diff options
Diffstat (limited to 'scatter_plot.py')
| -rw-r--r-- | scatter_plot.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/scatter_plot.py b/scatter_plot.py new file mode 100644 index 0000000..8cdfea2 --- /dev/null +++ b/scatter_plot.py @@ -0,0 +1,51 @@ +from module.constants import ( + HOUSE_COLORS, + NUMERICAL_FEATURE_CSV_TITLES, + HOUSE_FEATURE_CSV_TITLE, +) +from module.dataset_manip import parse_csv +import matplotlib.pyplot as plt +import os +import pandas as pd +import sys + +if len(sys.argv) < 2: + print(f"Usage: python {__file__} <dataset.csv>") + exit(-1) + +# Get data from CSV file +dataset_filename = sys.argv[1] +data = parse_csv( + dataset_filename, NUMERICAL_FEATURE_CSV_TITLES, [HOUSE_FEATURE_CSV_TITLE] +) +df = pd.DataFrame(data) + +# Show a scatter plot for each combination of numerical feature +numerical_feature_count = len(NUMERICAL_FEATURE_CSV_TITLES) +for i in range(numerical_feature_count): + for j in range(i + 1, numerical_feature_count): + x_axis_feature = NUMERICAL_FEATURE_CSV_TITLES[i] + y_axis_feature = NUMERICAL_FEATURE_CSV_TITLES[j] + title = f"{x_axis_feature} vs {y_axis_feature} Scatter Plot" + fig, ax = plt.subplots() + for house in df[HOUSE_FEATURE_CSV_TITLE].dropna().unique(): + house_df = df.loc[df[HOUSE_FEATURE_CSV_TITLE] == house] + ax.scatter( + house_df.loc[:, x_axis_feature], + house_df.loc[:, y_axis_feature], + alpha=0.8, + color=HOUSE_COLORS[house.lower()], + label=house, + s=10, + ) + ax.set_title(title) + ax.set_xlabel(f"{x_axis_feature} Score") + ax.set_ylabel(f"{y_axis_feature} Score") + ax.legend() + + # Save to png file + os.makedirs("output/scatter_plot", exist_ok=True) + save_filename = f"output/scatter_plot/{title}.png" + fig.savefig(save_filename) + plt.close(fig) + print(f"Saved {title} to {save_filename}") |
