This commit is contained in:
unknown
2020-04-20 22:05:35 +07:00
commit c0e3c634fa
94 changed files with 2237 additions and 0 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

16
convert.py Normal file
View File

@@ -0,0 +1,16 @@
from os import path
from pydub import AudioSegment
import soundfile as sf
import librosa
import numpy as np
# files
src = "Tap4_baLan_23m47_23m54_tucgian_nhieu.mp3"
dst = "test.wav"
# convert wav to mp3
#sound = AudioSegment.from_mp3(src)
#sound.export(dst, format="wav")
s, r = librosa.load(src, None, dtype='int16')
print(len(s),r)
print(np.mean(s))

1
data/info.txt Normal file
View File

@@ -0,0 +1 @@
Statistics are stored here, as well as the training lists.

BIN
data/stats.m Normal file

Binary file not shown.

BIN
data/stats.p Normal file

Binary file not shown.

Binary file not shown.

46
deepxi.py Normal file
View File

@@ -0,0 +1,46 @@
## FILE: deepxi.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University
## BRIEF: 'DeepXi' training and testing.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import os, sys
sys.path.insert(0, 'lib')
from dev.args import add_args, get_args
from dev.infer import infer
from dev.sample_stats import get_stats
from dev.train import train
import dev.deepxi_net as deepxi_net
import numpy as np
import tensorflow as tf
import dev.utils as utils
np.set_printoptions(threshold=1e6)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
if __name__ == '__main__':
## GET COMMAND LINE ARGUMENTS
args = get_args()
## TRAINING AND TESTING SET ARGUMENTS
args = add_args(args)
## GPU CONFIGURATION
config = utils.gpu_config(args.gpu)
## GET STATISTICS
args = get_stats(args.data_path, args, config)
print(args)
exit()
## MAKE DEEP XI NNET
net = deepxi_net.deepxi_net(args)
with tf.Session(config=config) as sess:
if args.train: train(sess, net, args)
if args.infer: infer(sess, net, args)

46
flask_server.py Normal file
View File

@@ -0,0 +1,46 @@
import os
import requests
from flask import Flask, escape, request, render_template, jsonify, make_response, session
#from utils import cvtToWavMono16, split
import random
from infer_file import get_model
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
infer = get_model()
@app.route('/')
def index():
return render_template('index.html')
@app.route('/upload', methods=['POST','GET'])
def upload():
if request.method == 'POST':
file = request.files['file']
filename = file.filename
print(filename)
if os.path.splitext(filename)[1][1:].strip() not in ['mp3','wav','flac']:
return render_template('index.html', filename='{} file not support. select mp3, wav or flac!'.format(filename))
file_path = 'static/upload/' + filename
file.save(file_path)
print('saved file: {}'.format(file_path))
res = make_response(jsonify({"file_path": file_path, "message": "Saved: {} to server".format(filename)}))
return res
return render_template('index.html')
@app.route('/predict/<file_path>')
def predict(file_path):
print(file_path)
file_path = file_path.replace('=','/')
out_file_path = infer(file_path)
print('predict done!!')
res = make_response(jsonify({"out_file_path":out_file_path, "message": "Predict susscess!"}))
return res
if __name__ == "__main__":
app.debug = True
app.secret_key = 'dangvansam'
#app.run(host='192.168.1.254', port='9002')
app.run(port='8080')

178
infer_file.py Normal file
View File

