aboutsummaryrefslogtreecommitdiffstats
path: root/describe.py
diff options
context:
space:
mode:
Diffstat (limited to 'describe.py')
-rw-r--r--describe.py91
1 files changed, 91 insertions, 0 deletions
diff --git a/describe.py b/describe.py
new file mode 100644
index 0000000..4a1c338
--- /dev/null
+++ b/describe.py
@@ -0,0 +1,91 @@
+import sys
+from module.constants import NUMERICAL_FEATURE_CSV_TITLES
+from module.dataset_manip import parse_csv
+from module.math import get_mean, get_std, get_min, get_max, get_quartiles
+
+
+def print_table(number_rows, h_headers, v_headers, decimal_precision=3):
+ """
+ Prints a table of numbers in a human-readable way on the standard output.
+
+ Parameters:
+ number_rows (list): A list of list of the numbers to print, one list for each row
+ h_headers (list): A list of strings of the horizontal table headers
+ v_headers (list): A list of strings of the vertical table headers
+ decimal_precision (int): The amount of decimal points to display for each number
+ """
+
+ # Get the column width needed to display the number whose string representation is the widest
+ # It includes the sign, digits before the decimal point, the decimal point and the decimal
+ # places
+ DECIMAL_PART_WIDTH = 1 + decimal_precision # Width of the dot + decimal places
+ max_column_width = 0
+ for l in number_rows:
+ for n in l:
+ n_width = 1 if n < 0 else 0 # Minus sign width
+ n_width += len(str(int(n))) # Width of digits before the decimal point
+ n_width += DECIMAL_PART_WIDTH
+ if n_width > max_column_width:
+ max_column_width = n_width
+
+ # Truncate headers to fit the maximum column width
+ truncated_h_headers = [s[:max_column_width] for s in h_headers]
+ truncated_v_headers = [s[:max_column_width] for s in v_headers]
+ # Replace last char with "." for each header if it was truncated
+ for i in range(len(h_headers)):
+ if len(h_headers[i]) != len(truncated_h_headers[i]):
+ truncated_h_headers[i] = truncated_h_headers[i][:-1] + "."
+ for i in range(len(v_headers)):
+ if len(v_headers[i]) != len(truncated_v_headers[i]):
+ truncated_v_headers[i] = truncated_v_headers[i][:-1] + "."
+
+ # Print the table
+ print(f"|{'':{max_column_width}}|", end="") # First empty cell
+ # Print headers
+ column_count = len(truncated_h_headers) # Not counting header column
+ for i in range(column_count):
+ print(f"{truncated_h_headers[i]:>{max_column_width}}|", end="")
+ print()
+ # Print rows
+ row_count = len(v_headers) # Not counting header row
+ for i in range(row_count):
+ # Print row header
+ print(f"|{truncated_v_headers[i]:{max_column_width}}|", end="")
+ row_numbers = number_rows[i]
+ for n in row_numbers:
+ print(f"{n:>{max_column_width}.{decimal_precision}f}|", end="")
+ print()
+
+
+if len(sys.argv) < 2:
+ print(f"Usage: python {__file__} <dataset.csv>")
+ exit(-1)
+
+# Get data from CSV file
+dataset_filename = sys.argv[1]
+data = parse_csv(dataset_filename, NUMERICAL_FEATURE_CSV_TITLES)
+# Remove None values from each feature vector
+for k in data.keys():
+ data[k] = [v for v in data[k] if v is not None]
+
+# Get horizontal headers
+features_names = list(data.keys())
+# Get vertical headers
+information_names = ["Count", "Mean", "Std", "Min", "25%", "50%", "75%", "Max"]
+# Get rows of data
+feature_value_lists = data.values()
+rows = []
+rows.append([len(l) for l in feature_value_lists])
+rows.append([get_mean(l) for l in feature_value_lists])
+rows.append([get_std(l) for l in feature_value_lists])
+rows.append([get_min(l) for l in feature_value_lists])
+q1_list, q2_list, q3_list = [], [], []
+for l in feature_value_lists:
+ quartiles = get_quartiles(l)
+ q1_list.append(quartiles[0])
+ q2_list.append(quartiles[1])
+ q3_list.append(quartiles[2])
+rows.extend([q1_list, q2_list, q3_list])
+rows.append([get_max(a) for a in feature_value_lists])
+
+print_table(rows, features_names, information_names)