aboutsummaryrefslogtreecommitdiffstats
path: root/scatter_matrix.py
diff options
context:
space:
mode:
authorThomas Vanbesien <tvanbesi@proton.me>2026-04-01 17:42:04 +0200
committerThomas Vanbesien <tvanbesi@proton.me>2026-04-01 17:42:04 +0200
commit32cd9b2be1763f872c800b17e1fa63f852fe91c1 (patch)
tree8aee9bd7e81d8204faca701c0a852bcf7dc45de6 /scatter_matrix.py
downloadDSLR-32cd9b2be1763f872c800b17e1fa63f852fe91c1.tar.gz
DSLR-32cd9b2be1763f872c800b17e1fa63f852fe91c1.zip
Import from github.comHEADmaster
Diffstat (limited to 'scatter_matrix.py')
-rw-r--r--scatter_matrix.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/scatter_matrix.py b/scatter_matrix.py
new file mode 100644
index 0000000..67ec0ab
--- /dev/null
+++ b/scatter_matrix.py
@@ -0,0 +1,52 @@
+from module.constants import (
+ HOUSE_COLORS,
+ NUMERICAL_FEATURE_CSV_TITLES,
+ HOUSE_FEATURE_CSV_TITLE,
+)
+from module.dataset_manip import impute_mean, parse_csv
+import matplotlib.pyplot as plt
+import os
+import pandas as pd
+import sys
+
+if len(sys.argv) < 2:
+ print(f"Usage: python {__file__} <dataset.csv>")
+ exit(-1)
+
+# Get data from CSV file
+filename = sys.argv[1]
+data = parse_csv(filename, NUMERICAL_FEATURE_CSV_TITLES, [HOUSE_FEATURE_CSV_TITLE])
+df = pd.DataFrame(data)
+
+# Drop rows without a house label, then impute missing numerical values with the mean
+df = df.loc[df[HOUSE_FEATURE_CSV_TITLE].notna() & (df[HOUSE_FEATURE_CSV_TITLE] != "")]
+for feature in NUMERICAL_FEATURE_CSV_TITLES:
+ df.loc[:, feature] = impute_mean(df[feature].tolist())
+
+# Assign colors based on house
+color_map = {house.capitalize(): color for house, color in HOUSE_COLORS.items()}
+colors = df[HOUSE_FEATURE_CSV_TITLE].map(color_map)
+
+title = "Hogwarts Course Score Scatter Matrix"
+scatter_df = df[NUMERICAL_FEATURE_CSV_TITLES]
+assert isinstance(scatter_df, pd.DataFrame)
+axes = pd.plotting.scatter_matrix(
+ scatter_df,
+ figsize=(1920 * 2 / 100, 1080 * 2 / 100),
+ alpha=0.5,
+ color=colors,
+ diagonal="hist",
+)
+
+# Remove diagonal plots
+for i in range(len(NUMERICAL_FEATURE_CSV_TITLES)):
+ axes[i, i].set_visible(False)
+
+plt.suptitle(title)
+
+# Save to png file
+os.makedirs("output/scatter_matrix", exist_ok=True)
+save_filename = f"output/scatter_matrix/{title}.png"
+plt.savefig(save_filename)
+plt.close()
+print(f"Saved {title} to {save_filename}")