@@ -0,0 +1,178 @@
import os, sys
sys.path.insert(0, 'lib')
from dev.infer import infer
from dev.sample_stats import get_stats
from dev.train import train
import dev.deepxi_net as deepxi_net
import numpy as np
import tensorflow as tf
import dev.utils as utils
import argparse
from dev.utils import read_wav
from tqdm import tqdm
import dev.gain as gain
import dev.utils as utils
import dev.xi as xi
import numpy as np
import os
import scipy.io as spio
import librosa
import pickle
from scipy.io.wavfile import write as wav_write
np.set_printoptions(threshold=1e6)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
def str2bool(s): return s.lower() in ("yes", "true", "t", "1")
def create_args():
with open('data/stats.p', 'rb') as f:
stats = pickle.load(f)
parser = argparse.ArgumentParser()
## OPTIONS (GENERAL)
parser.add_argument('--gpu', default='0', type=str, help='GPU selection')
parser.add_argument('--ver', type=str, help='Model version')
parser.add_argument('--epoch', type=int, help='Epoch to use/retrain from')
parser.add_argument('--train', default=False, type=str2bool, help='Training flag')
parser.add_argument('--infer', default=True, type=str2bool, help='Inference flag')
parser.add_argument('--verbose', default=False, type=str2bool, help='Verbose')
parser.add_argument('--model', default='ResNet', type=str, help='Model type')
## OPTIONS (TRAIN)
parser.add_argument('--cont', default=False, type=str2bool, help='Continue testing from last epoch')
parser.add_argument('--mbatch_size', default=10, type=int, help='Mini-batch size')
parser.add_argument('--sample_size', default=1000, type=int, help='Sample size')
parser.add_argument('--max_epochs', default=250, type=int, help='Maximum number of epochs')
parser.add_argument('--grad_clip', default=True, type=str2bool, help='Gradient clipping')
parser.add_argument('--out_type', default='y', type=str, help='Output type for testing')
## GAIN FUNCTION
parser.add_argument('--gain', default='mmse-lsa', type=str, help='Gain function for testing')
## PATHS
parser.add_argument('--model_path', default='model/3f/epoch-175', type=str, help='Model save path')
parser.add_argument('--set_path', default='set', type=str, help='Path to datasets')
parser.add_argument('--data_path', default='data', type=str, help='Save data path')
parser.add_argument('--test_x_path', default='set/test_noisy_speech', type=str, help='Path to the noisy speech test set')
parser.add_argument('--in_filepath', default='test.wav', type=str, help='Output path')
parser.add_argument('--out_filepath', default='out.wav', type=str, help='Output path')
## FEATURES
parser.add_argument('--min_snr', default=-10, type=int, help='Minimum trained SNR level')
parser.add_argument('--max_snr', default=20, type=int, help='Maximum trained SNR level')
parser.add_argument('--f_s', default=16000, type=int, help='Sampling frequency (Hz)')
parser.add_argument('--T_w', default=32, type=int, help='Window length (ms)')
parser.add_argument('--T_s', default=16, type=int, help='Window shift (ms)')
parser.add_argument('--nconst', default=32768.0, type=float, help='Normalisation constant (see feat.addnoisepad())')
parser.add_argument('--N_w', default=int(16000*32*0.001), type=int, help='window length (samples)')
parser.add_argument('--N_s', default=int(16000*16*0.001), type=int, help='window shift (samples)')
parser.add_argument('--NFFT', default=int(pow(2, np.ceil(np.log2(int(16000*32*0.001))))), type=float, help='number of DFT components')
parser.add_argument('--stats', default=stats)
## NETWORK PARAMETERS
parser.add_argument('--d_in', default=257, type=int, help='Input dimensionality')
parser.add_argument('--d_out', default=257, type=int, help='Ouput dimensionality')
parser.add_argument('--d_model', default=256, type=int, help='Model dimensions')
parser.add_argument('--n_blocks', default=40, type=int, help='Number of blocks')
parser.add_argument('--d_f', default=64, type=int, help='Number of filters')
parser.add_argument('--k_size', default=3, type=int, help='Kernel size')
parser.add_argument('--max_d_rate', default=16, type=int, help='Maximum dilation rate')
parser.add_argument('--norm_type', default='FrameLayerNorm', type=str, help='Normalisation type')
parser.add_argument('--net_height', default=[4], type=list, help='RDL block height')
args = parser.parse_args()
return args
def build_restore_model(model_path, args, config):
## MAKE DEEP XI NNET
print('Start: Build and Restore model!')
sess = tf.Session(config=config)
net = deepxi_net.deepxi_net(args)
net.saver.restore(sess, args.model_path)
print('Done: Build and Restore model!')
return net, sess
# def infer(filename ,net, sess):
# print('Start infer file: {}'.format(filename))
# #(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
# (wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
# wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
# print(max(wav), min(wav), np.mean(wav))
# print(wav.shape)
# input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
# xi_bar_hat = sess.run(
# net.infer_output, feed_dict={net.input_ph: input_feat[0],
# net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
# xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
# #file_name = filename.split('/')[-1].split('.')
# y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
# y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
# net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
# if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
# if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
# y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
# out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
# wav_write(out_filepath, args.f_s, y)
# print('Infer out file: {} done'.format(out_filepath))
# return out_filepath
def get_model():
args = create_args()
## GPU CONFIGURATION
config = tf.ConfigProto()
config.allow_soft_placement=True
config.gpu_options.allow_growth=True
config.log_device_placement=False
net, sess = build_restore_model(args.model_path, args, config)
def infer(filename):
print('Start infer file: {}'.format(filename))
#(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
(wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
print(max(wav), min(wav), np.mean(wav))
print(wav.shape)
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
xi_bar_hat = sess.run(
net.infer_output, feed_dict={net.input_ph: input_feat[0],
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
#file_name = filename.split('/')[-1].split('.')
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
wav_write(out_filepath, args.f_s, y)
print('Infer out file: {} done'.format(out_filepath))
return out_filepath
return infer
if __name__ == '__main__':
infer = get_model()
infer('Toàn cảnh phòng chống dịch COVID-19 ngày 18-4-2020 - VTV24.mp3')
#infer('set/test_noisy_speech/198853.wav')

58
lib/dev/ResLSTM.py Normal file
View File

@@ -0,0 +1,58 @@
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import tensorflow as tf
from tensorflow.python.training import moving_averages
from dev.normalisation import Normalisation
import numpy as np
import argparse, math, sys
def ResLSTMBlock(x, d_model, seq_len, block, parallel_iterations=1024):
'''
ResLSTM block.
Input/s:
x - input to block.
d_model - cell size.
seq_len - sequence length.
block - block number.
parallel_iterations - number of parrallel iterations.
Output/s:
output of residual block.
'''
with tf.variable_scope( 'block_' + str(block)):
cell = tf.contrib.rnn.LSTMCell(d_model)
# activation, _ = tf.nn.static_rnn(cell, [tf.expand_dims(x[0,0,:],0)]*seq_len[0], dtype=tf.float32)
activation, _ = tf.nn.dynamic_rnn(cell, x, seq_len, swap_memory=True,
parallel_iterations=parallel_iterations, dtype=tf.float32)
return tf.add(x, activation)
def ResLSTM(x, seq_len, norm_type, training=None, d_out=257,
n_blocks=5, d_model=512, out_layer=True, boolean_mask=False):
'''
ResLSTM network.
Input/s:
x - input.
seq_len - length of each sequence.
norm_type - normalisation type.
training - training flag.
d_out - output dimensions.
n_blocks - number of residual blocks.
d_model - cell size.
out_layer - add an output layer.
boolean_mask - convert padded 3D output to unpadded and stacked 2D output.
Output/s:
unactivated output of ResLSTM.
'''
blocks = [tf.nn.relu(Normalisation(tf.layers.dense(x,
d_model, use_bias=False), norm_type, seq_len))] # (W -> Norm -> ReLU).
for i in range(n_blocks): blocks.append(ResLSTMBlock(blocks[-1], d_model, seq_len, i))
if boolean_mask: blocks[-1] = tf.boolean_mask(blocks[-1], tf.sequence_mask(seq_len))
return tf.layers.dense(blocks[-1], d_out, use_bias=True)

89
lib/dev/ResNet.py Normal file
View File

@@ -0,0 +1,89 @@
## FILE: .py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: ResNet with bottlekneck blocks and 1D causal dlated convolutional units. .
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import tensorflow as tf
from tensorflow.python.training import moving_averages
from dev.normalisation import Normalisation
import numpy as np
import argparse, math, sys
from dev.add_noise import add_noise_batch
def CausalDilatedConv1d(x, d_f, k_size, d_rate=1, use_bias=True):
'''
1D Causal dilated convolutional unit.
Input/s:
x - input.
d_f - filter dimensions.
k_size - kernel dimensions.
d_rate - dilation rate.
use_bias - include use bias vector.
Output/s:
output of convolutional unit.
'''
if k_size > 1: # padding for causality.
x_shape = tf.shape(x)
x = tf.concat([tf.zeros([x_shape[0], (k_size - 1)*d_rate, x_shape[2]]), x], 1)
return tf.layers.conv1d(x, d_f, k_size, dilation_rate=d_rate,
activation=None, padding='valid', use_bias=use_bias)
def BottlekneckBlock(x, norm_type, seq_len, d_model, d_f, k_size, d_rate):
'''
Bottlekneck block with causal dilated convolutional
units, and normalisation.
Input/s:
x - input to block.
norm_type - normalisation type.
seq_len - length of each sequence.
d_out - output dimensions.
d_f - filter dimensions.
k_size - kernel dimensions.
d_rate - dilation rate.
Output/s:
output of residual block.
'''
layer_1 = CausalDilatedConv1d(tf.nn.relu(Normalisation(x, norm_type, seq_len=seq_len)), d_f, 1, 1, False)
layer_2 = CausalDilatedConv1d(tf.nn.relu(Normalisation(layer_1, norm_type, seq_len=seq_len)), d_f, k_size, d_rate, False)
layer_3 = CausalDilatedConv1d(tf.nn.relu(Normalisation(layer_2, norm_type, seq_len=seq_len)), d_model, 1, 1, True)
return tf.add(x, layer_3)
def ResNet(x, seq_len, norm_type, training=None, d_out=257,
n_blocks=40, d_model=256, d_f=64, k_size=3, max_d_rate=16, out_layer=True, boolean_mask=False):
'''
ResNet with bottlekneck blocks, causal dilated convolutional
units, and normalisation. Dilation resets after
exceeding 'max_d_rate'.
Input/s:
x - input to ResNet.
norm_type - normalisation type.
seq_len - length of each sequence.
training - training flag.
d_out - output dimensions.
n_blocks - number of residual blocks.
d_model - model dimensions.
d_f - filter dimensions.
k_size - kernel dimensions.
max_d_rate - maximum dilation rate.
Output/s:
unactivated output of ResNet.
'''
# mask = tf.cast(tf.expand_dims(tf.sequence_mask(seq_len), 2), tf.float32) # convert mask to float.
blocks = [tf.nn.relu(Normalisation(tf.layers.dense(x, d_model, use_bias=False), norm_type, seq_len=seq_len))] # (W -> Norm -> ReLU).
for i in range(n_blocks): blocks.append(BottlekneckBlock(blocks[-1], norm_type, seq_len,
d_model, d_f, k_size, int(2**(i%(np.log2(max_d_rate)+1)))))
if boolean_mask: blocks[-1] = tf.boolean_mask(blocks[-1], tf.sequence_mask(seq_len))
return tf.layers.dense(blocks[-1], d_out, use_bias=True)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,71 @@
## FILE: polar.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: Analysis and synthesis with the polar representation in the acoustic-domain.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import tensorflow as tf
import functools
from tensorflow.python.ops.signal import window_ops
def analysis(x, N_w, N_s, NFFT, legacy=False):
'''
Polar form acoustic-domain analysis.
Input/s:
x - noisy speech.
N_w - time-domain window length (samples).
N_s - time-domain window shift (samples).
NFFT - acoustic-domain DFT components.
Output/s:
Magnitude and phase spectrums.
'''
if legacy:
## MAGNITUDE & PHASE SPECTRUMS (ACOUSTIC DOMAIN)
x_DFT = tf.signal.stft(x, N_w, N_s, NFFT, pad_end=True)
x_MAG = tf.abs(x_DFT); x_PHA = tf.angle(x_DFT)
return x_MAG, x_PHA
else:
## MAGNITUDE & PHASE SPECTRUMS (ACOUSTIC DOMAIN)
W = functools.partial(window_ops.hamming_window, periodic=False)
x_DFT = tf.signal.stft(x, N_w, N_s, NFFT, window_fn=W, pad_end=True)
x_MAG = tf.abs(x_DFT); x_PHA = tf.angle(x_DFT)
return x_MAG, x_PHA
def synthesis(y_MAG, x_PHA, N_w, N_s, NFFT, legacy=False):
'''
Polar form acoustic-domain synthesis.
Input/s:
y_MAG - modified nagnitude spectrum.
x_PHA - unmodified phase spectrum.
N_w - time-domain window length (samples).
N_s - time-domain window shift (samples).
NFFT - acoustic-domain DFT components.
Output/s:
synthesised signal.
'''
if legacy:
## SYNTHESISED SIGNAL
y_DFT = tf.cast(y_MAG, tf.complex64)*tf.exp(1j*tf.cast(x_PHA, tf.complex64))
return tf.signal.inverse_stft(y_DFT, N_w, N_s, NFFT, tf.signal.inverse_stft_window_fn(N_s))
else:
## SYNTHESISED SIGNAL
W = functools.partial(window_ops.hamming_window, periodic=False)
y_DFT = tf.cast(y_MAG, tf.complex64)*tf.exp(1j*tf.cast(x_PHA, tf.complex64))
return tf.signal.inverse_stft(y_DFT, N_w, N_s, NFFT, tf.signal.inverse_stft_window_fn(N_s, W))

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,125 @@
## FILE: polar.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University
## BRIEF: Feature and target generation for polar representation in the acoustic-domain.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
from dev.acoustic.analysis_synthesis import polar
from dev.add_noise import add_noise_batch
from dev.num_frames import num_frames
from dev.utils import log10
import dev.xi as xi
import tensorflow as tf
def input(z, z_len, N_w, N_s, NFFT, f_s):
'''
Input features for polar form acoustic-domain.
Input/s:
z - speech (dtype=tf.int32).
z_len - speech length without padding (samples).
N_w - time-domain window length (samples).
N_s - time-domain window shift (samples).
NFFT - number of acoustic-domain DFT components.
f_s - sampling frequency (Hz).
Output/s:
z_MAG - speech magnitude spectrum.
z_PHA - speech phase spectrum.
L - number of time-domain frames for each sequence.
'''
z = tf.truediv(tf.cast(z, tf.float32), 32768.0)
L = num_frames(z_len, N_s)
z_MAG, z_PHA = polar.analysis(z, N_w, N_s, NFFT)
return z_MAG, L, z_PHA
def input_target_spec(s, d, s_len, d_len, SNR, N_w, N_s, NFFT, f_s):
'''
Input features and target (spectrum) for polar form acoustic-domain.
Inputs:
s - clean speech (dtype=tf.int32).
d - noise (dtype=tf.int32).
s_len - clean speech length without padding (samples).
d_len - noise length without padding (samples).
SNR - SNR level.
N_w - time-domain window length (samples).
N_s - time-domain window shift (samples).
NFFT - number of acoustic-domain DFT components.
f_s - sampling frequency (Hz).
Outputs:
x_MAG - noisy speech magnitude spectrum.
s_MAG - clean speech magnitude spectrum (target).
L - number of time-domain frames for each sequence.
'''
(x, s, _) = add_noise_batch(s, d, s_len, d_len, SNR)
L = num_frames(s_len, N_s) # number of time-domain frames for each sequence (uppercase eta).
x_MAG, _ = polar.analysis(x, N_w, N_s, NFFT)
s_MAG, _ = polar.analysis(s, N_w, N_s, NFFT)
s_MAG = tf.boolean_mask(s_MAG, tf.sequence_mask(L))
return x_MAG, s_MAG, L
def input_target_xi(s, d, s_len, d_len, SNR, N_w, N_s, NFFT, f_s, mu, sigma):
'''
Input features and target (mapped a priori SNR) for polar form acoustic-domain.
Inputs:
s - clean speech (dtype=tf.int32).
d - noise (dtype=tf.int32).
s_len - clean speech length without padding (samples).
d_len - noise length without padding (samples).
SNR - SNR level.
N_w - time-domain window length (samples).
N_s - time-domain window shift (samples).
NFFT - number of acoustic-domain DFT components.
f_s - sampling frequency (Hz).
mu - sample mean.
sigma - sample standard deviation.
Outputs:
x_MAG - noisy speech magnitude spectrum.
xi_mapped - mapped a priori SNR (target).
L - number of time-domain frames for each sequence.
'''
(x, s, d) = add_noise_batch(s, d, s_len, d_len, SNR)
L = num_frames(s_len, N_s) # number of acoustic-domain frames for each sequence (uppercase eta).
x_MAG, _ = polar.analysis(x, N_w, N_s, NFFT)
s_MAG, _ = polar.analysis(s, N_w, N_s, NFFT)
s_MAG = tf.boolean_mask(s_MAG, tf.sequence_mask(L))
d_MAG, _ = polar.analysis(d, N_w, N_s, NFFT)
d_MAG = tf.boolean_mask(d_MAG, tf.sequence_mask(L))
xi_bar = xi.xi_bar(s_MAG, d_MAG, mu, sigma)
return x_MAG, xi_bar, L
def target_xi(s, d, s_len, d_len, SNR, N_w, N_s, NFFT, f_s):
'''
Target (a priori SNR) for polar form acoustic-domain.
Inputs:
s - clean speech (dtype=tf.int32).
d - noise (dtype=tf.int32).
s_len - clean speech length without padding (samples).
d_len - noise length without padding (samples).
SNR - SNR level.
N_w - time-domain window length (samples).
N_s - time-domain window shift (samples).
NFFT - number of acoustic-domain DFT components.
f_s - sampling frequency (Hz).
Outputs:
xi_dB - a priori SNR in dB (target).
L - number of time-domain frames for each sequence.
'''
(_, s, d) = add_noise_batch(s, d, s_len, d_len, SNR)
L = num_frames(s_len, N_s) # number of acoustic-domain frames for each sequence (uppercase eta).
s_MAG, _ = polar.analysis(s, N_w, N_s, NFFT)
d_MAG, _ = polar.analysis(d, N_w, N_s, NFFT)
s_MAG = tf.boolean_mask(s_MAG, tf.sequence_mask(L))
d_MAG = tf.boolean_mask(d_MAG, tf.sequence_mask(L))
xi_dB = xi.xi_dB(s_MAG, d_MAG)
return xi_dB, L

79
lib/dev/add_noise.py Normal file
View File

@@ -0,0 +1,79 @@
## FILE: add_noise.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University
## BRIEF: Add noise to clean speech at set SNR level.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import tensorflow as tf
def add_noise_batch(s, d, s_len, d_len, SNR):
'''
Creates noisy speech batch from clean speech, noise, and SNR batches.
Input/s:
s - clean waveforms (dtype=tf.int32).
d - noisy waveforms (dtype=tf.int32).
s_len - clean waveform lengths without padding (samples).
d_len - noise waveform lengths without padding (samples).
SNR - SNR levels.
Output/s:
tuple consisting of clean speech, noisy speech, and noise (x, s, d).
'''
return tf.map_fn(lambda z: add_noise_pad(z[0], z[1], z[2], z[3], z[4],
tf.reduce_max(s_len)), (s, d, s_len, d_len, SNR), dtype=(tf.float32, tf.float32,
tf.float32))
def add_noise_pad(s, d, s_len, d_len, SNR, P):
'''
Calls addnoise() and pads the waveforms to the length given by P.
Also normalises the waveforms.
Inputs:
s - clean speech waveform.
d - noise waveform.
s_len - length of s.
d_len - length of d.
SNR - SNR level.
P - padded length.
Outputs:
s - padded clean speech waveform.
x - padded noisy speech waveform.
d - truncated, scaled, and padded noise waveform.
'''
s = tf.truediv(tf.cast(tf.slice(s, [0], [s_len]), tf.float32), 32768.0)
d = tf.truediv(tf.cast(tf.slice(d, [0], [d_len]), tf.float32), 32768.0)
(x, d) = add_noise(s, d, SNR)
total_zeros = tf.subtract(P, tf.shape(s)[0])
x = tf.pad(x, [[0, total_zeros]], "CONSTANT")
s = tf.pad(s, [[0, total_zeros]], "CONSTANT")
d = tf.pad(d, [[0, total_zeros]], "CONSTANT")
return (x, s, d)
def add_noise(s, d, SNR):
'''
Adds noise to the clean waveform at a specific SNR value. A random section
of the noise waveform is used.
Inputs:
s - clean waveform.
d - noise waveform.
SNR - SNR level.
Outputs:
x - noisy speech waveform.
d - truncated and scaled noise waveform.
'''
s_len = tf.shape(s)[0]
d_len = tf.shape(d)[0]
i = tf.random_uniform([1], 0, tf.add(1, tf.subtract(d_len, s_len)), tf.int32)
d = tf.slice(d, [i[0]], [s_len])
d = tf.multiply(tf.truediv(d, tf.norm(d)), tf.truediv(tf.norm(s),
tf.pow(10.0, tf.multiply(0.05, SNR))))
x = tf.add(s, d)
return (x, d)

109
lib/dev/args.py Normal file
View File

@@ -0,0 +1,109 @@
## FILE: args.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: Get command line arguments.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import argparse
import numpy as np
import os
from dev.se_batch import Batch_list, Batch
from os.path import expanduser
## ADD ADDITIONAL ARGUMENTS
def add_args(args, modulation=False):
## DEPENDANT OPTIONS
args.model_path = args.model_path + '/' + args.ver # model save path.
args.train_s_path = args.set_path + '/train_clean_speech' # path to the clean speech training set.
args.train_d_path = args.set_path + '/train_noise' # path to the noise training set.
args.val_s_path = args.set_path + '/val_clean_speech' # path to the clean speech validation set.
args.val_d_path = args.set_path + '/val_noise' # path to the noise validation set.
args.out_path = args.out_path + '/' + args.ver + '/' + 'e' + str(args.epoch) # output path.
args.N_w = int(args.f_s*args.T_w*0.001) # window length (samples).
args.N_s = int(args.f_s*args.T_s*0.001) # window shift (samples).
args.NFFT = int(pow(2, np.ceil(np.log2(args.N_w)))) # number of DFT components.
## DATASETS
if args.train: ## TRAINING AND VALIDATION CLEAN SPEECH AND NOISE SET
args.train_s_list = Batch_list(args.train_s_path, 'clean_speech_' + args.set_path.rsplit('/', 1)[-1], args.data_path) # clean speech training list.
args.train_d_list = Batch_list(args.train_d_path, 'noise_' + args.set_path.rsplit('/', 1)[-1], args.data_path) # noise training list.
if not os.path.exists(args.model_path): os.makedirs(args.model_path) # make model path directory.
args.val_s, args.val_s_len, args.val_snr, _ = Batch(args.val_s_path,
list(range(args.min_snr, args.max_snr + 1))) # clean validation waveforms and lengths.
args.val_d, args.val_d_len, _, _ = Batch(args.val_d_path,
list(range(args.min_snr, args.max_snr + 1))) # noise validation waveforms and lengths.
args.train_steps=int(np.ceil(len(args.train_s_list)/args.mbatch_size))
args.val_steps=int(np.ceil(args.val_s.shape[0]/args.mbatch_size))
## INFERENCE
# if args.infer: args.test_x, args.test_x_len, args.test_snr, args.test_fnames = Batch(args.test_x_path, '*', []) # noisy speech test waveforms and lengths.
if args.infer: args.test_x_list = Batch_list(args.test_x_path, 'test_x', args.data_path, make_new=True)
return args
## STRING TO BOOLEAN
def str2bool(s): return s.lower() in ("yes", "true", "t", "1")
## GET COMMAND LINE ARGUMENTS
def get_args():
parser = argparse.ArgumentParser()
## OPTIONS (GENERAL)
parser.add_argument('--gpu', default='0', type=str, help='GPU selection')
parser.add_argument('--ver', type=str, help='Model version')
parser.add_argument('--epoch', type=int, help='Epoch to use/retrain from')
parser.add_argument('--train', default=False, type=str2bool, help='Training flag')
parser.add_argument('--infer', default=False, type=str2bool, help='Inference flag')
parser.add_argument('--verbose', default=False, type=str2bool, help='Verbose')
parser.add_argument('--model', default='ResNet', type=str, help='Model type')
## OPTIONS (TRAIN)
parser.add_argument('--cont', default=False, type=str2bool, help='Continue testing from last epoch')
parser.add_argument('--mbatch_size', default=10, type=int, help='Mini-batch size')
parser.add_argument('--sample_size', default=1000, type=int, help='Sample size')
parser.add_argument('--max_epochs', default=250, type=int, help='Maximum number of epochs')
parser.add_argument('--grad_clip', default=True, type=str2bool, help='Gradient clipping')
# TEST OUTPUT TYPE
# 'xi_hat' - a priori SNR estimate (.mat),
# 'y' - enhanced speech (.wav).
parser.add_argument('--out_type', default='y', type=str, help='Output type for testing')
## GAIN FUNCTION
# 'ibm' - Ideal Binary Mask (IBM), 'wf' - Wiener Filter (WF), 'srwf' - Square-Root Wiener Filter (SRWF),
# 'cwf' - Constrained Wiener Filter (cWF), 'mmse-stsa' - Minimum-Mean Square Error - Short-Time Spectral Amplitude (MMSE-STSA) estimator,
# 'mmse-lsa' - Minimum-Mean Square Error - Log-Spectral Amplitude (MMSE-LSA) estimator.
parser.add_argument('--gain', default='srwf', type=str, help='Gain function for testing')
## PATHS
parser.add_argument('--model_path', default='model', type=str, help='Model save path')
parser.add_argument('--set_path', default='set', type=str, help='Path to datasets')
parser.add_argument('--data_path', default='data', type=str, help='Save data path')
parser.add_argument('--test_x_path', default='set/test_noisy_speech', type=str, help='Path to the noisy speech test set')
parser.add_argument('--out_path', default='out', type=str, help='Output path')
## FEATURES
parser.add_argument('--min_snr', default=-10, type=int, help='Minimum trained SNR level')
parser.add_argument('--max_snr', default=20, type=int, help='Maximum trained SNR level')
parser.add_argument('--f_s', default=16000, type=int, help='Sampling frequency (Hz)')
parser.add_argument('--T_w', default=32, type=int, help='Window length (ms)')
parser.add_argument('--T_s', default=16, type=int, help='Window shift (ms)')
parser.add_argument('--nconst', default=32768.0, type=float, help='Normalisation constant (see feat.addnoisepad())')
## NETWORK PARAMETERS
parser.add_argument('--d_in', default=257, type=int, help='Input dimensionality')
parser.add_argument('--d_out', default=257, type=int, help='Ouput dimensionality')
parser.add_argument('--d_model', default=256, type=int, help='Model dimensions')
parser.add_argument('--n_blocks', default=40, type=int, help='Number of blocks')
parser.add_argument('--d_f', default=64, type=int, help='Number of filters')
parser.add_argument('--k_size', default=3, type=int, help='Kernel size')
parser.add_argument('--max_d_rate', default=16, type=int, help='Maximum dilation rate')
parser.add_argument('--norm_type', default='FrameLayerNorm', type=str, help='Normalisation type')
parser.add_argument('--net_height', default=[4], type=list, help='RDL block height')
args = parser.parse_args()
return args

77
lib/dev/deepxi_net.py Normal file
View File

@@ -0,0 +1,77 @@
## FILE: deepxi_net.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University
## BRIEF: Network employed withing the Deep Xi framework.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
## VARIABLE DESCRIPTIONS
# s - clean speech.
# d - noise.
# x - noisy speech.
from dev.acoustic.analysis_synthesis.polar import synthesis
from dev.acoustic.feat import polar
from dev.ResNet import ResNet
import dev.optimisation as optimisation
import numpy as np
import tensorflow as tf
## ARTIFICIAL NEURAL NETWORK
class deepxi_net:
def __init__(self, args):
print('Preparing graph...')
## RESNET
self.input_ph = tf.placeholder(tf.float32, shape=[None, None, args.d_in], name='input_ph') # noisy speech MS placeholder.
self.nframes_ph = tf.placeholder(tf.int32, shape=[None], name='nframes_ph') # noisy speech MS sequence length placeholder.
if args.model == 'ResNet':
self.output = ResNet(self.input_ph, self.nframes_ph, args.norm_type, n_blocks=args.n_blocks, boolean_mask=True, d_out=args.d_out,
d_model=args.d_model, d_f=args.d_f, k_size=args.k_size, max_d_rate=args.max_d_rate)
elif args.model == 'RDLNet':
from dev.RDLNet import RDLNet
self.output = RDLNet(self.input_ph, self.nframes_ph, args.norm_type, n_blocks=args.n_blocks, boolean_mask=True, d_out=args.d_out,
d_f=args.d_f, net_height=args.net_height)
elif args.model == 'ResLSTM':
from dev.ResLSTM import ResLSTM
self.output = ResLSTM(self.input_ph, self.nframes_ph, args.norm_type, n_blocks=args.n_blocks, boolean_mask=True, d_out=args.d_out, d_model=args.d_model)
## TRAINING FEATURE EXTRACTION GRAPH
self.s_ph = tf.placeholder(tf.int16, shape=[None, None], name='s_ph') # clean speech placeholder.
self.d_ph = tf.placeholder(tf.int16, shape=[None, None], name='d_ph') # noise placeholder.
self.s_len_ph = tf.placeholder(tf.int32, shape=[None], name='s_len_ph') # clean speech sequence length placeholder.
self.d_len_ph = tf.placeholder(tf.int32, shape=[None], name='d_len_ph') # noise sequence length placeholder.
self.snr_ph = tf.placeholder(tf.float32, shape=[None], name='snr_ph') # SNR placeholder.
self.train_feat = polar.input_target_xi(self.s_ph, self.d_ph, self.s_len_ph,
self.d_len_ph, self.snr_ph, args.N_w, args.N_s, args.NFFT, args.f_s, args.stats['mu_hat'], args.stats['sigma_hat'])
## INFERENCE FEATURE EXTRACTION GRAPH
self.infer_feat = polar.input(self.s_ph, self.s_len_ph, args.N_w, args.N_s, args.NFFT, args.f_s)
## PLACEHOLDERS
self.x_ph = tf.placeholder(tf.int16, shape=[None, None], name='x_ph') # noisy speech placeholder.
self.x_len_ph = tf.placeholder(tf.int32, shape=[None], name='x_len_ph') # noisy speech sequence length placeholder.
self.target_ph = tf.placeholder(tf.float32, shape=[None, args.d_out], name='target_ph') # training target placeholder.
self.keep_prob_ph = tf.placeholder(tf.float32, name='keep_prob_ph') # keep probability placeholder.
self.training_ph = tf.placeholder(tf.bool, name='training_ph') # training placeholder.
## SYNTHESIS GRAPH
if args.infer:
self.infer_output = tf.nn.sigmoid(self.output)
self.y_MAG_ph = tf.placeholder(tf.float32, shape=[None, None, args.d_in], name='y_MAG_ph')
self.x_PHA_ph = tf.placeholder(tf.float32, [None, None, args.d_in], name='x_PHA_ph')
self.y = synthesis(self.y_MAG_ph, self.x_PHA_ph, args.N_w, args.N_s, args.NFFT)
## LOSS & OPTIMIZER
self.loss = optimisation.loss(self.target_ph, self.output, 'mean_sigmoid_cross_entropy', axis=[1])
self.total_loss = tf.reduce_mean(self.loss, axis=0)
self.trainer, _ = optimisation.optimiser(self.total_loss, optimizer='adam', grad_clip=True)
## SAVE VARIABLES
self.saver = tf.train.Saver(max_to_keep=256)
## NUMBER OF PARAMETERS
args.params = (np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))

129
lib/dev/gain.py Normal file
View File

@@ -0,0 +1,129 @@
## FILE: gain.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University
## BRIEF: Gain functions and masks for speech enhancement.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import numpy as np
from scipy.special import exp1, i0, i1
def mmse_stsa(xi, gamma):
'''
Computes the MMSE-STSA gain function.
Input/s:
xi - a priori SNR.
gamma - a posteriori SNR.
Output/s:
G - MMSE-STSA gain function.
'''
nu = np.multiply(xi, np.divide(gamma, np.add(1, xi)))
G = np.multiply(np.multiply(np.multiply(np.divide(np.sqrt(np.pi), 2),
np.divide(np.sqrt(nu), gamma)), np.exp(np.divide(-nu,2))),
np.add(np.multiply(np.add(1, nu), i0(np.divide(nu,2))),
np.multiply(nu, i1(np.divide(nu, 2))))) # MMSE-STSA gain function.
idx = np.isnan(G) | np.isinf(G) # replace by Wiener gain.
G[idx] = np.divide(xi[idx], np.add(1, xi[idx])) # Wiener gain.
return G
def mmse_lsa(xi, gamma):
'''
Computes the MMSE-LSA gain function.
Input/s:
xi - a priori SNR.
gamma - a posteriori SNR.
Output/s:
MMSE-LSA gain function.
'''
nu = np.multiply(np.divide(xi, np.add(1, xi)), gamma)
return np.multiply(np.divide(xi, np.add(1, xi)), np.exp(np.multiply(0.5, exp1(nu)))) # MMSE-LSA gain function.
def wf(xi):
'''
Computes the Wiener filter (WF) gain function.
Input/s:
xi - a priori SNR.
Output/s:
WF gain function.
'''
return np.divide(xi, np.add(xi, 1.0)) # WF gain function.
def srwf(xi):
'''
Computes the square-root Wiener filter (WF) gain function.
Input/s:
xi - a priori SNR.
Output/s:
SRWF gain function.
'''
return np.sqrt(wf(xi)) # SRWF gain function.
def cwf(xi):
'''
Computes the constrained Wiener filter (WF) gain function.
Input/s:
xi - a priori SNR.
Output/s:
cWF gain function.
'''
return wf(np.sqrt(xi)) # cWF gain function.
def irm(xi):
'''
Computes the ideal ratio mask (IRM).
Input/s:
xi - a priori SNR.
Output/s:
IRM.
'''
return srwf(xi) # IRM.
def ibm(xi):
'''
Computes the ideal binary mask (IBM) with a threshold of 0 dB.
Input/s:
xi - a priori SNR.
Output/s:
IBM.
'''
return np.greater(self.xi_hat_ph, 1, dtype=np.float32) # IBM (1 corresponds to 0 dB).
def gfunc(xi, gamma=None, gtype='mmse-lsa'):
'''
Computes the selected gain function.
Input/s:
xi - a priori SNR.
gamma - a posteriori SNR.
gtype - gain function type.
Output/s:
G - gain function.
'''
if gtype == 'mmse-lsa': G = mmse_lsa(xi, gamma)
elif gtype == 'mmse-stsa': G = mmse_stsa(xi, gamma)
elif gtype == 'wf': G = wf(xi)
elif gtype == 'srwf': G = srwf(xi)
elif gtype == 'cwf': G = cwf(xi)
elif gtype == 'irm': G = irm(xi)
elif gtype == 'ibm': G = ibm(xi)
else: ValueError('Gain function not available.')
return G

97
lib/dev/infer.py Normal file
View File

@@ -0,0 +1,97 @@
## FILE: infer.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: Inference module.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
from dev.utils import read_wav
from tqdm import tqdm
import dev.gain as gain
import dev.utils as utils
import dev.xi as xi
import numpy as np
import os
import scipy.io as spio
## INFERENCE
def infer(sess, net, args):
print("Inference...", )
print (args.test_x_list)
net.saver.restore(sess, args.model_path + '/epoch-' + str(args.epoch)) # load model from epoch.
if args.out_type == 'xi_hat': args.out_path = args.out_path + '/xi_hat'
elif args.out_type == 'y': args.out_path = args.out_path + '/' + args.gain + '/y'
elif args.out_type == 'ibm_hat': args.out_path = args.out_path + '/ibm_hat'
else: ValueError('Incorrect output type.')
if not os.path.exists(args.out_path): os.makedirs(args.out_path) # make output directory.
for j in tqdm(args.test_x_list):
(wav, _) = read_wav(j['file_path']) # read wav from given file path.
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [j['seq_len']]}) # sample of training set.
xi_bar_hat = sess.run(net.infer_output, feed_dict={net.input_ph: input_feat[0],
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
file_name = j['file_path'].rsplit('/',1)[1].split('.')[0]
if args.out_type == 'xi_hat':
spio.savemat(args.out_path + '/' + file_name + '.mat', {'xi_hat':xi_hat})
elif args.out_type == 'y':
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
print (args.out_path + '/' + file_name + '.wav')
utils.save_wav(args.out_path + '/' + file_name + '.wav', args.f_s, y)
elif args.out_type == 'ibm_hat':
ibm_hat = np.greater(xi_hat, 1.0)
spio.savemat(args.out_path + '/' + file_name + '.mat', {'ibm_hat':ibm_hat})
print('Inference complete.')
def infer2(sess, net, args):
print("Inference...", )
print (args.test_x_list)
net.saver.restore(sess, args.model_path + '/epoch-' + str(args.epoch)) # load model from epoch.
if args.out_type == 'xi_hat': args.out_path = args.out_path + '/xi_hat'
elif args.out_type == 'y': args.out_path = args.out_path + '/' + args.gain + '/y'
elif args.out_type == 'ibm_hat': args.out_path = args.out_path + '/ibm_hat'
else: ValueError('Incorrect output type.')
if not os.path.exists(args.out_path): os.makedirs(args.out_path) # make output directory.
for j in tqdm(args.test_x_list):
(wav, _) = read_wav(j['file_path']) # read wav from given file path.
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [j['seq_len']]}) # sample of training set.
xi_bar_hat = sess.run(net.infer_output, feed_dict={net.input_ph: input_feat[0],
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
file_name = j['file_path'].rsplit('/',1)[1].split('.')[0]
if args.out_type == 'xi_hat':
spio.savemat(args.out_path + '/' + file_name + '.mat', {'xi_hat':xi_hat})
elif args.out_type == 'y':
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
print (args.out_path + '/' + file_name + '.wav')
utils.save_wav(args.out_path + '/' + file_name + '.wav', args.f_s, y)
elif args.out_type == 'ibm_hat':
ibm_hat = np.greater(xi_hat, 1.0)
spio.savemat(args.out_path + '/' + file_name + '.mat', {'ibm_hat':ibm_hat})
print('Inference complete.')

96
lib/dev/normalisation.py Normal file
View File

@@ -0,0 +1,96 @@
## FILE: normalisation.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: Layer/instance/batch normalisation functions.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
from os.path import expanduser
import argparse, os, string
import numpy as np
import tensorflow as tf
def Normalisation(x, norm_type='FrameLayerNorm', seq_len=None, mask=None, training=False, centre=True, scale=True):
'''
Normalisation.
Input/s:
x - unnormalised input.
norm_type - normalisation type.
seq_len - length of each sequence.
mask - sequence mask.
training - training flag.
Output/s:
normalised input.
'''
if norm_type == 'SeqCausalLayerNorm': return SeqCausalLayerNorm(x, seq_len, centre=centre, scale=scale)
elif norm_type == 'FrameLayerNorm': return FrameLayerNorm(x, centre=centre, scale=scale)
elif norm_type == 'unnormalised': return x
else: ValueError('Normalisation type does not exist: %s.' % (norm_type))
count = 0
def SeqCausalLayerNorm(x, seq_len, centre=True, scale=True):
'''
Sequence-wise causal layer normalisation with sequence masking (causal layer norm version of https://arxiv.org/pdf/1510.01378.pdf).
Input/s:
x - input.
seq_len - length of each sequence.
centre - centre parameter.
scale - scale parameter.
Output/s:
normalised input.
'''
global count
count += 1
with tf.variable_scope('LayerNorm' + str(count)):
input_size = x.get_shape().as_list()[-1]
mask = tf.cast(tf.sequence_mask(seq_len), tf.float32) # convert mask to float.
den = tf.multiply(tf.range(1.0, tf.add(tf.cast(tf.shape(mask)[-1], tf.float32), 1.0), dtype=tf.float32), input_size)
mu = tf.expand_dims(tf.truediv(tf.cumsum(tf.reduce_sum(x, -1), -1), den), 2)
sigma = tf.expand_dims(tf.truediv(tf.cumsum(tf.reduce_sum(tf.square(tf.subtract(x,
mu)), -1), -1), den),2)
if centre: beta = tf.get_variable("beta", input_size, dtype=tf.float32,
initializer=tf.constant_initializer(0.0), trainable=True)
else: beta = tf.constant(np.zeros(input_size), name="beta", dtype=tf.float32)
if scale: gamma = tf.get_variable("Gamma", input_size, dtype=tf.float32,
initializer=tf.constant_initializer(1.0), trainable=True)
else: gamma = tf.constant(np.ones(input_size), name="Gamma", dtype=tf.float32)
return tf.multiply(tf.nn.batch_normalization(x, mu, sigma, offset=beta, scale=gamma,
variance_epsilon = 1e-12), tf.expand_dims(mask, 2))
count = 0
def FrameLayerNorm(x, centre=True, scale=True):
'''
Frame-wise layer normalisation (layer norm version of https://arxiv.org/pdf/1510.01378.pdf).
Input/s:
x - input.
seq_len - length of each sequence.
centre - centre parameter.
scale - scale parameter.
Output/s:
normalised input.
'''
global count
count += 1
with tf.variable_scope('frm_wise_layer_norm' + str(count)):
mu, sigma = tf.nn.moments(x, -1, keepdims=True)
input_size = x.get_shape().as_list()[-1] # get number of input dimensions.
if centre:
beta = tf.get_variable("beta", input_size, dtype=tf.float32,
initializer=tf.constant_initializer(0.0), trainable=True)
else: beta = tf.constant(np.zeros(input_size), name="beta", dtype=tf.float32)
if scale:
gamma = tf.get_variable("Gamma", input_size, dtype=tf.float32,
initializer=tf.constant_initializer(1.0), trainable=True)
else: gamma = tf.constant(np.ones(input_size), name="Gamma", dtype=tf.float32)
return tf.nn.batch_normalization(x, mu, sigma, offset=beta, scale=gamma,
variance_epsilon = 1e-12) # normalise batch.

43
lib/dev/num_frames.py Normal file
View File

@@ -0,0 +1,43 @@
## FILE: nframes.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: Detirmines number of frames in a signal.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import tensorflow as tf
def num_frames(N, N_s):
'''
Returns the number of frames for a given sequence length, and
frame shift.
Inputs:
N - sequence length (samples).
N_s - frame shift (samples).
Output:
number of frames
'''
return tf.cast(tf.ceil(tf.truediv(tf.cast(N, tf.float32),tf.cast(N_s, tf.float32))), tf.int32) # number of frames.
def acou_num_frames(N, N_s, K_s):
'''
Returns the number of acoustic-domain frames for a given sequence length, and
frame shift.
Inputs:
N - time-domain sequence length (samples).
N_s - time-domain frame shift (samples).
K_s - acoustic-domain frame shift (samples).
Output:
number of modulation-domain frames
'''
N = tf.cast(N, tf.float32)
N_s = tf.cast(N_s, tf.float32)
K_s = tf.cast(K_s, tf.float32)
return tf.cast(tf.ceil(tf.truediv(tf.truediv(N, N_s), K_s)), tf.int32) # number of frames.

50
lib/dev/optimisation.py Normal file
View File

@@ -0,0 +1,50 @@
## FILE: optimistion.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: Loss functions and algorithms for gradient descent optimisation.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import tensorflow as tf
## LOSS FUNCTIONS
def loss(target, estimate, loss_fnc, axis=1):
'loss functions for gradient descent.'
with tf.name_scope(loss_fnc + '_loss'):
if loss_fnc == 'quadratic':
loss = tf.reduce_sum(tf.square(tf.subtract(target, estimate)), axis=axis)
if loss_fnc == 'sigmoid_cross_entropy':
loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=estimate), axis=axis)
if loss_fnc == 'mean_quadratic':
loss = tf.reduce_mean(tf.square(tf.subtract(target, estimate)), axis=axis)
if loss_fnc == 'mean_sigmoid_cross_entropy':
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=estimate), axis=axis)
return loss
## GRADIENT DESCENT OPTIMISERS
def optimiser(loss, lr=None, epsilon=None, var_list=None, optimizer='adam', grad_clip=False):
'optimizers for training.'
with tf.name_scope(optimizer + '_opt'):
if optimizer == 'adam':
if lr == None: lr = 0.001 # default.
if epsilon == None: epsilon = 1e-8 # default.
optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=epsilon)
if optimizer == 'lazyadam':
if lr == None: lr = 0.001 # default.
if epsilon == None: epsilon = 1e-8 # default.
optimizer = tf.contrib.opt.LazyAdamOptimizer(learning_rate=lr, epsilon=epsilon)
if optimizer == 'nadam':
if lr == None: lr = 0.001 # default.
if epsilon == None: epsilon = 1e-8 # default.
optimizer = tf.contrib.opt.NadamOptimizer(learning_rate=lr, epsilon=epsilon)
if optimizer == 'sgd':
if lr == None: lr = 0.5 # default.
optimizer = tf.train.GradientDescentOptimizer(lr)
grads_and_vars = optimizer.compute_gradients(loss, var_list=var_list)
if grad_clip: grads_and_vars = [(tf.clip_by_value(gv[0], -1., 1.), gv[1]) for gv in grads_and_vars]
trainer = optimizer.apply_gradients(grads_and_vars)
return trainer, optimizer

