diff --git a/src/tlo/util.py b/src/tlo/util.py index efe17a9920..5b88fcd5ca 100644 --- a/src/tlo/util.py +++ b/src/tlo/util.py @@ -571,3 +571,45 @@ def parse_csv_values_for_columns_with_mixed_datatypes(value: Any): except (ValueError, SyntaxError): pass return value # Return as a string if no other type fits + +def scale_up_population_dataframe(df: DataFrame, initial_pop: int, census_pop: int) -> DataFrame: + """ + Scales up numeric values in a given DataFrame based on a ratio of census population to + initial population. + + This function takes an input DataFrame, a specified initial population, and a census + population, and scales the numeric columns in the DataFrame by a factor determined + from the ratio of the census population to the initial population. It is designed to + work with numeric data and will raise an error if no numeric columns are found. + + :param df: The input DataFrame containing numeric data to be scaled. + :param initial_pop: Initial population value. Must be a positive number. + :param census_pop: Census population value. Must be a positive number. + :return: A copy of the input DataFrame with numeric columns scaled proportionally + by the calculated scale factor. + :raises ValueError: If `initial_pop` or `census_pop` is not a positive number, or + if no numeric columns are found in the input DataFrame. + """ + + # --- Input validation --- + if not isinstance(initial_pop, (int, float)) or initial_pop <= 0: + raise ValueError("initial_pop must be a positive number") + if not isinstance(census_pop, (int, float)) or census_pop <= 0: + raise ValueError("census_pop must be a positive number") + + scaled_df = df.copy() + + # --- Ensure date is index when available --- + if scaled_df.index.name != "date" and "date" in scaled_df.columns: + scaled_df = scaled_df.set_index("date") + + # --- Compute scale factor --- + scale_factor = census_pop / initial_pop + + # --- Scale numeric values only --- + numeric = scaled_df.select_dtypes(include="number") + if numeric.shape[1] == 0: + raise ValueError("No numeric columns found to scale.") + + scaled_df.loc[:, numeric.columns] = numeric.mul(scale_factor) + return scaled_df diff --git a/tests/test_utils.py b/tests/test_utils.py index ee9426f607..b38f51b00d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,6 +20,7 @@ convert_excel_files_to_csv, parse_csv_values_for_columns_with_mixed_datatypes, read_csv_files, + scale_up_population_dataframe, ) path_to_files = Path(os.path.dirname(__file__)) @@ -523,3 +524,111 @@ def test_parse_values_in_mixed_datatypes_columns(): # confirm value data type is now as expected for _index, exp_dtype in enumerate(exp_dtypes): assert isinstance(mixed_data_df.loc[_index, "param_values"], exp_dtype) + +def test_scale_up_population_dataframe(): + """ + Test function for verifying the behavior and correctness of + `scale_up_population_dataframe`. This test suite includes both + positive and negative cases to ensure the function handles + different configurations of input dataframes and census population + inputs correctly. + + The test validates several scenarios: + - Flat columns with specified index. + - MultiIndex columns with specified index. + - Cases when the census population is less than or equal to + the initial population. + - DataFrames where the date is a column instead of the index. + + Each scenario ensures that the scaling is accurately performed, and + key properties like the equality of indices, column structures and + value-based correctness are tested. + + :raises AssertionError: If the function does not behave as expected. + """ + import pandas.testing as pdt + + init_pop_size = 1_000_000 + males_per = 0.48 + date = pd.to_datetime(["2010-01-01"])[0] + + # ========================================================= + # Positive: Flat columns (date already index) + # ========================================================= + census_pop = 1_500_000 + scale_factor = census_pop / init_pop_size + + df_flat = pd.DataFrame( + {"M": males_per * init_pop_size, "F": (1 - males_per) * init_pop_size}, + index=pd.to_datetime(["2010-01-01"]), + ) + df_flat.index.name = "date" + + scaled_flat = scale_up_population_dataframe(df_flat, init_pop_size, census_pop) + + assert not df_flat.equals(scaled_flat) + assert (scaled_flat > df_flat).all().all() + pdt.assert_frame_equal(scaled_flat, df_flat * scale_factor) + + # ========================================================= + # Positive: MultiIndex columns (date already index) + # ========================================================= + sex = ["F", "M"] + is_alive = [False, True] + columns = pd.MultiIndex.from_product([sex, is_alive], names=["sex", "is_alive"]) + + df_mi = pd.DataFrame( + np.zeros((1, len(columns))), + index=pd.Index(pd.to_datetime(["2010-01-01"]), name="date"), + columns=columns, + ) + # non-zero cells to enable strict comparisons + df_mi.loc[date, ("F", True)] = (1 - males_per) * init_pop_size + df_mi.loc[date, ("M", True)] = males_per * init_pop_size + + scaled_mi = scale_up_population_dataframe(df_mi, init_pop_size, census_pop) + + assert isinstance(df_mi.columns, pd.MultiIndex) + assert scaled_mi.columns.equals(df_mi.columns) + assert scaled_mi.index.equals(df_mi.index) + + assert not df_mi.equals(scaled_mi) + + nonzero_mask = df_mi.to_numpy() != 0 + assert (scaled_mi.to_numpy()[nonzero_mask] > df_mi.to_numpy()[nonzero_mask]).all() + + zero_mask = ~nonzero_mask + assert (scaled_mi.to_numpy()[zero_mask] == 0).all() + + pdt.assert_frame_equal(scaled_mi, df_mi * scale_factor) + + # ========================================================= + # Negative: census <= initial (should NOT satisfy "scaled > original") + # ========================================================= + census_pop_small = init_pop_size # equal => scale_factor = 1 + scaled_equal = scale_up_population_dataframe(df_flat, init_pop_size, census_pop_small) + assert scaled_equal.equals(df_flat) # exactly the same + assert not (scaled_equal > df_flat).all().all() + + census_pop_smaller = 900_000 # smaller => scale_factor < 1 + scaled_down = scale_up_population_dataframe(df_flat, init_pop_size, census_pop_smaller) + assert not (scaled_down > df_flat).all().all() + assert (scaled_down < df_flat).all().all() + + # ========================================================= + # Date-not-index-but-in-columns: function should set index + # ========================================================= + df_date_col = pd.DataFrame( + { + "date": pd.to_datetime(["2010-01-01"]), + "M": [males_per * init_pop_size], + "F": [(1 - males_per) * init_pop_size], + } + ) + + scaled_date_col = scale_up_population_dataframe(df_date_col, init_pop_size, census_pop) + + assert scaled_date_col.index.name == "date" + assert "date" not in scaled_date_col.columns # moved to index + pdt.assert_frame_equal(scaled_date_col, df_flat * scale_factor) +