mirror of
https://github.com/MAGICGrants/Monero-Dataset-Pipeline.git
synced 2026-01-09 13:37:57 -05:00
Added multiprocessing of the undersampling process
This commit is contained in:
4
Dataset_Files/.gitignore
vendored
Normal file
4
Dataset_Files/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
||||
@@ -450,7 +450,7 @@ def discover_wallet_directories(dir_to_search):
|
||||
num_bad_txs = 0
|
||||
pool = Pool(processes=NUM_PROCESSES) # Multiprocessing pool
|
||||
# Multiprocess combining the 6 csv files for each wallet
|
||||
for wallet_tx_data in tqdm(pool.imap_unordered(func=combine_files, iterable=Wallet_info), desc="(Multiprocessing) Combining Exported Wallet Transactions", total=len(Wallet_info), colour='blue'):
|
||||
for wallet_tx_data in tqdm(pool.imap_unordered(func=combine_files, iterable=Wallet_info), desc="(Multiprocessing) Combining Exported Wallet Files", total=len(Wallet_info), colour='blue'):
|
||||
# Make sure there are transactions in the data before adding it to the dataset
|
||||
for tx_hash, tx_data in wallet_tx_data.items():
|
||||
if "Input_True_Rings" in tx_data.keys():
|
||||
@@ -459,7 +459,7 @@ def discover_wallet_directories(dir_to_search):
|
||||
else:
|
||||
num_bad_txs += 1
|
||||
print("There were " + str(num_bad_txs) + " bad transactions that were deleted out of a total " + str(total_txs) + " transactions!")
|
||||
print("Dataset now includes " + str(len(data)) + " transactions.")
|
||||
print("The dataset now includes " + str(len(data)) + " transactions.")
|
||||
|
||||
|
||||
def clean_transaction(transaction):
|
||||
@@ -526,15 +526,15 @@ def create_feature_set(database):
|
||||
Valid_Transactions = []
|
||||
num_errors = 0
|
||||
feature_set = dict()
|
||||
num_of_valid_txs = 0
|
||||
# Iterate through each tx hash
|
||||
num_of_valid_txs = 0 # Incrementer which doesn't count invalid txs
|
||||
# Iterate through each tx hash in the database dict
|
||||
for idx, tx_hash in tqdm(enumerate(database.keys()), total=len(database), colour='blue', desc="Cleaning Transactions"):
|
||||
# Pass the transaction ( by reference ) to be stripped of non-features and receive the labels back
|
||||
try:
|
||||
private_info = clean_transaction(database[tx_hash])
|
||||
except Exception as e:
|
||||
num_errors += 1
|
||||
continue
|
||||
continue # Dont process the tx and loop
|
||||
# add tx hash to good list
|
||||
#Valid_Transactions.append(DataFrame(CherryPicker(database[tx_hash]).flatten(delim='.').get(), index=[idx]))
|
||||
# Flatten each transaction and iterate over each feature
|
||||
@@ -550,7 +550,7 @@ def create_feature_set(database):
|
||||
else: # If the feature is already in the feature set
|
||||
# Check if there are any transactions that did not have this feature
|
||||
if len(feature_set[k]) < num_of_valid_txs:
|
||||
# Add -1 for those occurences
|
||||
# Add -1 for those occurrences
|
||||
for i in range(num_of_valid_txs-len(feature_set[k])-1):
|
||||
feature_set[k].append(-1)
|
||||
# Append the feature
|
||||
@@ -588,6 +588,55 @@ def create_feature_set(database):
|
||||
return feature_set_df, labels
|
||||
|
||||
|
||||
def undersample_processing(y, ns, min_occurrences, occurrences):
|
||||
"""
|
||||
|
||||
:param y:
|
||||
:param ns:
|
||||
:param min_occurrences:
|
||||
:param occurrences:
|
||||
:return:
|
||||
"""
|
||||
undersampled_y = []
|
||||
new_X = []
|
||||
y_idx, ring_array = y
|
||||
# For each array of ring members iterate over each index
|
||||
for ring_array_idx in range(len(ring_array["True_Ring_Pos"])):
|
||||
# Get the true ring position (label) for the current iteration
|
||||
ring_pos = int(ring_array["True_Ring_Pos"][ring_array_idx].split("/")[0])
|
||||
total_rings = int(ring_array["True_Ring_Pos"][ring_array_idx].split("/")[1])
|
||||
# Check to see if we hit the maximum number of labels for this position and that the
|
||||
# number of ring members is what we expect.
|
||||
if occurrences[ring_pos] < min_occurrences and total_rings == NUM_RING_MEMBERS:
|
||||
occurrences[ring_pos] = occurrences[ring_pos] + 1
|
||||
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.iloc.html#pandas.DataFrame.iloc
|
||||
# Slice out the row from the dataframe but keep it as a dataframe
|
||||
temp_df = ns.X.iloc[[y_idx]]
|
||||
# Go through each column name in the temp dataframe
|
||||
for col_name in temp_df.columns:
|
||||
# Check if the column name has data relating to irrelevant ring signatures
|
||||
if "Inputs." in col_name and "." + str(ring_array_idx) + "." not in col_name:
|
||||
# Delete the columns
|
||||
temp_df = temp_df.drop([col_name], axis=1)
|
||||
# Check if the column name is for the current ring signature
|
||||
elif "Inputs." in col_name and "." + str(ring_array_idx) + "." in col_name:
|
||||
# Rename the column such that it doesn't have the .0. or .1. positioning information
|
||||
temp_df.rename(columns={col_name: col_name.replace("Inputs." + str(ring_array_idx) + ".", "Input.")}, inplace=True)
|
||||
# Add to the new X and y dataframes
|
||||
new_X.append(temp_df)
|
||||
undersampled_y.append(ring_pos)
|
||||
return new_X, undersampled_y
|
||||
|
||||
|
||||
def undersample_processing_wrapper(y_X_min_occurrences_Occurrences):
|
||||
"""
|
||||
|
||||
:param y_X_min_occurrences_Occurrences:
|
||||
:return:
|
||||
"""
|
||||
return undersample_processing(*y_X_min_occurrences_Occurrences)
|
||||
|
||||
|
||||
def undersample(X, y):
|
||||
"""
|
||||
|
||||
@@ -615,43 +664,38 @@ def undersample(X, y):
|
||||
|
||||
# Find the smallest number of occurrences
|
||||
min_occurrences = labels_distribution.most_common()[len(labels_distribution)-1][1]
|
||||
print("Undersampling to " + str(min_occurrences) + " transactions per class.")
|
||||
print("Undersampling to " + str(min_occurrences) + " transactions per class. A total of " + str(min_occurrences*NUM_RING_MEMBERS) + " transactions.")
|
||||
#max_occurrences = labels_distribution.most_common(1)[0][1]
|
||||
|
||||
# Create a dictionary for all 11 spots in a ring signature
|
||||
occurrences = {}
|
||||
undersampled_y = []
|
||||
new_X = []
|
||||
with Manager() as manager:
|
||||
# https://stackoverflow.com/questions/19887087/how-to-share-pandas-dataframe-object-between-processes
|
||||
ns = manager.Namespace()
|
||||
ns.X = X
|
||||
# Create a dictionary for all 11 spots in a ring signature
|
||||
occurrences = manager.dict()
|
||||
for i in range(NUM_RING_MEMBERS):
|
||||
occurrences[i + 1] = 0
|
||||
|
||||
# Enumerate each index in the df and get the array of ring labels
|
||||
for y_idx, ring_array in tqdm(enumerate(y), total=len(y), colour='blue', desc="Undersampling Dataset"):
|
||||
# For each array of ring members iterate over each index
|
||||
for ring_array_idx in range(len(ring_array["True_Ring_Pos"])):
|
||||
# Get the true ring position (label) for the current iteration
|
||||
ring_pos = int(ring_array["True_Ring_Pos"][ring_array_idx].split("/")[0])
|
||||
total_rings = int(ring_array["True_Ring_Pos"][ring_array_idx].split("/")[1])
|
||||
# Check to see if we hit the maximum number of labels for this position and that the
|
||||
# number of ring members is what we expect.
|
||||
if occurrences[ring_pos] < min_occurrences and total_rings == NUM_RING_MEMBERS:
|
||||
occurrences[ring_pos] = occurrences[ring_pos] + 1
|
||||
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.iloc.html#pandas.DataFrame.iloc
|
||||
# Slice out the row from the dataframe but keep it as a dataframe
|
||||
temp_df = X.iloc[[y_idx]]
|
||||
# Go through each column name in the temp dataframe
|
||||
for col_name in temp_df.columns:
|
||||
# Check if the column name has data relating to irrelevant ring signatures
|
||||
if "Inputs." in col_name and "." + str(ring_array_idx) + "." not in col_name:
|
||||
# Delete the columns
|
||||
temp_df = temp_df.drop([col_name], axis=1)
|
||||
# Check if the column name is for the current ring signature
|
||||
elif "Inputs." in col_name and "." + str(ring_array_idx) + "." in col_name:
|
||||
# Rename the column such that it doesn't have the .0. or .1. positioning information
|
||||
temp_df.rename(columns={col_name: col_name.replace("Inputs." + str(ring_array_idx) + ".", "Input.")}, inplace=True)
|
||||
# Multiprocessing enriching each transaction
|
||||
with manager.Pool(processes=NUM_PROCESSES) as pool:
|
||||
for result in tqdm(pool.imap_unordered(func=undersample_processing_wrapper,
|
||||
iterable=zip(
|
||||
list(enumerate(y)),
|
||||
repeat(ns, len(y)),
|
||||
repeat(min_occurrences, len(y)),
|
||||
repeat(occurrences, len(y))
|
||||
)
|
||||
),
|
||||
desc="(Multiprocessing) Undersampling Dataset",
|
||||
total=len(y),
|
||||
colour='blue'
|
||||
):
|
||||
subset_new_X = result[0]
|
||||
subset_undersampled_y = result[1]
|
||||
# Add to the new X and y dataframes
|
||||
new_X.append(temp_df)
|
||||
undersampled_y.append(ring_pos)
|
||||
new_X = new_X + subset_new_X
|
||||
undersampled_y = undersampled_y + subset_undersampled_y
|
||||
|
||||
del X # Remove the old dataset to save RAM
|
||||
collect() # Garbage collector
|
||||
@@ -661,6 +705,9 @@ def undersample(X, y):
|
||||
del new_X
|
||||
collect() # Garbage collector
|
||||
|
||||
# Sometimes there is a race condition where a class will get +1 samples in the class ( most of the time this happens while debugging )
|
||||
assert len(undersampled_X) == len(undersampled_y) == (min_occurrences * NUM_RING_MEMBERS)
|
||||
|
||||
# Shuffle the data one last time
|
||||
undersampled_X, undersampled_y = shuffle(undersampled_X, undersampled_y)
|
||||
undersampled_X, undersampled_y = shuffle(undersampled_X, undersampled_y)
|
||||
@@ -686,10 +733,11 @@ def main():
|
||||
# Create the dataset from files on disk #
|
||||
###########################################
|
||||
global data
|
||||
print("Opening " + str(argv[1]) + "\n")
|
||||
# Find where the wallets are stored and combine the exported files
|
||||
print(blue + "Opening " + str(argv[1]) + reset + "\n")
|
||||
# Find where the wallets are stored and combine the exported csv files
|
||||
discover_wallet_directories(argv[1])
|
||||
|
||||
# Multiprocessing References
|
||||
# https://leimao.github.io/blog/Python-tqdm-Multiprocessing/
|
||||
# https://thebinarynotes.com/python-multiprocessing/
|
||||
# https://docs.python.org/3/library/multiprocessing.html
|
||||
@@ -702,14 +750,14 @@ def main():
|
||||
data[tx_hash] = transaction_entry # Set the enriched version of the tx
|
||||
|
||||
# Save the raw database to disk
|
||||
with open("dataset.pkl", "wb") as fp:
|
||||
with open("./Dataset_Files/dataset.pkl", "wb") as fp:
|
||||
pickle.dump(data, fp)
|
||||
print("dataset.pkl written to disk!")
|
||||
print("./Dataset_Files/dataset.pkl written to disk!")
|
||||
|
||||
#################################
|
||||
# Remove Unnecessary Features #
|
||||
#################################
|
||||
with open("dataset.pkl", "rb") as fp:
|
||||
with open("./Dataset_Files/dataset.pkl", "rb") as fp:
|
||||
data = pickle.load(fp)
|
||||
# Feature selection on raw dataset
|
||||
X, y = create_feature_set(data)
|
||||
@@ -717,34 +765,33 @@ def main():
|
||||
collect() # Garbage collector
|
||||
|
||||
# Save data and labels to disk for future AI training
|
||||
with open("X.pkl", "wb") as fp:
|
||||
with open("./Dataset_Files/X.pkl", "wb") as fp:
|
||||
pickle.dump(X, fp)
|
||||
with open("y.pkl", "wb") as fp:
|
||||
with open("./Dataset_Files/y.pkl", "wb") as fp:
|
||||
pickle.dump(y, fp)
|
||||
# Error checking; labels and data should be the same length
|
||||
assert len(X) == len(y)
|
||||
print("X.pkl and y.pkl written to disk!")
|
||||
print("./Dataset_Files/X.pkl and ./Dataset_Files/y.pkl written to disk!")
|
||||
|
||||
###################
|
||||
# Undersampling #
|
||||
###################
|
||||
|
||||
with open("X.pkl", "rb") as fp:
|
||||
with open("./Dataset_Files/X.pkl", "rb") as fp:
|
||||
X = pickle.load(fp)
|
||||
with open("y.pkl", "rb") as fp:
|
||||
with open("./Dataset_Files/y.pkl", "rb") as fp:
|
||||
y = pickle.load(fp)
|
||||
|
||||
print("Starting to undersample the dataset...")
|
||||
X_Undersampled, y_Undersampled = undersample(X, y)
|
||||
del X
|
||||
collect() # Garbage collector
|
||||
|
||||
with open("X_Undersampled.pkl", "wb") as fp:
|
||||
with open("./Dataset_Files/X_Undersampled.pkl", "wb") as fp:
|
||||
pickle.dump(X_Undersampled, fp)
|
||||
with open("y_Undersampled.pkl", "wb") as fp:
|
||||
with open("./Dataset_Files/y_Undersampled.pkl", "wb") as fp:
|
||||
pickle.dump(y_Undersampled, fp)
|
||||
|
||||
print("X_Undersampled.pkl and y_Undersampled.pkl written to disk!\nFinished")
|
||||
print("./Dataset_Files/X_Undersampled.pkl and ./Dataset_Files/y_Undersampled.pkl written to disk!\nFinished")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user