57
lib/dev/sample_stats.py Normal file
View File

@@ -0,0 +1,57 @@
## FILE: sample_stats.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University
## BRIEF: Get statistics from sample of the training set.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import numpy as np
import tensorflow as tf
import os, pickle, random
import dev.se_batch as batch
from dev.acoustic.feat import polar
from tqdm import tqdm
import scipy.io as spio
## GET STATISTICS OF SAMPLE
def get_stats(stats_path, args, config):
if os.path.exists(stats_path + '/stats.p'):
print('Loading sample statistics from pickle file...')
with open(stats_path + '/stats.p', 'rb') as f:
args.stats = pickle.load(f)
return args
elif args.infer:
raise ValueError('You have not completed training (no stats.p file exsists). In the Deep Xi github repository, data/stats.p is available.')
else:
print('Finding sample statistics...')
random.shuffle(args.train_s_list) # shuffle list.
s_sample, s_sample_seq_len = batch.Clean_mbatch(args.train_s_list,
args.sample_size, 0, args.sample_size) # generate mini-batch of clean training waveforms.
d_sample, d_sample_seq_len = batch.Noise_mbatch(args.train_d_list,
args.sample_size, s_sample_seq_len) # generate mini-batch of noise training waveforms.
snr_sample = np.random.randint(args.min_snr, args.max_snr + 1, args.sample_size) # generate mini-batch of SNR levels.
s_ph = tf.placeholder(tf.int16, shape=[None, None], name='s_ph') # clean speech placeholder.
d_ph = tf.placeholder(tf.int16, shape=[None, None], name='d_ph') # noise placeholder.
s_len_ph = tf.placeholder(tf.int32, shape=[None], name='s_len_ph') # clean speech sequence length placeholder.
d_len_ph = tf.placeholder(tf.int32, shape=[None], name='d_len_ph') # noise sequence length placeholder.
snr_ph = tf.placeholder(tf.float32, shape=[None], name='snr_ph') # SNR placeholder.
analysis = polar.target_xi(s_ph, d_ph, s_len_ph, d_len_ph, snr_ph, args.N_w, args.N_s, args.NFFT, args.f_s)
sample_graph = analysis[0]
samples = []
with tf.Session(config=config) as sess:
for i in tqdm(range(s_sample.shape[0])):
sample = sess.run(sample_graph, feed_dict={s_ph: [s_sample[i]], d_ph: [d_sample[i]], s_len_ph: [s_sample_seq_len[i]],
d_len_ph: [d_sample_seq_len[i]], snr_ph: [snr_sample[i]]}) # sample of training set.
samples.append(sample)
samples = np.vstack(samples)
if len(samples.shape) != 2: ValueError('Incorrect shape for sample.')
args.stats = {'mu_hat': np.mean(samples, axis=0), 'sigma_hat': np.std(samples, axis=0)}
if not os.path.exists(stats_path): os.makedirs(stats_path) # make directory.
with open(stats_path + '/stats.p', 'wb') as f:
pickle.dump(args.stats, f)
spio.savemat(stats_path + '/stats.m', mdict={'mu_hat': args.stats['mu_hat'], 'sigma_hat': args.stats['sigma_hat']})
print('Sample statistics saved to pickle file.')
return args

