diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | README.md | 54 | ||||
| -rw-r--r-- | data.csv | 25 | ||||
| -rw-r--r-- | docs/ft_linear_regression.pdf | bin | 0 -> 1353377 bytes | |||
| -rw-r--r-- | predict.py | 40 | ||||
| -rw-r--r-- | train.py | 75 | ||||
| -rw-r--r-- | visualize.py | 44 |
7 files changed, 239 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c8119b7 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +thetas.csv diff --git a/README.md b/README.md new file mode 100644 index 0000000..4d6833d --- /dev/null +++ b/README.md @@ -0,0 +1,54 @@ +# ft_linear_regression + +A simple linear regression implementation using gradient descent to predict car prices based on mileage. + +## Requirements + +- Python 3 +- matplotlib (for visualization only) + +``` +pip install matplotlib +``` + +## Usage + +### Train the model + +``` +python3 train.py <learning_rate> <iterations> +``` + +Example: +``` +python3 train.py 1.0 1000 +``` + +This trains the model on `data.csv` and saves the resulting parameters (θ0, θ1) to `thetas.csv`. + +### Predict a price + +``` +python3 predict.py +``` + +Prompts for a mileage value and outputs the estimated price. Loops until Ctrl+C. +If no trained model is found, θ0 and θ1 default to 0. + +### Visualize + +``` +python3 visualize.py +``` + +Displays a scatter plot of the dataset. If a trained model exists, the regression line is drawn on top. + +## How it works + +The model fits a linear function: + +``` +estimatePrice(mileage) = θ0 + θ1 * mileage +``` + +Parameters are found via gradient descent with min-max normalization on the input data. After training, thetas are denormalized so they work directly on raw mileage values. diff --git a/data.csv b/data.csv new file mode 100644 index 0000000..b875289 --- /dev/null +++ b/data.csv @@ -0,0 +1,25 @@ +km,price +240000,3650 +139800,3800 +150500,4400 +185530,4450 +176000,5250 +114800,5350 +166800,5800 +89000,5990 +144500,5999 +84000,6200 +82029,6390 +63060,6390 +74000,6600 +97500,6800 +67000,6800 +76025,6900 +48235,6900 +93000,6990 +60949,7490 +65674,7555 +54000,7990 +68500,7990 +22899,7990 +61789,8290 diff --git a/docs/ft_linear_regression.pdf b/docs/ft_linear_regression.pdf Binary files differnew file mode 100644 index 0000000..527e5eb --- /dev/null +++ b/docs/ft_linear_regression.pdf diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..3bbcb47 --- /dev/null +++ b/predict.py @@ -0,0 +1,40 @@ +import os +import sys + +THETAS_FILE = "thetas.csv" + + +def load_thetas(): + if not os.path.exists(THETAS_FILE): + return 0.0, 0.0 + with open(THETAS_FILE) as f: + lines = f.readlines() + return float(lines[0]), float(lines[1]) + + +def estimate_price(mileage, theta0, theta1): + return theta0 + theta1 * mileage + + +def main(): + theta0, theta1 = load_thetas() + if theta0 == 0.0 and theta1 == 0.0: + print(f"Warning: no trained model found ({THETAS_FILE}), using θ0=0 θ1=0") + else: + print(f"Loaded θ0={theta0}, θ1={theta1}") + while True: + try: + raw = input("\nMileage (km): ") + mileage = float(raw) + price = estimate_price(mileage, theta0, theta1) + print(f"Estimated price: {price:.2f}") + except ValueError: + print("Please enter a valid number.") + + +if __name__ == "__main__": + try: + main() + except (KeyboardInterrupt, EOFError): + print() + sys.exit(0) diff --git a/train.py b/train.py new file mode 100644 index 0000000..a9c865b --- /dev/null +++ b/train.py @@ -0,0 +1,75 @@ +import csv +import sys + +DATASET = "data.csv" +THETAS_FILE = "thetas.csv" + + +def normalize(data): + min_val = min(data) + max_val = max(data) + return [(x - min_val) / (max_val - min_val) for x in data], min_val, max_val + + +def load_data(): + km = [] + price = [] + with open(DATASET) as f: + reader = csv.reader(f) + next(reader) + for row in reader: + km.append(float(row[0])) + price.append(float(row[1])) + return km, price + + +def estimate_price(mileage, theta0, theta1): + return theta0 + theta1 * mileage + + +# DV: dependant variable, IV: independant variable +def train_once(learning_rate, DV, IV, theta0, theta1): + tmp0 = ( + learning_rate + * (1.0 / len(DV)) + * sum(estimate_price(x, theta0, theta1) - y for x, y in zip(DV, IV)) + ) + tmp1 = ( + learning_rate + * (1.0 / len(DV)) + * sum((estimate_price(x, theta0, theta1) - y) * x for x, y in zip(DV, IV)) + ) + return tmp0, tmp1 + + +def denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max): + price_range = price_max - price_min + km_range = km_max - km_min + real_t1 = t1 * price_range / km_range + real_t0 = t0 * price_range + price_min - real_t1 * km_min + return real_t0, real_t1 + + +def train(learning_rate, iterations): + kms, prices = load_data() + kms_norm, km_min, km_max = normalize(kms) + prices_norm, price_min, price_max = normalize(prices) + t0 = 0.0 + t1 = 0.0 + for _ in range(iterations): + grad0, grad1 = train_once(learning_rate, prices_norm, kms_norm, t0, t1) + t0 -= grad0 + t1 -= grad1 + return denormalize_thetas(t0, t1, km_min, km_max, price_min, price_max) + + +def save_thetas(theta0, theta1): + with open(THETAS_FILE, "w") as f: + f.write(f"{theta0}\n{theta1}\n") + + +learning_rate = float(sys.argv[1]) +i = int(sys.argv[2]) +t0, t1 = train(learning_rate, i) +save_thetas(t0, t1) +print(f"θ0={t0}, θ1={t1} saved to {THETAS_FILE}") diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..90f3e15 --- /dev/null +++ b/visualize.py @@ -0,0 +1,44 @@ +import csv +import os +import matplotlib.pyplot as plt + +DATA_FILE = "data.csv" +THETAS_FILE = "thetas.csv" + + +def load_data(): + km = [] + price = [] + with open(DATA_FILE) as f: + reader = csv.reader(f) + next(reader) + for row in reader: + km.append(float(row[0])) + price.append(float(row[1])) + return km, price + + +def load_thetas(): + if not os.path.exists(THETAS_FILE): + return None, None + with open(THETAS_FILE) as f: + lines = f.readlines() + return float(lines[0]), float(lines[1]) + + +def main(): + km, price = load_data() + plt.scatter(km, price) + theta0, theta1 = load_thetas() + if theta0 is not None: + x_line = [min(km), max(km)] + y_line = [theta0 + theta1 * x for x in x_line] + plt.plot(x_line, y_line, color="red") + plt.xlabel("Mileage (km)") + plt.ylabel("Price") + plt.title("Car price vs mileage") + plt.show() + + +if __name__ == "__main__": + main() |
