import csv import os import matplotlib.pyplot as plt DATA_FILE = "data.csv" THETAS_FILE = "thetas.csv" def load_data(): km = [] price = [] with open(DATA_FILE) as f: reader = csv.reader(f) next(reader) for row in reader: km.append(float(row[0])) price.append(float(row[1])) return km, price def load_thetas(): if not os.path.exists(THETAS_FILE): return None, None with open(THETAS_FILE) as f: lines = f.readlines() return float(lines[0]), float(lines[1]) def main(): km, price = load_data() plt.scatter(km, price) theta0, theta1 = load_thetas() if theta0 is not None: x_line = [min(km), max(km)] y_line = [theta0 + theta1 * x for x in x_line] plt.plot(x_line, y_line, color="red") plt.xlabel("Mileage (km)") plt.ylabel("Price") plt.title("Car price vs mileage") plt.show() if __name__ == "__main__": main()