155
lib/dev/se_batch.py Normal file
View File

@@ -0,0 +1,155 @@
## FILE: se_batch.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: Generates mini-batches, creates training, and test lists for speech enhancement.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import contextlib, glob, os, pickle, platform, random, sys, wave
import numpy as np
from dev.utils import read_wav
from scipy.io.wavfile import read
def Batch_list(file_dir, list_name, data_path=None, make_new=False):
from soundfile import SoundFile, SEEK_END
'''
Places the file paths and wav lengths of an audio file into a dictionary, which
is then appended to a list. SPHERE format cannot be used. 'glob' is used to
support Unix style pathname pattern expansions. Checks if the training list
has already been pickled, and loads it. If a different dataset is to be
used, delete the pickle file.
Inputs:
file_dir - directory containing the wavs.
list_name - name for the list.
data_path - path to store pickle files.
make_new - re-create list.
Outputs:
batch_list - list of file paths and wav length.
'''
file_name = ['*.wav', '*.flac', '*.mp3']
if data_path == None: data_path = 'data'
if not make_new:
if os.path.exists(data_path + '/' + list_name + '_list_' + platform.node() + '.p'):
print('Loading ' + list_name + ' list from pickle file...')
with open(data_path + '/' + list_name + '_list_' + platform.node() + '.p', 'rb') as f:
batch_list = pickle.load(f)
if batch_list[0]['file_path'].find(file_dir) != -1:
print('The ' + list_name + ' list has a total of %i entries.' % (len(batch_list)))
return batch_list
print('Creating ' + list_name + ' list, as no pickle file exists...')
batch_list = [] # list for wav paths and lengths.
for fn in file_name:
for file_path in glob.glob(os.path.join(file_dir, fn)):
f = SoundFile(file_path)
seq_len = f.seek(0, SEEK_END)
batch_list.append({'file_path': file_path, 'seq_len': seq_len}) # append dictionary.
if not os.path.exists(data_path): os.makedirs(data_path) # make directory.
with open(data_path + '/' + list_name + '_list_' + platform.node() + '.p', 'wb') as f:
pickle.dump(batch_list, f)
print('The ' + list_name + ' list has a total of %i entries.' % (len(batch_list)))
return batch_list
def Clean_mbatch(clean_list, mbatch_size, start_idx, end_idx):
'''
Creates a padded mini-batch of clean speech wavs.
Inputs:
clean_list - training list for the clean speech files.
mbatch_size - size of the mini-batch.
version - version name.
Outputs:
mbatch - matrix of paded wavs stored as a numpy array.
seq_len - length of each wavs strored as a numpy array.
clean_list - training list for the clean files.
'''
mbatch_list = clean_list[start_idx:end_idx] # get mini-batch list from training list.
maxlen = max([dic['seq_len'] for dic in mbatch_list]) # find maximum length wav in mini-batch.
seq_len = [] # list of the wavs lengths.
mbatch = np.zeros([len(mbatch_list), maxlen], np.int16) # numpy array for wav matrix.
for i in range(len(mbatch_list)):
(wav, _) = read_wav(mbatch_list[i]['file_path']) # read wav from given file path.
mbatch[i,:mbatch_list[i]['seq_len']] = wav # add wav to numpy array.
seq_len.append(mbatch_list[i]['seq_len']) # append length of wav to list.
return mbatch, np.array(seq_len, np.int32)
def Noise_mbatch(noise_list, mbatch_size, clean_seq_len):
'''
Creates a padded mini-batch of noise speech wavs.
Inputs:
noise_list - training list for the noise files.
mbatch_size - size of the mini-batch.
clean_seq_len - sequence length of each clean speech file in the mini-batch.
Outputs:
mbatch - matrix of paded wavs stored as a numpy array.
seq_len - length of each wavs strored as a numpy array.
'''
mbatch_list = random.sample(noise_list, mbatch_size) # get mini-batch list from training list.
for i in range(len(clean_seq_len)):
flag = True
while flag:
if mbatch_list[i]['seq_len'] < clean_seq_len[i]:
mbatch_list[i] = random.choice(noise_list)
else:
flag = False
maxlen = max([dic['seq_len'] for dic in mbatch_list]) # find maximum length wav in mini-batch.
seq_len = [] # list of the wav lengths.
mbatch = np.zeros([len(mbatch_list), maxlen], np.int16) # numpy array for wav matrix.
for i in range(len(mbatch_list)):
(wav, _) = read_wav(mbatch_list[i]['file_path']) # read wav from given file path.
mbatch[i,:mbatch_list[i]['seq_len']] = wav # add wav to numpy array.
seq_len.append(mbatch_list[i]['seq_len']) # append length of wav to list.
return mbatch, np.array(seq_len, np.int32)
def Batch(fdir, snr_l):
'''
REQUIRES REWRITING.
Places all of the test waveforms from the list into a numpy array.
SPHERE format cannot be used. 'glob' is used to support Unix style pathname
pattern expansions. Waveforms are padded to the maximum waveform length. The
waveform lengths are recorded so that the correct lengths can be sliced
for feature extraction. The SNR levels of each test file are placed into a
numpy array. Also returns a list of the file names.
Inputs:
fdir - directory containing the waveforms.
fnames - filename/s of the waveforms.
snr_l - list of the SNR levels used.
Outputs:
wav_np - matrix of paded waveforms stored as a numpy array.
len_np - length of each waveform strored as a numpy array.
snr_test_np - numpy array of all the SNR levels for the test set.
fname_l - list of filenames.
'''
fname_l = [] # list of file names.
wav_l = [] # list for waveforms.
snr_test_l = [] # list of SNR levels for the test set.
# if isinstance(fnames, str): fnames = [fnames] # if string, put into list.
fnames = ['*.wav', '*.flac', '*.mp3']
for fname in fnames:
for fpath in glob.glob(os.path.join(fdir, fname)):
for snr in snr_l:
if fpath.find('_' + str(snr) + 'dB') != -1:
snr_test_l.append(snr) # append SNR level.
(wav, _) = read_wav(fpath) # read waveform from given file path.
if np.isnan(wav).any() or np.isinf(wav).any():
raise ValueError('Error: NaN or Inf value. File path: %s.' % (file_path))
wav_l.append(wav) # append.
fname_l.append(os.path.basename(os.path.splitext(fpath)[0])) # append name.
len_l = [] # list of the waveform lengths.
maxlen = max(len(wav) for wav in wav_l) # maximum length of waveforms.
wav_np = np.zeros([len(wav_l), maxlen], np.int16) # numpy array for waveform matrix.
for (i, wav) in zip(range(len(wav_l)), wav_l):
wav_np[i,:len(wav)] = wav # add waveform to numpy array.
len_l.append(len(wav)) # append length of waveform to list.
return wav_np, np.array(len_l, np.int32), np.array(snr_test_l, np.int32), fname_l

