From fd5fe70ce5271f09303b51dae34b42acc47f5730 Mon Sep 17 00:00:00 2001 From: Thomas Vanbesien Date: Mon, 23 Mar 2026 21:17:11 +0100 Subject: Initial commit: linear regression for car price prediction Training, prediction, and visualization programs using gradient descent with min-max normalization. --- visualize.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 visualize.py (limited to 'visualize.py') diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..90f3e15 --- /dev/null +++ b/visualize.py @@ -0,0 +1,44 @@ +import csv +import os +import matplotlib.pyplot as plt + +DATA_FILE = "data.csv" +THETAS_FILE = "thetas.csv" + + +def load_data(): + km = [] + price = [] + with open(DATA_FILE) as f: + reader = csv.reader(f) + next(reader) + for row in reader: + km.append(float(row[0])) + price.append(float(row[1])) + return km, price + + +def load_thetas(): + if not os.path.exists(THETAS_FILE): + return None, None + with open(THETAS_FILE) as f: + lines = f.readlines() + return float(lines[0]), float(lines[1]) + + +def main(): + km, price = load_data() + plt.scatter(km, price) + theta0, theta1 = load_thetas() + if theta0 is not None: + x_line = [min(km), max(km)] + y_line = [theta0 + theta1 * x for x in x_line] + plt.plot(x_line, y_line, color="red") + plt.xlabel("Mileage (km)") + plt.ylabel("Price") + plt.title("Car price vs mileage") + plt.show() + + +if __name__ == "__main__": + main() -- cgit v1.2.3