diff --git a/chart_utils.py b/chart_utils.py index 6f2db0c..3f9b8ef 100644 --- a/chart_utils.py +++ b/chart_utils.py @@ -16,8 +16,10 @@ def set_yaxis_cash(plot): plot.yaxis[0].formatter = bm.NumeralTickFormatter(format="€0") -def get_categorical_stats_plot(df, *, category): - df = get_categorical_stats(df, category, "Vuositulot") +def get_categorical_stats_plot(df, *, category, na_as_category=None): + df = get_categorical_stats( + df, category, "Vuositulot", na_as_category=na_as_category + ) df.reset_index(inplace=True) df[category] = df[category].astype("category") plot = bp.figure( diff --git a/data_utils.py b/data_utils.py index 9a4bf3b..b14ae62 100644 --- a/data_utils.py +++ b/data_utils.py @@ -1,13 +1,23 @@ +from typing import Optional + import pandas as pd def get_categorical_stats( - df: pd.DataFrame, category_col: str, value_col: str + df: pd.DataFrame, + category_col: str, + value_col: str, + *, + na_as_category: Optional[str] = None, ) -> pd.DataFrame: # Drop records where value is not numeric before grouping... df = df.copy() df[value_col] = pd.to_numeric(df[value_col], errors="coerce") df = df[df[value_col].notna() & df[value_col] > 0] + if na_as_category: + df[category_col] = df[category_col].astype("string") + df.loc[df[category_col].isna(), category_col] = na_as_category + df[category_col] = df[category_col].astype("category") # ... then carry on. group = df[[category_col, value_col]].groupby(category_col) - return group[value_col].agg(["mean", "min", "max", "median"]) + return group[value_col].agg(["mean", "min", "max", "median", "count"])