aboutsummaryrefslogtreecommitdiffstats
path: root/train.py
diff options
context:
space:
mode:
authorThomas Vanbesien <tvanbesi@proton.me>2026-03-30 17:10:31 +0200
committerThomas Vanbesien <tvanbesi@proton.me>2026-03-30 17:22:03 +0200
commitb998b2cdfe454c9d177e06304c2c01c63747335c (patch)
tree4f55811de78a23dc67ca62a7da052beb47145c85 /train.py
parentfd5fe70ce5271f09303b51dae34b42acc47f5730 (diff)
downloadft_linear_regression-b998b2cdfe454c9d177e06304c2c01c63747335c.tar.gz
ft_linear_regression-b998b2cdfe454c9d177e06304c2c01c63747335c.zip
Rename train_once to compute_gradients, clean up normalize, document normalization in README
Diffstat (limited to 'train.py')
-rw-r--r--train.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/train.py b/train.py
index a9c865b..8b15547 100644
--- a/train.py
+++ b/train.py
@@ -8,7 +8,7 @@ THETAS_FILE = "thetas.csv"
def normalize(data):
min_val = min(data)
max_val = max(data)
- return [(x - min_val) / (max_val - min_val) for x in data], min_val, max_val
+ return [(x - min_val) / (max_val - min_val) for x in data]
def load_data():
@@ -28,7 +28,7 @@ def estimate_price(mileage, theta0, theta1):
# DV: dependant variable, IV: independant variable
-def train_once(learning_rate, DV, IV, theta0, theta1):
+def compute_gradients(learning_rate, DV, IV, theta0, theta1):
tmp0 = (
learning_rate
* (1.0 / len(DV))
@@ -52,12 +52,14 @@ def denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max):
def train(learning_rate, iterations):
kms, prices = load_data()
- kms_norm, km_min, km_max = normalize(kms)
- prices_norm, price_min, price_max = normalize(prices)
+ km_min, km_max = min(kms), max(kms)
+ price_min, price_max = min(prices), max(prices)
+ kms_norm = normalize(kms)
+ prices_norm = normalize(prices)
t0 = 0.0
t1 = 0.0
for _ in range(iterations):
- grad0, grad1 = train_once(learning_rate, prices_norm, kms_norm, t0, t1)
+ grad0, grad1 = compute_gradients(learning_rate, prices_norm, kms_norm, t0, t1)
t0 -= grad0
t1 -= grad1
return denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max)