97
lib/dev/train.py Normal file
View File

@@ -0,0 +1,97 @@
## FILE: train.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: Training module.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import dev.se_batch as batch
import numpy as np
import math, os, random
import tensorflow as tf
from datetime import datetime
from tqdm import tqdm
## TRAINING
def train(sess, net, args):
print("Training...")
random.shuffle(args.train_s_list) # shuffle training list.
## CONTINUE FROM LAST EPOCH
if args.cont:
epoch_size = len(args.train_s_list); epoch_comp = args.epoch; start_idx = 0;
end_idx = args.mbatch_size; val_error = float("inf") # create epoch parameters.
net.saver.restore(sess, args.model_path + '/epoch-' + str(args.epoch)) # load model from last epoch.
## TRAIN RAW NETWORK
else:
epoch_size = len(args.train_s_list); epoch_comp = 0; start_idx = 0;
end_idx = args.mbatch_size; val_error = float("inf") # create epoch parameters.
if args.mbatch_size > epoch_size: raise ValueError('Error: mini-batch size is greater than the epoch size.')
sess.run(tf.global_variables_initializer()) # initialise model variables.
net.saver.save(sess, args.model_path + '/epoch', global_step=epoch_comp) # save model.
## TRAINING LOG
if not os.path.exists('log'): os.makedirs('log') # create log directory.
with open("log/" + args.ver + ".csv", "a") as results: results.write("'"'Validation error'"', '"'Training error'"', '"'Epoch'"', '"'D/T'"'\n")
train_err = 0; mbatch_count = 0
while args.train:
print('Training E%d (ver=%s, gpu=%s, params=%g)...' % (epoch_comp + 1, args.ver, args.gpu, args.params))
for _ in tqdm(range(args.train_steps)):
## MINI-BATCH GENERATION
mbatch_size_iter = end_idx - start_idx # number of examples in mini-batch for the training iteration.
s_mbatch, s_mbatch_seq_len = batch.Clean_mbatch(args.train_s_list,
mbatch_size_iter, start_idx, end_idx) # generate mini-batch of clean training waveforms.
d_mbatch, d_mbatch_seq_len = batch.Noise_mbatch(args.train_d_list,
mbatch_size_iter, s_mbatch_seq_len) # generate mini-batch of noise training waveforms.
snr_mbatch = np.random.randint(args.min_snr, args.max_snr + 1, end_idx - start_idx) # generate mini-batch of SNR levels.
## TRAINING ITERATION
mbatch = sess.run(net.train_feat, feed_dict={net.s_ph: s_mbatch, net.d_ph: d_mbatch,
net.s_len_ph: s_mbatch_seq_len, net.d_len_ph: d_mbatch_seq_len, net.snr_ph: snr_mbatch}) # mini-batch.
[_, mbatch_err] = sess.run([net.trainer, net.total_loss], feed_dict={net.input_ph: mbatch[0], net.target_ph: mbatch[1],
net.nframes_ph: mbatch[2], net.training_ph: True}) # training iteration.
if not math.isnan(mbatch_err):
train_err += mbatch_err; mbatch_count += 1
## UPDATE EPOCH PARAMETERS
start_idx += args.mbatch_size; end_idx += args.mbatch_size # start and end index of mini-batch.
if end_idx > epoch_size: end_idx = epoch_size # if less than the mini-batch size of examples is left.
## VALIDATION SET ERROR
start_idx = 0; end_idx = args.mbatch_size # reset start and end index of mini-batch.
random.shuffle(args.train_s_list) # shuffle list.
start_idx = 0; end_idx = args.mbatch_size; frames = 0; val_error = 0; # validation variables.
print('Validation error for E%d...' % (epoch_comp + 1))
for _ in tqdm(range(args.val_steps)):
mbatch = sess.run(net.train_feat, feed_dict={net.s_ph: args.val_s[start_idx:end_idx],
net.d_ph: args.val_d[start_idx:end_idx], net.s_len_ph: args.val_s_len[start_idx:end_idx],
net.d_len_ph: args.val_d_len[start_idx:end_idx], net.snr_ph: args.val_snr[start_idx:end_idx]}) # mini-batch.
val_error_mbatch = sess.run(net.loss, feed_dict={net.input_ph: mbatch[0],
net.target_ph: mbatch[1], net.nframes_ph: mbatch[2], net.training_ph: False}) # validation error for each frame in mini-batch.
val_error += np.sum(val_error_mbatch)
frames += mbatch[1].shape[0] # total number of frames.
print("Validation error for Epoch %d: %3.3f%% complete. " %
(epoch_comp + 1, 100*(end_idx/args.val_s_len.shape[0])), end="\r")
start_idx += args.mbatch_size; end_idx += args.mbatch_size
if end_idx > args.val_s_len.shape[0]: end_idx = args.val_s_len.shape[0]
val_error /= frames # validation error.
epoch_comp += 1 # an epoch has been completed.
net.saver.save(sess, args.model_path + '/epoch', global_step=epoch_comp) # save model.
print("E%d: train err=%3.3f, val err=%3.3f. " %
(epoch_comp, train_err/mbatch_count, val_error))
with open("log/" + args.ver + ".csv", "a") as results:
results.write("%g, %g, %d, %s\n" % (val_error, train_err/mbatch_count,
epoch_comp, datetime.now().strftime('%Y-%m-%d/%H:%M:%S')))
train_err = 0; mbatch_count = 0; start_idx = 0; end_idx = args.mbatch_size
if epoch_comp >= args.max_epochs:
args.train = False
print('\nTraining complete. Validation error for epoch %d: %g. ' %
(epoch_comp, val_error))

