Added multiprocessing of the undersampling process

This commit is contained in:
ACK-J
2022-06-08 19:23:29 -04:00
parent 78a338f115
commit f67ed53150
2 changed files with 103 additions and 52 deletions

4
Dataset_Files/.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

View File

@@ -450,7 +450,7 @@ def discover_wallet_directories(dir_to_search):
num_bad_txs = 0
pool = Pool(processes=NUM_PROCESSES) # Multiprocessing pool
# Multiprocess combining the 6 csv files for each wallet
for wallet_tx_data in tqdm(pool.imap_unordered(func=combine_files, iterable=Wallet_info), desc="(Multiprocessing) Combining Exported Wallet Transactions", total=len(Wallet_info), colour='blue'):
for wallet_tx_data in tqdm(pool.imap_unordered(func=combine_files, iterable=Wallet_info), desc="(Multiprocessing) Combining Exported Wallet Files", total=len(Wallet_info), colour='blue'):
# Make sure there are transactions in the data before adding it to the dataset
for tx_hash, tx_data in wallet_tx_data.items():
if "Input_True_Rings" in tx_data.keys():
@@ -459,7 +459,7 @@ def discover_wallet_directories(dir_to_search):
else:
num_bad_txs += 1
print("There were " + str(num_bad_txs) + " bad transactions that were deleted out of a total " + str(total_txs) + " transactions!")
print("Dataset now includes " + str(len(data)) + " transactions.")
print("The dataset now includes " + str(len(data)) + " transactions.")
def clean_transaction(transaction):
@@ -526,15 +526,15 @@ def create_feature_set(database):
Valid_Transactions = []
num_errors = 0
feature_set = dict()
num_of_valid_txs = 0
# Iterate through each tx hash
num_of_valid_txs = 0 # Incrementer which doesn't count invalid txs
# Iterate through each tx hash in the database dict
for idx, tx_hash in tqdm(enumerate(database.keys()), total=len(database), colour='blue', desc="Cleaning Transactions"):
# Pass the transaction ( by reference ) to be stripped of non-features and receive the labels back
try:
private_info = clean_transaction(database[tx_hash])
except Exception as e:
num_errors += 1
continue
continue # Dont process the tx and loop
# add tx hash to good list
#Valid_Transactions.append(DataFrame(CherryPicker(database[tx_hash]).flatten(delim='.').get(), index=[idx]))
# Flatten each transaction and iterate over each feature
@@ -550,7 +550,7 @@ def create_feature_set(database):
else: # If the feature is already in the feature set
# Check if there are any transactions that did not have this feature
if len(feature_set[k]) < num_of_valid_txs:
# Add -1 for those occurences
# Add -1 for those occurrences
for i in range(num_of_valid_txs-len(feature_set[k])-1):
feature_set[k].append(-1)
# Append the feature
@@ -588,6 +588,55 @@ def create_feature_set(database):
return feature_set_df, labels
def undersample_processing(y, ns, min_occurrences, occurrences):
"""
:param y:
:param ns:
:param min_occurrences:
:param occurrences:
:return:
"""
undersampled_y = []
new_X = []
y_idx, ring_array = y
# For each array of ring members iterate over each index
for ring_array_idx in range(len(ring_array["True_Ring_Pos"])):
# Get the true ring position (label) for the current iteration
ring_pos = int(ring_array["True_Ring_Pos"][ring_array_idx].split("/")[0])
total_rings = int(ring_array["True_Ring_Pos"][ring_array_idx].split("/")[1])
# Check to see if we hit the maximum number of labels for this position and that the
# number of ring members is what we expect.
if occurrences[ring_pos] < min_occurrences and total_rings == NUM_RING_MEMBERS:
occurrences[ring_pos] = occurrences[ring_pos] + 1
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.iloc.html#pandas.DataFrame.iloc
# Slice out the row from the dataframe but keep it as a dataframe
temp_df = ns.X.iloc[[y_idx]]
# Go through each column name in the temp dataframe
for col_name in temp_df.columns:
# Check if the column name has data relating to irrelevant ring signatures
if "Inputs." in col_name and "." + str(ring_array_idx) + "." not in col_name:
# Delete the columns
temp_df = temp_df.drop([col_name], axis=1)
# Check if the column name is for the current ring signature
elif "Inputs." in col_name and "." + str(ring_array_idx) + "." in col_name:
# Rename the column such that it doesn't have the .0. or .1. positioning information
temp_df.rename(columns={col_name: col_name.replace("Inputs." + str(ring_array_idx) + ".", "Input.")}, inplace=True)
# Add to the new X and y dataframes
new_X.append(temp_df)
undersampled_y.append(ring_pos)
return new_X, undersampled_y
def undersample_processing_wrapper(y_X_min_occurrences_Occurrences):
"""
:param y_X_min_occurrences_Occurrences:
:return:
"""
return undersample_processing(*y_X_min_occurrences_Occurrences)
def undersample(X, y):
"""
@@ -615,43 +664,38 @@ def undersample(X, y):
# Find the smallest number of occurrences
min_occurrences = labels_distribution.most_common()[len(labels_distribution)-1][1]
print("Undersampling to " + str(min_occurrences) + " transactions per class.")
print("Undersampling to " + str(min_occurrences) + " transactions per class. A total of " + str(min_occurrences*NUM_RING_MEMBERS) + " transactions.")
#max_occurrences = labels_distribution.most_common(1)[0][1]
# Create a dictionary for all 11 spots in a ring signature
occurrences = {}
undersampled_y = []
new_X = []
with Manager() as manager:
# https://stackoverflow.com/questions/19887087/how-to-share-pandas-dataframe-object-between-processes
ns = manager.Namespace()
ns.X = X
# Create a dictionary for all 11 spots in a ring signature
occurrences = manager.dict()
for i in range(NUM_RING_MEMBERS):
occurrences[i+1] = 0
# Enumerate each index in the df and get the array of ring labels
for y_idx, ring_array in tqdm(enumerate(y), total=len(y), colour='blue', desc="Undersampling Dataset"):
# For each array of ring members iterate over each index
for ring_array_idx in range(len(ring_array["True_Ring_Pos"])):
# Get the true ring position (label) for the current iteration
ring_pos = int(ring_array["True_Ring_Pos"][ring_array_idx].split("/")[0])
total_rings = int(ring_array["True_Ring_Pos"][ring_array_idx].split("/")[1])
# Check to see if we hit the maximum number of labels for this position and that the
# number of ring members is what we expect.
if occurrences[ring_pos] < min_occurrences and total_rings == NUM_RING_MEMBERS:
occurrences[ring_pos] = occurrences[ring_pos] + 1
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.iloc.html#pandas.DataFrame.iloc
# Slice out the row from the dataframe but keep it as a dataframe
temp_df = X.iloc[[y_idx]]
# Go through each column name in the temp dataframe
for col_name in temp_df.columns:
# Check if the column name has data relating to irrelevant ring signatures
if "Inputs." in col_name and "." + str(ring_array_idx) + "." not in col_name:
# Delete the columns
temp_df = temp_df.drop([col_name], axis=1)
# Check if the column name is for the current ring signature
elif "Inputs." in col_name and "." + str(ring_array_idx) + "." in col_name:
# Rename the column such that it doesn't have the .0. or .1. positioning information
temp_df.rename(columns={col_name: col_name.replace("Inputs." + str(ring_array_idx) + ".", "Input.")}, inplace=True)
occurrences[i + 1] = 0
# Multiprocessing enriching each transaction
with manager.Pool(processes=NUM_PROCESSES) as pool:
for result in tqdm(pool.imap_unordered(func=undersample_processing_wrapper,
iterable=zip(
list(enumerate(y)),
repeat(ns, len(y)),
repeat(min_occurrences, len(y)),
repeat(occurrences, len(y))
)
),
desc="(Multiprocessing) Undersampling Dataset",
total=len(y),
colour='blue'
):
subset_new_X = result[0]
subset_undersampled_y = result[1]
# Add to the new X and y dataframes
new_X.append(temp_df)
undersampled_y.append(ring_pos)
new_X = new_X + subset_new_X
undersampled_y = undersampled_y + subset_undersampled_y
del X # Remove the old dataset to save RAM
collect() # Garbage collector
@@ -661,6 +705,9 @@ def undersample(X, y):
del new_X
collect() # Garbage collector
# Sometimes there is a race condition where a class will get +1 samples in the class ( most of the time this happens while debugging )
assert len(undersampled_X) == len(undersampled_y) == (min_occurrences * NUM_RING_MEMBERS)
# Shuffle the data one last time
undersampled_X, undersampled_y = shuffle(undersampled_X, undersampled_y)
undersampled_X, undersampled_y = shuffle(undersampled_X, undersampled_y)
@@ -686,10 +733,11 @@ def main():
# Create the dataset from files on disk #
###########################################
global data
print("Opening " + str(argv[1]) + "\n")
# Find where the wallets are stored and combine the exported files
print(blue + "Opening " + str(argv[1]) + reset + "\n")
# Find where the wallets are stored and combine the exported csv files
discover_wallet_directories(argv[1])
# Multiprocessing References
# https://leimao.github.io/blog/Python-tqdm-Multiprocessing/
# https://thebinarynotes.com/python-multiprocessing/
# https://docs.python.org/3/library/multiprocessing.html
@@ -702,14 +750,14 @@ def main():
data[tx_hash] = transaction_entry # Set the enriched version of the tx
# Save the raw database to disk
with open("dataset.pkl", "wb") as fp:
with open("./Dataset_Files/dataset.pkl", "wb") as fp:
pickle.dump(data, fp)
print("dataset.pkl written to disk!")
print("./Dataset_Files/dataset.pkl written to disk!")
#################################
# Remove Unnecessary Features #
#################################
with open("dataset.pkl", "rb") as fp:
with open("./Dataset_Files/dataset.pkl", "rb") as fp:
data = pickle.load(fp)
# Feature selection on raw dataset
X, y = create_feature_set(data)
@@ -717,34 +765,33 @@ def main():
collect() # Garbage collector
# Save data and labels to disk for future AI training
with open("X.pkl", "wb") as fp:
with open("./Dataset_Files/X.pkl", "wb") as fp:
pickle.dump(X, fp)
with open("y.pkl", "wb") as fp:
with open("./Dataset_Files/y.pkl", "wb") as fp:
pickle.dump(y, fp)
# Error checking; labels and data should be the same length
assert len(X) == len(y)
print("X.pkl and y.pkl written to disk!")
print("./Dataset_Files/X.pkl and ./Dataset_Files/y.pkl written to disk!")
###################
# Undersampling #
###################
with open("X.pkl", "rb") as fp:
with open("./Dataset_Files/X.pkl", "rb") as fp:
X = pickle.load(fp)
with open("y.pkl", "rb") as fp:
with open("./Dataset_Files/y.pkl", "rb") as fp:
y = pickle.load(fp)
print("Starting to undersample the dataset...")
X_Undersampled, y_Undersampled = undersample(X, y)
del X
collect() # Garbage collector
with open("X_Undersampled.pkl", "wb") as fp:
with open("./Dataset_Files/X_Undersampled.pkl", "wb") as fp:
pickle.dump(X_Undersampled, fp)
with open("y_Undersampled.pkl", "wb") as fp:
with open("./Dataset_Files/y_Undersampled.pkl", "wb") as fp:
pickle.dump(y_Undersampled, fp)
print("X_Undersampled.pkl and y_Undersampled.pkl written to disk!\nFinished")
print("./Dataset_Files/X_Undersampled.pkl and ./Dataset_Files/y_Undersampled.pkl written to disk!\nFinished")
if __name__ == '__main__':