aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--README.md54
-rw-r--r--data.csv25
-rw-r--r--docs/ft_linear_regression.pdfbin0 -> 1353377 bytes
-rw-r--r--predict.py40
-rw-r--r--train.py75
-rw-r--r--visualize.py44
7 files changed, 239 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..c8119b7
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+thetas.csv
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..4d6833d
--- /dev/null
+++ b/README.md
@@ -0,0 +1,54 @@
+# ft_linear_regression
+
+A simple linear regression implementation using gradient descent to predict car prices based on mileage.
+
+## Requirements
+
+- Python 3
+- matplotlib (for visualization only)
+
+```
+pip install matplotlib
+```
+
+## Usage
+
+### Train the model
+
+```
+python3 train.py <learning_rate> <iterations>
+```
+
+Example:
+```
+python3 train.py 1.0 1000
+```
+
+This trains the model on `data.csv` and saves the resulting parameters (θ0, θ1) to `thetas.csv`.
+
+### Predict a price
+
+```
+python3 predict.py
+```
+
+Prompts for a mileage value and outputs the estimated price. Loops until Ctrl+C.
+If no trained model is found, θ0 and θ1 default to 0.
+
+### Visualize
+
+```
+python3 visualize.py
+```
+
+Displays a scatter plot of the dataset. If a trained model exists, the regression line is drawn on top.
+
+## How it works
+
+The model fits a linear function:
+
+```
+estimatePrice(mileage) = θ0 + θ1 * mileage
+```
+
+Parameters are found via gradient descent with min-max normalization on the input data. After training, thetas are denormalized so they work directly on raw mileage values.
diff --git a/data.csv b/data.csv
new file mode 100644
index 0000000..b875289
--- /dev/null
+++ b/data.csv
@@ -0,0 +1,25 @@
+km,price
+240000,3650
+139800,3800
+150500,4400
+185530,4450
+176000,5250
+114800,5350
+166800,5800
+89000,5990
+144500,5999
+84000,6200
+82029,6390
+63060,6390
+74000,6600
+97500,6800
+67000,6800
+76025,6900
+48235,6900
+93000,6990
+60949,7490
+65674,7555
+54000,7990
+68500,7990
+22899,7990
+61789,8290
diff --git a/docs/ft_linear_regression.pdf b/docs/ft_linear_regression.pdf
new file mode 100644
index 0000000..527e5eb
--- /dev/null
+++ b/docs/ft_linear_regression.pdf
Binary files differ
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..3bbcb47
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,40 @@
+import os
+import sys
+
+THETAS_FILE = "thetas.csv"
+
+
+def load_thetas():
+ if not os.path.exists(THETAS_FILE):
+ return 0.0, 0.0
+ with open(THETAS_FILE) as f:
+ lines = f.readlines()
+ return float(lines[0]), float(lines[1])
+
+
+def estimate_price(mileage, theta0, theta1):
+ return theta0 + theta1 * mileage
+
+
+def main():
+ theta0, theta1 = load_thetas()
+ if theta0 == 0.0 and theta1 == 0.0:
+ print(f"Warning: no trained model found ({THETAS_FILE}), using θ0=0 θ1=0")
+ else:
+ print(f"Loaded θ0={theta0}, θ1={theta1}")
+ while True:
+ try:
+ raw = input("\nMileage (km): ")
+ mileage = float(raw)
+ price = estimate_price(mileage, theta0, theta1)
+ print(f"Estimated price: {price:.2f}")
+ except ValueError:
+ print("Please enter a valid number.")
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except (KeyboardInterrupt, EOFError):
+ print()
+ sys.exit(0)
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..a9c865b
--- /dev/null
+++ b/train.py
@@ -0,0 +1,75 @@
+import csv
+import sys
+
+DATASET = "data.csv"
+THETAS_FILE = "thetas.csv"
+
+
+def normalize(data):
+ min_val = min(data)
+ max_val = max(data)
+ return [(x - min_val) / (max_val - min_val) for x in data], min_val, max_val
+
+
+def load_data():
+ km = []
+ price = []
+ with open(DATASET) as f:
+ reader = csv.reader(f)
+ next(reader)
+ for row in reader:
+ km.append(float(row[0]))
+ price.append(float(row[1]))
+ return km, price
+
+
+def estimate_price(mileage, theta0, theta1):
+ return theta0 + theta1 * mileage
+
+
+# DV: dependant variable, IV: independant variable
+def train_once(learning_rate, DV, IV, theta0, theta1):
+ tmp0 = (
+ learning_rate
+ * (1.0 / len(DV))
+ * sum(estimate_price(x, theta0, theta1) - y for x, y in zip(DV, IV))
+ )
+ tmp1 = (
+ learning_rate
+ * (1.0 / len(DV))
+ * sum((estimate_price(x, theta0, theta1) - y) * x for x, y in zip(DV, IV))
+ )
+ return tmp0, tmp1
+
+
+def denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max):
+ price_range = price_max - price_min
+ km_range = km_max - km_min
+ real_t1 = t1 * price_range / km_range
+ real_t0 = t0 * price_range + price_min - real_t1 * km_min
+ return real_t0, real_t1
+
+
+def train(learning_rate, iterations):
+ kms, prices = load_data()
+ kms_norm, km_min, km_max = normalize(kms)
+ prices_norm, price_min, price_max = normalize(prices)
+ t0 = 0.0
+ t1 = 0.0
+ for _ in range(iterations):
+ grad0, grad1 = train_once(learning_rate, prices_norm, kms_norm, t0, t1)
+ t0 -= grad0
+ t1 -= grad1
+ return denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max)
+
+
+def save_thetas(theta0, theta1):
+ with open(THETAS_FILE, "w") as f:
+ f.write(f"{theta0}\n{theta1}\n")
+
+
+learning_rate = float(sys.argv[1])
+i = int(sys.argv[2])
+t0, t1 = train(learning_rate, i)
+save_thetas(t0, t1)
+print(f"θ0={t0}, θ1={t1} saved to {THETAS_FILE}")
diff --git a/visualize.py b/visualize.py
new file mode 100644
index 0000000..90f3e15
--- /dev/null
+++ b/visualize.py
@@ -0,0 +1,44 @@
+import csv
+import os
+import matplotlib.pyplot as plt
+
+DATA_FILE = "data.csv"
+THETAS_FILE = "thetas.csv"
+
+
+def load_data():
+ km = []
+ price = []
+ with open(DATA_FILE) as f:
+ reader = csv.reader(f)
+ next(reader)
+ for row in reader:
+ km.append(float(row[0]))
+ price.append(float(row[1]))
+ return km, price
+
+
+def load_thetas():
+ if not os.path.exists(THETAS_FILE):
+ return None, None
+ with open(THETAS_FILE) as f:
+ lines = f.readlines()
+ return float(lines[0]), float(lines[1])
+
+
+def main():
+ km, price = load_data()
+ plt.scatter(km, price)
+ theta0, theta1 = load_thetas()
+ if theta0 is not None:
+ x_line = [min(km), max(km)]
+ y_line = [theta0 + theta1 * x for x in x_line]
+ plt.plot(x_line, y_line, color="red")
+ plt.xlabel("Mileage (km)")
+ plt.ylabel("Price")
+ plt.title("Car price vs mileage")
+ plt.show()
+
+
+if __name__ == "__main__":
+ main()