49
lib/dev/utils.py Normal file
View File

@@ -0,0 +1,49 @@
## FILE: utils.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University.
## BRIEF: General utility functions/modules.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
from os.path import expanduser
import argparse, os, string
import numpy as np
from scipy.io.wavfile import write as wav_write
import tensorflow as tf
import soundfile as sf
def save_wav(save_path, f_s, wav):
if isinstance(wav[0], np.float32): wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
wav_write(save_path, f_s, wav)
def read_wav(path):
wav, f_s = sf.read(path, dtype='int16')
return wav, f_s
def log10(x):
numerator = tf.log(x)
denominator = tf.constant(np.log(10), dtype=numerator.dtype)
return tf.truediv(numerator, denominator)
## CHARACTER DICTIONARIES
def char_dict():
chars = list(" " + string.ascii_lowercase + "'") # 26 alphabetic characters + space + EOS + blank = 29 classes.
char2idx = dict(zip(chars, [i for i in range(len(chars))]))
idx2char = dict((y,x) for x,y in char2idx.items())
return char2idx, idx2char
## NUMPY SIGMOID FUNCTION
def np_sigmoid(x): return np.divide(1, np.add(1, np.exp(np.negative(x))))
## GPU CONFIGURATION
def gpu_config(gpu_selection, log_device_placement=False):
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_selection)
config = tf.ConfigProto()
config.allow_soft_placement=True
config.gpu_options.allow_growth=True
config.log_device_placement=log_device_placement
return config

BIN
lib/dev/utils.pyc Normal file

Binary file not shown.

33
lib/dev/xi.py Normal file
View File

@@ -0,0 +1,33 @@
## FILE: xi.py
## DATE: 2019
## AUTHOR: Aaron Nicolson
## AFFILIATION: Signal Processing Laboratory, Griffith University
## BRIEF: Functions for computing a priori SNR.
##
## This Source Code Form is subject to the terms of the Mozilla Public
## License, v. 2.0. If a copy of the MPL was not distributed with this
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
import numpy as np
import scipy.special as spsp
import tensorflow as tf
def log10(x):
numerator = tf.log(x)
denominator = tf.constant(np.log(10), dtype=numerator.dtype)
return tf.truediv(numerator, denominator)
def xi(s_MAG, d_MAG):
return tf.truediv(tf.square(s_MAG), tf.maximum(tf.square(d_MAG), 1e-12)) # a priori SNR.
def xi_dB(s_MAG, d_MAG):
return tf.multiply(10.0, log10(tf.maximum(xi(s_MAG, d_MAG), 1e-12)))
def xi_bar(s_MAG, d_MAG, mu, sigma):
return tf.multiply(0.5, tf.add(1.0, tf.erf(tf.truediv(tf.subtract(xi_dB(s_MAG, d_MAG), mu),
tf.multiply(sigma, tf.sqrt(2.0))))))
def xi_hat(xi_bar_hat, mu, sigma):
xi_dB_hat = np.add(np.multiply(np.multiply(sigma, np.sqrt(2.0)),
spsp.erfinv(np.subtract(np.multiply(2.0, xi_bar_hat), 1))), mu)
return np.power(10.0, np.divide(xi_dB_hat, 10.0))

