aboutsummaryrefslogtreecommitdiffstats
path: root/scatter_matrix.py
blob: 67ec0ab193a2c1ea390fe568d4b9316e7fa1e810 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from module.constants import (
    HOUSE_COLORS,
    NUMERICAL_FEATURE_CSV_TITLES,
    HOUSE_FEATURE_CSV_TITLE,
)
from module.dataset_manip import impute_mean, parse_csv
import matplotlib.pyplot as plt
import os
import pandas as pd
import sys

if len(sys.argv) < 2:
    print(f"Usage: python {__file__} <dataset.csv>")
    exit(-1)

# Get data from CSV file
filename = sys.argv[1]
data = parse_csv(filename, NUMERICAL_FEATURE_CSV_TITLES, [HOUSE_FEATURE_CSV_TITLE])
df = pd.DataFrame(data)

# Drop rows without a house label, then impute missing numerical values with the mean
df = df.loc[df[HOUSE_FEATURE_CSV_TITLE].notna() & (df[HOUSE_FEATURE_CSV_TITLE] != "")]
for feature in NUMERICAL_FEATURE_CSV_TITLES:
    df.loc[:, feature] = impute_mean(df[feature].tolist())

# Assign colors based on house
color_map = {house.capitalize(): color for house, color in HOUSE_COLORS.items()}
colors = df[HOUSE_FEATURE_CSV_TITLE].map(color_map)

title = "Hogwarts Course Score Scatter Matrix"
scatter_df = df[NUMERICAL_FEATURE_CSV_TITLES]
assert isinstance(scatter_df, pd.DataFrame)
axes = pd.plotting.scatter_matrix(
    scatter_df,
    figsize=(1920 * 2 / 100, 1080 * 2 / 100),
    alpha=0.5,
    color=colors,
    diagonal="hist",
)

# Remove diagonal plots
for i in range(len(NUMERICAL_FEATURE_CSV_TITLES)):
    axes[i, i].set_visible(False)

plt.suptitle(title)

# Save to png file
os.makedirs("output/scatter_matrix", exist_ok=True)
save_filename = f"output/scatter_matrix/{title}.png"
plt.savefig(save_filename)
plt.close()
print(f"Saved {title} to {save_filename}")