aboutsummaryrefslogtreecommitdiffstats
path: root/scatter_plot.py
diff options
context:
space:
mode:
authorThomas Vanbesien <tvanbesi@proton.me>2026-04-01 17:42:04 +0200
committerThomas Vanbesien <tvanbesi@proton.me>2026-04-01 17:42:04 +0200
commit32cd9b2be1763f872c800b17e1fa63f852fe91c1 (patch)
tree8aee9bd7e81d8204faca701c0a852bcf7dc45de6 /scatter_plot.py
downloadDSLR-32cd9b2be1763f872c800b17e1fa63f852fe91c1.tar.gz
DSLR-32cd9b2be1763f872c800b17e1fa63f852fe91c1.zip
Import from github.comHEADmaster
Diffstat (limited to 'scatter_plot.py')
-rw-r--r--scatter_plot.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/scatter_plot.py b/scatter_plot.py
new file mode 100644
index 0000000..8cdfea2
--- /dev/null
+++ b/scatter_plot.py
@@ -0,0 +1,51 @@
+from module.constants import (
+ HOUSE_COLORS,
+ NUMERICAL_FEATURE_CSV_TITLES,
+ HOUSE_FEATURE_CSV_TITLE,
+)
+from module.dataset_manip import parse_csv
+import matplotlib.pyplot as plt
+import os
+import pandas as pd
+import sys
+
+if len(sys.argv) < 2:
+ print(f"Usage: python {__file__} <dataset.csv>")
+ exit(-1)
+
+# Get data from CSV file
+dataset_filename = sys.argv[1]
+data = parse_csv(
+ dataset_filename, NUMERICAL_FEATURE_CSV_TITLES, [HOUSE_FEATURE_CSV_TITLE]
+)
+df = pd.DataFrame(data)
+
+# Show a scatter plot for each combination of numerical feature
+numerical_feature_count = len(NUMERICAL_FEATURE_CSV_TITLES)
+for i in range(numerical_feature_count):
+ for j in range(i + 1, numerical_feature_count):
+ x_axis_feature = NUMERICAL_FEATURE_CSV_TITLES[i]
+ y_axis_feature = NUMERICAL_FEATURE_CSV_TITLES[j]
+ title = f"{x_axis_feature} vs {y_axis_feature} Scatter Plot"
+ fig, ax = plt.subplots()
+ for house in df[HOUSE_FEATURE_CSV_TITLE].dropna().unique():
+ house_df = df.loc[df[HOUSE_FEATURE_CSV_TITLE] == house]
+ ax.scatter(
+ house_df.loc[:, x_axis_feature],
+ house_df.loc[:, y_axis_feature],
+ alpha=0.8,
+ color=HOUSE_COLORS[house.lower()],
+ label=house,
+ s=10,
+ )
+ ax.set_title(title)
+ ax.set_xlabel(f"{x_axis_feature} Score")
+ ax.set_ylabel(f"{y_axis_feature} Score")
+ ax.legend()
+
+ # Save to png file
+ os.makedirs("output/scatter_plot", exist_ok=True)
+ save_filename = f"output/scatter_plot/{title}.png"
+ fig.savefig(save_filename)
+ plt.close(fig)
+ print(f"Saved {title} to {save_filename}")