aboutsummaryrefslogtreecommitdiffstats
path: root/scatter_plot.py
blob: 8cdfea2c9aaebc3529efa2f38c2d757882464ba9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from module.constants import (
    HOUSE_COLORS,
    NUMERICAL_FEATURE_CSV_TITLES,
    HOUSE_FEATURE_CSV_TITLE,
)
from module.dataset_manip import parse_csv
import matplotlib.pyplot as plt
import os
import pandas as pd
import sys

if len(sys.argv) < 2:
    print(f"Usage: python {__file__} <dataset.csv>")
    exit(-1)

# Get data from CSV file
dataset_filename = sys.argv[1]
data = parse_csv(
    dataset_filename, NUMERICAL_FEATURE_CSV_TITLES, [HOUSE_FEATURE_CSV_TITLE]
)
df = pd.DataFrame(data)

# Show a scatter plot for each combination of numerical feature
numerical_feature_count = len(NUMERICAL_FEATURE_CSV_TITLES)
for i in range(numerical_feature_count):
    for j in range(i + 1, numerical_feature_count):
        x_axis_feature = NUMERICAL_FEATURE_CSV_TITLES[i]
        y_axis_feature = NUMERICAL_FEATURE_CSV_TITLES[j]
        title = f"{x_axis_feature} vs {y_axis_feature} Scatter Plot"
        fig, ax = plt.subplots()
        for house in df[HOUSE_FEATURE_CSV_TITLE].dropna().unique():
            house_df = df.loc[df[HOUSE_FEATURE_CSV_TITLE] == house]
            ax.scatter(
                house_df.loc[:, x_axis_feature],
                house_df.loc[:, y_axis_feature],
                alpha=0.8,
                color=HOUSE_COLORS[house.lower()],
                label=house,
                s=10,
            )
        ax.set_title(title)
        ax.set_xlabel(f"{x_axis_feature} Score")
        ax.set_ylabel(f"{y_axis_feature} Score")
        ax.legend()

        # Save to png file
        os.makedirs("output/scatter_plot", exist_ok=True)
        save_filename = f"output/scatter_plot/{title}.png"
        fig.savefig(save_filename)
        plt.close(fig)
        print(f"Saved {title} to {save_filename}")