from module.constants import ( HOUSE_COLORS, NUMERICAL_FEATURE_CSV_TITLES, HOUSE_FEATURE_CSV_TITLE, ) from module.dataset_manip import impute_mean, parse_csv import matplotlib.pyplot as plt import os import pandas as pd import sys if len(sys.argv) < 2: print(f"Usage: python {__file__} ") exit(-1) # Get data from CSV file filename = sys.argv[1] data = parse_csv(filename, NUMERICAL_FEATURE_CSV_TITLES, [HOUSE_FEATURE_CSV_TITLE]) df = pd.DataFrame(data) # Drop rows without a house label, then impute missing numerical values with the mean df = df.loc[df[HOUSE_FEATURE_CSV_TITLE].notna() & (df[HOUSE_FEATURE_CSV_TITLE] != "")] for feature in NUMERICAL_FEATURE_CSV_TITLES: df.loc[:, feature] = impute_mean(df[feature].tolist()) # Assign colors based on house color_map = {house.capitalize(): color for house, color in HOUSE_COLORS.items()} colors = df[HOUSE_FEATURE_CSV_TITLE].map(color_map) title = "Hogwarts Course Score Scatter Matrix" scatter_df = df[NUMERICAL_FEATURE_CSV_TITLES] assert isinstance(scatter_df, pd.DataFrame) axes = pd.plotting.scatter_matrix( scatter_df, figsize=(1920 * 2 / 100, 1080 * 2 / 100), alpha=0.5, color=colors, diagonal="hist", ) # Remove diagonal plots for i in range(len(NUMERICAL_FEATURE_CSV_TITLES)): axes[i, i].set_visible(False) plt.suptitle(title) # Save to png file os.makedirs("output/scatter_matrix", exist_ok=True) save_filename = f"output/scatter_matrix/{title}.png" plt.savefig(save_filename) plt.close() print(f"Saved {title} to {save_filename}")