202
model/3f/checkpoint Normal file
View File

@@ -0,0 +1,202 @@
model_checkpoint_path: "epoch-200"
all_model_checkpoint_paths: "epoch-0"
all_model_checkpoint_paths: "epoch-1"
all_model_checkpoint_paths: "epoch-2"
all_model_checkpoint_paths: "epoch-3"
all_model_checkpoint_paths: "epoch-4"
all_model_checkpoint_paths: "epoch-5"
all_model_checkpoint_paths: "epoch-6"
all_model_checkpoint_paths: "epoch-7"
all_model_checkpoint_paths: "epoch-8"
all_model_checkpoint_paths: "epoch-9"
all_model_checkpoint_paths: "epoch-10"
all_model_checkpoint_paths: "epoch-11"
all_model_checkpoint_paths: "epoch-12"
all_model_checkpoint_paths: "epoch-13"
all_model_checkpoint_paths: "epoch-14"
all_model_checkpoint_paths: "epoch-15"
all_model_checkpoint_paths: "epoch-16"
all_model_checkpoint_paths: "epoch-17"
all_model_checkpoint_paths: "epoch-18"
all_model_checkpoint_paths: "epoch-19"
all_model_checkpoint_paths: "epoch-20"
all_model_checkpoint_paths: "epoch-21"
all_model_checkpoint_paths: "epoch-22"
all_model_checkpoint_paths: "epoch-23"
all_model_checkpoint_paths: "epoch-24"
all_model_checkpoint_paths: "epoch-25"
all_model_checkpoint_paths: "epoch-26"
all_model_checkpoint_paths: "epoch-27"
all_model_checkpoint_paths: "epoch-28"
all_model_checkpoint_paths: "epoch-29"
all_model_checkpoint_paths: "epoch-30"
all_model_checkpoint_paths: "epoch-31"
all_model_checkpoint_paths: "epoch-32"
all_model_checkpoint_paths: "epoch-33"
all_model_checkpoint_paths: "epoch-34"
all_model_checkpoint_paths: "epoch-35"
all_model_checkpoint_paths: "epoch-36"
all_model_checkpoint_paths: "epoch-37"
all_model_checkpoint_paths: "epoch-38"
all_model_checkpoint_paths: "epoch-39"
all_model_checkpoint_paths: "epoch-40"
all_model_checkpoint_paths: "epoch-41"
all_model_checkpoint_paths: "epoch-42"
all_model_checkpoint_paths: "epoch-43"
all_model_checkpoint_paths: "epoch-44"
all_model_checkpoint_paths: "epoch-45"
all_model_checkpoint_paths: "epoch-46"
all_model_checkpoint_paths: "epoch-47"
all_model_checkpoint_paths: "epoch-48"
all_model_checkpoint_paths: "epoch-49"
all_model_checkpoint_paths: "epoch-50"
all_model_checkpoint_paths: "epoch-51"
all_model_checkpoint_paths: "epoch-52"
all_model_checkpoint_paths: "epoch-53"
all_model_checkpoint_paths: "epoch-54"
all_model_checkpoint_paths: "epoch-55"
all_model_checkpoint_paths: "epoch-56"
all_model_checkpoint_paths: "epoch-57"
all_model_checkpoint_paths: "epoch-58"
all_model_checkpoint_paths: "epoch-59"
all_model_checkpoint_paths: "epoch-60"
all_model_checkpoint_paths: "epoch-61"
all_model_checkpoint_paths: "epoch-62"
all_model_checkpoint_paths: "epoch-63"
all_model_checkpoint_paths: "epoch-64"
all_model_checkpoint_paths: "epoch-65"
all_model_checkpoint_paths: "epoch-66"
all_model_checkpoint_paths: "epoch-67"
all_model_checkpoint_paths: "epoch-68"
all_model_checkpoint_paths: "epoch-69"
all_model_checkpoint_paths: "epoch-70"
all_model_checkpoint_paths: "epoch-71"
all_model_checkpoint_paths: "epoch-72"
all_model_checkpoint_paths: "epoch-73"
all_model_checkpoint_paths: "epoch-74"
all_model_checkpoint_paths: "epoch-75"
all_model_checkpoint_paths: "epoch-76"
all_model_checkpoint_paths: "epoch-77"
all_model_checkpoint_paths: "epoch-78"
all_model_checkpoint_paths: "epoch-79"
all_model_checkpoint_paths: "epoch-80"
all_model_checkpoint_paths: "epoch-81"
all_model_checkpoint_paths: "epoch-82"
all_model_checkpoint_paths: "epoch-83"
all_model_checkpoint_paths: "epoch-84"
all_model_checkpoint_paths: "epoch-85"
all_model_checkpoint_paths: "epoch-86"
all_model_checkpoint_paths: "epoch-87"
all_model_checkpoint_paths: "epoch-88"
all_model_checkpoint_paths: "epoch-89"
all_model_checkpoint_paths: "epoch-90"
all_model_checkpoint_paths: "epoch-91"
all_model_checkpoint_paths: "epoch-92"
all_model_checkpoint_paths: "epoch-93"
all_model_checkpoint_paths: "epoch-94"
all_model_checkpoint_paths: "epoch-95"
all_model_checkpoint_paths: "epoch-96"
all_model_checkpoint_paths: "epoch-97"
all_model_checkpoint_paths: "epoch-98"
all_model_checkpoint_paths: "epoch-99"
all_model_checkpoint_paths: "epoch-100"
all_model_checkpoint_paths: "epoch-101"
all_model_checkpoint_paths: "epoch-102"
all_model_checkpoint_paths: "epoch-103"
all_model_checkpoint_paths: "epoch-104"
all_model_checkpoint_paths: "epoch-105"
all_model_checkpoint_paths: "epoch-106"
all_model_checkpoint_paths: "epoch-107"
all_model_checkpoint_paths: "epoch-108"
all_model_checkpoint_paths: "epoch-109"
all_model_checkpoint_paths: "epoch-110"
all_model_checkpoint_paths: "epoch-111"
all_model_checkpoint_paths: "epoch-112"
all_model_checkpoint_paths: "epoch-113"
all_model_checkpoint_paths: "epoch-114"
all_model_checkpoint_paths: "epoch-115"
all_model_checkpoint_paths: "epoch-116"
all_model_checkpoint_paths: "epoch-117"
all_model_checkpoint_paths: "epoch-118"
all_model_checkpoint_paths: "epoch-119"
all_model_checkpoint_paths: "epoch-120"
all_model_checkpoint_paths: "epoch-121"
all_model_checkpoint_paths: "epoch-122"
all_model_checkpoint_paths: "epoch-123"
all_model_checkpoint_paths: "epoch-124"
all_model_checkpoint_paths: "epoch-125"
all_model_checkpoint_paths: "epoch-126"
all_model_checkpoint_paths: "epoch-127"
all_model_checkpoint_paths: "epoch-128"
all_model_checkpoint_paths: "epoch-129"
all_model_checkpoint_paths: "epoch-130"
all_model_checkpoint_paths: "epoch-131"
all_model_checkpoint_paths: "epoch-132"
all_model_checkpoint_paths: "epoch-133"
all_model_checkpoint_paths: "epoch-134"
all_model_checkpoint_paths: "epoch-135"
all_model_checkpoint_paths: "epoch-136"
all_model_checkpoint_paths: "epoch-137"
all_model_checkpoint_paths: "epoch-138"
all_model_checkpoint_paths: "epoch-139"
all_model_checkpoint_paths: "epoch-140"
all_model_checkpoint_paths: "epoch-141"
all_model_checkpoint_paths: "epoch-142"
all_model_checkpoint_paths: "epoch-143"
all_model_checkpoint_paths: "epoch-144"
all_model_checkpoint_paths: "epoch-145"
all_model_checkpoint_paths: "epoch-146"
all_model_checkpoint_paths: "epoch-147"
all_model_checkpoint_paths: "epoch-148"
all_model_checkpoint_paths: "epoch-149"
all_model_checkpoint_paths: "epoch-150"
all_model_checkpoint_paths: "epoch-151"
all_model_checkpoint_paths: "epoch-152"
all_model_checkpoint_paths: "epoch-153"
all_model_checkpoint_paths: "epoch-154"
all_model_checkpoint_paths: "epoch-155"
all_model_checkpoint_paths: "epoch-156"
all_model_checkpoint_paths: "epoch-157"
all_model_checkpoint_paths: "epoch-158"
all_model_checkpoint_paths: "epoch-159"
all_model_checkpoint_paths: "epoch-160"
all_model_checkpoint_paths: "epoch-161"
all_model_checkpoint_paths: "epoch-162"
all_model_checkpoint_paths: "epoch-163"
all_model_checkpoint_paths: "epoch-164"
all_model_checkpoint_paths: "epoch-165"
all_model_checkpoint_paths: "epoch-166"
all_model_checkpoint_paths: "epoch-167"
all_model_checkpoint_paths: "epoch-168"
all_model_checkpoint_paths: "epoch-169"
all_model_checkpoint_paths: "epoch-170"
all_model_checkpoint_paths: "epoch-171"
all_model_checkpoint_paths: "epoch-172"
all_model_checkpoint_paths: "epoch-173"
all_model_checkpoint_paths: "epoch-174"
all_model_checkpoint_paths: "epoch-175"
all_model_checkpoint_paths: "epoch-176"
all_model_checkpoint_paths: "epoch-177"
all_model_checkpoint_paths: "epoch-178"
all_model_checkpoint_paths: "epoch-179"
all_model_checkpoint_paths: "epoch-180"
all_model_checkpoint_paths: "epoch-181"
all_model_checkpoint_paths: "epoch-182"
all_model_checkpoint_paths: "epoch-183"
all_model_checkpoint_paths: "epoch-184"
all_model_checkpoint_paths: "epoch-185"
all_model_checkpoint_paths: "epoch-186"
all_model_checkpoint_paths: "epoch-187"
all_model_checkpoint_paths: "epoch-188"
all_model_checkpoint_paths: "epoch-189"
all_model_checkpoint_paths: "epoch-190"
all_model_checkpoint_paths: "epoch-191"
all_model_checkpoint_paths: "epoch-192"
all_model_checkpoint_paths: "epoch-193"
all_model_checkpoint_paths: "epoch-194"
all_model_checkpoint_paths: "epoch-195"
all_model_checkpoint_paths: "epoch-196"
all_model_checkpoint_paths: "epoch-197"
all_model_checkpoint_paths: "epoch-198"
all_model_checkpoint_paths: "epoch-199"
all_model_checkpoint_paths: "epoch-200"

Binary file not shown.

BIN
model/3f/epoch-175.index Normal file

Binary file not shown.

BIN
model/3f/epoch-175.meta Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

25
server.py Normal file
View File

