aboutsummaryrefslogtreecommitdiffstats
path: root/visualize.py
blob: 90f3e158a270bdad9981cf1d9357ef12733864cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import csv
import os
import matplotlib.pyplot as plt

DATA_FILE = "data.csv"
THETAS_FILE = "thetas.csv"


def load_data():
    km = []
    price = []
    with open(DATA_FILE) as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            km.append(float(row[0]))
            price.append(float(row[1]))
    return km, price


def load_thetas():
    if not os.path.exists(THETAS_FILE):
        return None, None
    with open(THETAS_FILE) as f:
        lines = f.readlines()
        return float(lines[0]), float(lines[1])


def main():
    km, price = load_data()
    plt.scatter(km, price)
    theta0, theta1 = load_thetas()
    if theta0 is not None:
        x_line = [min(km), max(km)]
        y_line = [theta0 + theta1 * x for x in x_line]
        plt.plot(x_line, y_line, color="red")
    plt.xlabel("Mileage (km)")
    plt.ylabel("Price")
    plt.title("Car price vs mileage")
    plt.show()


if __name__ == "__main__":
    main()