aboutsummaryrefslogtreecommitdiffstats
path: root/train.py
blob: a9c865b1947649c4493effd2cbfe70e3ccc676f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import csv
import sys

DATASET = "data.csv"
THETAS_FILE = "thetas.csv"


def normalize(data):
    min_val = min(data)
    max_val = max(data)
    return [(x - min_val) / (max_val - min_val) for x in data], min_val, max_val


def load_data():
    km = []
    price = []
    with open(DATASET) as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            km.append(float(row[0]))
            price.append(float(row[1]))
    return km, price


def estimate_price(mileage, theta0, theta1):
    return theta0 + theta1 * mileage


# DV: dependant variable, IV: independant variable
def train_once(learning_rate, DV, IV, theta0, theta1):
    tmp0 = (
        learning_rate
        * (1.0 / len(DV))
        * sum(estimate_price(x, theta0, theta1) - y for x, y in zip(DV, IV))
    )
    tmp1 = (
        learning_rate
        * (1.0 / len(DV))
        * sum((estimate_price(x, theta0, theta1) - y) * x for x, y in zip(DV, IV))
    )
    return tmp0, tmp1


def denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max):
    price_range = price_max - price_min
    km_range = km_max - km_min
    real_t1 = t1 * price_range / km_range
    real_t0 = t0 * price_range + price_min - real_t1 * km_min
    return real_t0, real_t1


def train(learning_rate, iterations):
    kms, prices = load_data()
    kms_norm, km_min, km_max = normalize(kms)
    prices_norm, price_min, price_max = normalize(prices)
    t0 = 0.0
    t1 = 0.0
    for _ in range(iterations):
        grad0, grad1 = train_once(learning_rate, prices_norm, kms_norm, t0, t1)
        t0 -= grad0
        t1 -= grad1
    return denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max)


def save_thetas(theta0, theta1):
    with open(THETAS_FILE, "w") as f:
        f.write(f"{theta0}\n{theta1}\n")


learning_rate = float(sys.argv[1])
i = int(sys.argv[2])
t0, t1 = train(learning_rate, i)
save_thetas(t0, t1)
print(f"θ0={t0}, θ1={t1} saved to {THETAS_FILE}")