@@ -0,0 +1,25 @@
from flask import Flask
from flask import Flask, request, render_template, url_for, session, jsonify, make_response
app = Flask(__name__)
@app.route('/')
def upload():
if request.method == 'POST' and 'photo' in request.files:
filename = photos.save(request.files['photo'])
rec = Photo(filename=filename, user=g.user.id)
rec.store()
flash("Photo saved.")
return redirect(url_for('show', id=rec.id))
return render_template('upload.html')
@app.route('/photo/<id>')
def show(id):
photo = Photo.load(id)
if photo is None:
abort(404)
url = photos.url(photo.filename)
return render_template('show.html', url=url, photo=photo)
if __name__ == '__main__':
app.run(host='192.168.1.254',port=9100,debug=True)

6
set/info.txt Normal file
View File

@@ -0,0 +1,6 @@
Directories used to store clean speech and noise for validation and training, as well as noisy speech for testing.
For the validation set only:
Identical filenames for the clean speech and noise must be placed in 'val_clean_speech' and 'val_noise', with the SNR at the end of the filename. As an example:
'./val_clean_speech/198_19-198-0003_Machinery17_15dB.wav' contains the clean speech, and './val_noise/198_19-198-0003_Machinery17_15dB.wav' contains the noise at the same length. They will be mixed together at the SNR level specified in the filename.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1 @@
Place all noisy speech .wav files for testing here.

View File

@@ -0,0 +1 @@
Place all clean speech .wav files for training here.

1
set/train_noise/info.txt Normal file
View File

@@ -0,0 +1 @@
Place all noise .wav files for training here.

View File

@@ -0,0 +1 @@
Place all clean speech .wav files for validation here.

1
set/val_noise/info.txt Normal file
View File

@@ -0,0 +1 @@
Place all noise .wav files for validation here.

214
templates/index.html Normal file
View File

@@ -0,0 +1,214 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.2.1/css/bootstrap.min.css" integrity="sha384-GJzZqFGwb1QTTN6wy59ffF1BuGJpLSa9DkKMp0DgiMDm4iYMj70gZWKYbI706tWS" crossorigin="anonymous">
<title>deep-xi</title>
</head>
<style>
body{
text-align: center;
}
.custom_btn{
background-color: lightgreen;
border: none;
padding: 8px;
border-radius: 3px;
margin-top: 10px;
}
.container{
padding-left: 20%;
padding-right: 20%;
}
</style>
<body>
<div class="container">
<div class="row">
<div class="col">
<div class="mb-3 mt-3">
<h2 class="mb-3" style="font-weight: 300">DeepXi</h2>
<div class="form-group mb-3">
<div class="custom-file">
<input type="file" class="custom-file-input" name="file_input" id="file_input" oninput="input_filename();">
<label id="file_input_label" class="custom-file-label" >Select file (wav, mp3, flac)</label>
</div>
</div>
<div id="progress_wrapper" class="d-none">
<label id="progress_status"></label>
<div class="progress mb-3">
<div id="progress" class="progress-bar" role="progressbar" aria-valuenow="25" aria-valuemin="0" aria-valuemax="100"></div>
</div>
</div>
<div id="alert_wrapper" class="d-none"></div>
<button onclick="upload('/upload');" id="upload_btn" class="btn btn-primary">Upload</button>
<button class="btn btn-primary d-none" id="loading_btn" type="button" disabled>
<span class="spinner-border spinner-border-sm" role="status" aria-hidden="true"></span> Uploading...
</button>
<button type="button" id="cancel_btn" class="btn btn-secondary d-none">Cancel</button>
</div>
<div id="original_play" style="margin: 10px;" class="d-none">
<h4>original</h4>
<audio id="original_audio" controls src="" type="audio/*">
</div>
<div id="predict_play" style="margin: 10px;" class="d-none">
<h4>predict</h4>
<audio id="predict_audio" controls src="" type="audio/*">
</div>
</div>
</div>
</div>
<script>
var input = document.getElementById("file_input")
var file_input_label = document.getElementById("file_input_label")
var progress = document.getElementById("progress");
var progress_wrapper = document.getElementById("progress_wrapper");
var progress_status = document.getElementById("progress_status");
var upload_btn = document.getElementById("upload_btn");
var loading_btn = document.getElementById("loading_btn");
var loading_btn_text = document.getElementById("loading_btn_text");
var cancel_btn = document.getElementById("cancel_btn");
var alert_wrapper = document.getElementById("alert_wrapper");
var original_play = document.getElementById("original_play")
var predict_play = document.getElementById("predict_play")
var original_audio = document.getElementById("original_audio")
var predict_audio = document.getElementById("predict_audio")
function input_filename() {
file_input_label.innerText = input.files[0].name;}
function show_alert(message, alert, autohide=false) {
alert_wrapper.classList.remove("d-none")
alert_wrapper.innerHTML = `
<div id="alert" class="alert alert-${alert} alert-dismissible fade show" role="alert">
<span>${message}</span>
<button type="button" class="close" data-dismiss="alert" aria-label="Close">
<span aria-hidden="true">&times;</span>
</button>
</div>
`
// if (autohide){
// setTimeout(() => {alert_wrapper.classList.add("d-none")}, 5000)
// }
}
function upload(url) {
console.log(url)
if (!input.value) {
show_alert("No file selected", "warning")
return;
}
var data = new FormData();
var request = new XMLHttpRequest();
request.responseType = "json";
alert_wrapper.innerHTML = "";
input.disabled = true;
upload_btn.classList.add("d-none");
loading_btn.classList.remove("d-none");
cancel_btn.classList.remove("d-none");
progress_wrapper.classList.remove("d-none");
original_play.classList.add("d-none")
predict_play.classList.add("d-none")
var file = input.files[0];
var filename = file.name;
var filesize = file.size;
document.cookie = `filesize=${filesize}`;
data.append("file", file);
request.upload.addEventListener("progress", function (e) {
var loaded = e.loaded;
var total = e.total
var percent_complete = (loaded / total) * 100;
progress.setAttribute("style", `width: ${Math.floor(percent_complete)}%`);
progress_status.innerHTML = `${Math.floor(percent_complete)}% uploaded`;
if (percent_complete == 100){
show_alert("Saving file to server...", "primary");
loading_btn.innerHTML = "<span class='spinner-border spinner-border-sm' role='status' aria-hidden='true'></span>" + " Predict..."
}
})
// request load handler (transfer complete)
request.addEventListener("load", function (e) {
if (request.status == 200) {
show_alert(`${request.response.message}`, "success", true);
original_play.classList.remove("d-none")
original_audio.src = request.response.file_path
original_audio.autoplay = true
progress_wrapper.classList.add("d-none")
//loading_btn.innerHTML = "<span class='spinner-border spinner-border-sm' role='status' aria-hidden='true'></span>" + " Predict..."
predict_rq = predict('/predict/' + String(request.response.file_path).split('/').join('='))
if (predict_rq.status == 200){
json = JSON.parse(predict_rq.responseText)
show_alert(json.message, "success")
console.log(predict_rq.responseText)
console.log(json.out_file_path)
predict_play.classList.remove("d-none")
predict_audio.src = json.out_file_path
}
else{
show_alert(`Error predict file:` + filename, "danger");
}
loading_btn.classList.add("d-none")
loading_btn.innerHTML = "<span class='spinner-border spinner-border-sm' role='status' aria-hidden='true'></span>" + " Uploading..."
}
else {
}
reset();
});
// request error handler
request.addEventListener("error", function (e) {
reset();
show_alert(`Error uploading file`, "danger");
});
// request abort handler
request.addEventListener("abort", function (e) {
reset();
show_alert(`Upload cancelled`, "danger");
});
// Open and send the request
request.open("post", url);
request.send(data);
cancel_btn.addEventListener("click", function () {
request.abort();
})
}
function predict(url){
var xmlHttp = new XMLHttpRequest();
//xmlHttp.responseType = "json";
xmlHttp.open( "GET", url, false ); // false for synchronous request
xmlHttp.send( null );
//console.log(xmlHttp.response.out_file_path)
return xmlHttp;
}
function reset() {
// Clear the input
input.value = null;
// Hide the cancel button
cancel_btn.classList.add("d-none");
// Reset the input element
input.disabled = false;
// Show the upload button
upload_btn.classList.remove("d-none");
// Hide the loading button
loading_btn.classList.add("d-none");
//loading_btn_text.textContent = "Uploading..."
// Hide the progress bar
progress_wrapper.classList.add("d-none");
// Reset the progress bar state
progress.setAttribute("style", `width: 0%`);
// Reset the input placeholder
file_input_label.innerText = "Select file";
}
</script>
</body>
</html>

5
templates1/index.html Normal file
View File

@@ -0,0 +1,5 @@
<form method=POST enctype=multipart/form-data action="{{ url_for('upload') }}">
...
<input type=file name=photo>
...
</form>

5
templates1/upload.html Normal file
View File

@@ -0,0 +1,5 @@
<form method=POST enctype=multipart/form-data action="{{ url_for('upload') }}">
...
<input type=file name=photo>
...
</form>

9
testload.py Normal file
View File

@@ -0,0 +1,9 @@
import librosa
print(librosa.__version__)
import scipy
filename = 'music.mp3'
#sr, data = scipy.io.wavfile.read(filename)
data2, sr2 = librosa.load(filename,None)
#print(sr,sr2)
print(sr2,data2)

65
utils.py Normal file
View File

@@ -0,0 +1,65 @@
import wave
import requests
import json
import os
from time import sleep
import csv
import numpy as np
#from pydub import AudioSegment
import librosa
import wave
import audioop
import scipy
from datetime import datetime
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from inaSpeechSegmenter import Segmenter, seg2csv
def split(file_name, out_dir):
print('\nREMOVE MUSIC AND CUT')
seg = Segmenter()
segmentation = seg(file_name)
sample_rate, raw_audio = scipy.io.wavfile.read(file_name)
#raw_audio , sr = librosa.load(file_name, sr=16000)
speech = []
print(segmentation)
count = 1
if not os.path.exists(out_dir):
os.mkdir(out_dir)
list_file = []
for s in segmentation:
if s[0] != 'Music' and s[0] != 'NOACTIVITY':
print(str(count),'dur of sen:',s[2]-s[1])
speech_data = raw_audio[int(s[1]*sample_rate) - int(sample_rate/4):int(s[2]*sample_rate + int(sample_rate/4))]
speech_data = np.array(speech_data)
print(len(speech_data), len(speech_data)/sample_rate)
if len(speech_data)/sample_rate < 0.5 or len(speech_data)/sample_rate > 20:
continue
else:
out_filename = out_dir + '/' + file_name.split('/')[-1].replace('.wav','') + '_' + str(count) + '.wav'
list_file.append(out_filename)
scipy.io.wavfile.write(out_filename, sample_rate, speech_data)
count += 1
return list_file
def cvtToWavMono16(filename):
#covert to wav
#try:
#if filename.split('.')[-1] == 'mp3':
# filename = mp3_to_wav(filename)
#stereo_to_mono()
#to16000()
print('start convert to wav 16000 mono')
#filename = 'myfile.wav'
# Extract data and sampling rate from file
#data, fs = sf.read(filename, dtype='float32')
sig, rate = librosa.load(filename, sr=16000, mono=True)
#print(sig, rate)
new_filename = ''.join(filename.split('.')[:-1]) + '.wav'
librosa.output.write_wav(new_filename, sig, sr=rate)
print('converted to wav 16000 mono')
return new_filename
# except:
# return False