diff options
Diffstat (limited to 'scatter_matrix.py')
| -rw-r--r-- | scatter_matrix.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/scatter_matrix.py b/scatter_matrix.py new file mode 100644 index 0000000..67ec0ab --- /dev/null +++ b/scatter_matrix.py @@ -0,0 +1,52 @@ +from module.constants import ( + HOUSE_COLORS, + NUMERICAL_FEATURE_CSV_TITLES, + HOUSE_FEATURE_CSV_TITLE, +) +from module.dataset_manip import impute_mean, parse_csv +import matplotlib.pyplot as plt +import os +import pandas as pd +import sys + +if len(sys.argv) < 2: + print(f"Usage: python {__file__} <dataset.csv>") + exit(-1) + +# Get data from CSV file +filename = sys.argv[1] +data = parse_csv(filename, NUMERICAL_FEATURE_CSV_TITLES, [HOUSE_FEATURE_CSV_TITLE]) +df = pd.DataFrame(data) + +# Drop rows without a house label, then impute missing numerical values with the mean +df = df.loc[df[HOUSE_FEATURE_CSV_TITLE].notna() & (df[HOUSE_FEATURE_CSV_TITLE] != "")] +for feature in NUMERICAL_FEATURE_CSV_TITLES: + df.loc[:, feature] = impute_mean(df[feature].tolist()) + +# Assign colors based on house +color_map = {house.capitalize(): color for house, color in HOUSE_COLORS.items()} +colors = df[HOUSE_FEATURE_CSV_TITLE].map(color_map) + +title = "Hogwarts Course Score Scatter Matrix" +scatter_df = df[NUMERICAL_FEATURE_CSV_TITLES] +assert isinstance(scatter_df, pd.DataFrame) +axes = pd.plotting.scatter_matrix( + scatter_df, + figsize=(1920 * 2 / 100, 1080 * 2 / 100), + alpha=0.5, + color=colors, + diagonal="hist", +) + +# Remove diagonal plots +for i in range(len(NUMERICAL_FEATURE_CSV_TITLES)): + axes[i, i].set_visible(False) + +plt.suptitle(title) + +# Save to png file +os.makedirs("output/scatter_matrix", exist_ok=True) +save_filename = f"output/scatter_matrix/{title}.png" +plt.savefig(save_filename) +plt.close() +print(f"Saved {title} to {save_filename}") |
