From b928b582c0f7da94a3be91f1ce722c5dbe9b3884 Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Mon, 12 Jan 2026 15:09:35 -0500 Subject: [PATCH] typing: adding types to utilities --- pyproject.toml | 2 +- stubs/pyomo/core/base/component.pyi | 1 + temoa/utilities/capacity_analyzer.py | 38 +++--- temoa/utilities/clear_db_outputs.py | 8 +- temoa/utilities/database_util.py | 17 ++- temoa/utilities/db_migration_to_v3.py | 30 +++-- temoa/utilities/db_migration_v3_1_to_v4.py | 5 +- temoa/utilities/db_migration_v3_to_v3_1.py | 121 +++++++++++--------- temoa/utilities/graph_utils.py | 10 +- temoa/utilities/run_all_v4_migrations.py | 2 +- temoa/utilities/sql_migration_v3_1_to_v4.py | 5 +- temoa/utilities/unit_cost_explorer.py | 69 ++++++++--- 12 files changed, 186 insertions(+), 122 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2810d00d..2ac6c8c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,7 +144,7 @@ python_version = "3.12" mypy_path = "stubs" # Exclude specific directories from type checking will try to add them back gradually -exclude = "(?x)(^temoa/extensions/|^temoa/utilities/|^stubs/)" +exclude = "(?x)(^temoa/extensions/|^stubs/)" # Strict typing for our own code disallow_untyped_defs = true diff --git a/stubs/pyomo/core/base/component.pyi b/stubs/pyomo/core/base/component.pyi index dcf415e8..4549744a 100644 --- a/stubs/pyomo/core/base/component.pyi +++ b/stubs/pyomo/core/base/component.pyi @@ -159,6 +159,7 @@ class ComponentData(ComponentBase): def __idiv__(self, other: typingAny) -> typingAny: ... def __itruediv__(self, other: typingAny) -> typingAny: ... def __ipow__(self, other: typingAny) -> typingAny: ... + def set_value(self, val, skip_validation: bool = False) -> None: ... class ActiveComponentData(ComponentData): def __init__(self, component) -> None: ... diff --git a/temoa/utilities/capacity_analyzer.py b/temoa/utilities/capacity_analyzer.py index 81e77c01..660955bf 100644 --- a/temoa/utilities/capacity_analyzer.py +++ b/temoa/utilities/capacity_analyzer.py @@ -5,10 +5,9 @@ """ import itertools -import os.path import sqlite3 +from typing import Any -from definitions import PROJECT_ROOT from matplotlib import pyplot as plt # Written by: J. F. Hyink @@ -17,34 +16,31 @@ # Created on: 7/18/23 -# filename of db to analyze... -db = 'US_9R_8D_CT500.sqlite' +# filepath of db to analyze... +source_db_file = 'US_9R_8D_CT500.sqlite' + -source_db_file = os.path.join(PROJECT_ROOT, 'data_files', 'untracked_data', db) print(source_db_file) -res = [] +res: list[Any] = [] try: - con = sqlite3.connect(source_db_file) - cur = con.cursor() - cur.execute('SELECT max_cap FROM max_capacity') - for row in cur: - res.append(row) + with sqlite3.connect(source_db_file) as con: + cur: sqlite3.Cursor = con.cursor() + cur.execute('SELECT max_cap FROM max_capacity') + for row in cur: + res.append(row) except sqlite3.Error as e: print(e) -finally: - con.close() - # chain them together into a list -caps = list(itertools.chain(*res)) +caps: list[float] = list(itertools.chain(*res)) cutoff = 1 # GW : An arbitrary cutoff between big and small capacity systems. -small_cap_sources = [c for c in caps if c <= cutoff] -large_cap_sources = [c for c in caps if c > cutoff] +small_cap_sources: list[float] = [c for c in caps if c <= cutoff] +large_cap_sources: list[float] = [c for c in caps if c > cutoff] -aggregate_small_cap = sum(small_cap_sources) -aggregate_large_cap = sum(large_cap_sources) +aggregate_small_cap: float = sum(small_cap_sources) +aggregate_large_cap: float = sum(large_cap_sources) print(f'{len(small_cap_sources)} small cap sources account for: {aggregate_small_cap: 0.1f} GW') print(f'{len(large_cap_sources)} large cap sources account for: {aggregate_large_cap: 0.1f} GW') @@ -56,8 +52,8 @@ # make a cumulative contribution plot, and find a 5% cutoff cutoff_num_sources = 0 caps.sort() -total_cap = sum(caps) -cumulative_caps = [ +total_cap: float = sum(caps) +cumulative_caps: list[float] = [ caps[0] / total_cap, ] for i, cap in enumerate(caps[1:]): diff --git a/temoa/utilities/clear_db_outputs.py b/temoa/utilities/clear_db_outputs.py index d8614f14..93da3cd9 100644 --- a/temoa/utilities/clear_db_outputs.py +++ b/temoa/utilities/clear_db_outputs.py @@ -7,7 +7,7 @@ import sys from pathlib import Path -basic_output_tables = [ +basic_output_tables: list[str] = [ 'output_built_capacity', 'output_cost', 'output_curtailment', @@ -19,15 +19,15 @@ 'output_objective', 'output_retired_capacity', ] -optional_output_tables = ['output_flow_out_summary', 'myopic_efficiency'] +optional_output_tables: list[str] = ['output_flow_out_summary', 'myopic_efficiency'] if len(sys.argv) != 2: print('this utility file expects a CLA for the path to the database to clear') sys.exit(-1) -target_db_str = sys.argv[1] +target_db_str: str = sys.argv[1] -proceed = input('This will clear ALL output tables in ' + target_db_str + '? (y/n): ') +proceed: str = input('This will clear ALL output tables in ' + target_db_str + '? (y/n): ') if proceed == 'y': target_db = Path(target_db_str) if not target_db.exists(): diff --git a/temoa/utilities/database_util.py b/temoa/utilities/database_util.py index b49002de..6c25028c 100644 --- a/temoa/utilities/database_util.py +++ b/temoa/utilities/database_util.py @@ -7,10 +7,13 @@ periods, and regions. """ +from __future__ import annotations + import os import re import sqlite3 from os import PathLike +from typing import Any, cast import deprecated import pandas as pd @@ -55,9 +58,9 @@ def __init__(self, database_path: str | PathLike[str], scenario: str | None = No def close(self) -> None: """Closes the database cursor and connection.""" - if self.cur: + if hasattr(self, 'cur') and self.cur: self.cur.close() - if self.con: + if hasattr(self, 'con') and self.con: self.con.close() @staticmethod @@ -115,6 +118,8 @@ def get_time_peridos_for_flags(self, flags: list[str] | None = None) -> set[int] query = f'SELECT period FROM time_period WHERE flag IN ({in_clause})' self.cur.execute(query) + # cast to int because sqlite might return strings or ints depending on how data was inserted + # but type hint says set[int] return {int(row[0]) for row in self.cur} def get_technologies_for_flags(self, flags: list[str] | None = None) -> set[str]: @@ -125,7 +130,7 @@ def get_technologies_for_flags(self, flags: list[str] | None = None) -> set[str] in_clause = ', '.join(f"'{flag}'" for flag in flags) query = f'SELECT tech FROM Technology WHERE flag IN ({in_clause})' - return {row[0] for row in self.cur.execute(query)} + return {cast('str', row[0]) for row in self.cur.execute(query)} def get_commodities_and_tech( self, inp_comm: str | None, inp_tech: str | None, region: str | None @@ -171,7 +176,7 @@ def get_commodities_for_flags(self, flags: list[str] | None = None) -> set[str]: in_clause = ', '.join(f"'{flag}'" for flag in flags) query = f'SELECT name FROM Commodity WHERE flag IN ({in_clause})' - return {row[0] for row in self.cur.execute(query)} + return {cast('str', row[0]) for row in self.cur.execute(query)} def get_commodities_by_technology( self, region: str | None, comm_type: str = 'input' @@ -187,11 +192,11 @@ def get_commodities_by_technology( if region: query += f" WHERE region LIKE '%{region}%'" - return {tuple(row) for row in self.cur.execute(query)} + return {cast('tuple[str, str]', row) for row in self.cur.execute(query)} def get_capacity_for_tech_and_period( self, tech: str | None = None, period: int | None = None, region: str | None = None - ) -> pd.DataFrame | pd.Series: + ) -> pd.DataFrame | pd.Series[Any]: """Retrieves capacity data, aggregated by technology.""" if not self.scenario: raise ValueError('A scenario must be set for output-related queries') diff --git a/temoa/utilities/db_migration_to_v3.py b/temoa/utilities/db_migration_to_v3.py index 19c7999e..b87a8475 100644 --- a/temoa/utilities/db_migration_to_v3.py +++ b/temoa/utilities/db_migration_to_v3.py @@ -12,6 +12,7 @@ import sys from collections import defaultdict from pathlib import Path +from typing import Any parser = argparse.ArgumentParser() parser.add_argument( @@ -36,9 +37,9 @@ new_db_name = legacy_db.stem + '_v3.sqlite' new_db_path = Path(legacy_db.parent, new_db_name) -con_old = sqlite3.connect(legacy_db) -con_new = sqlite3.connect(new_db_path) -cur = con_new.cursor() +con_old: sqlite3.Connection = sqlite3.connect(legacy_db) +con_new: sqlite3.Connection = sqlite3.connect(new_db_path) +cur: sqlite3.Cursor = con_new.cursor() # bring in the new schema and execute with open(schema_file) as src: @@ -50,7 +51,7 @@ # table mapping for DIRECT transfers # fmt: off -direct_transfer_tables = [ +direct_transfer_tables: list[tuple[str, str]] = [ ("", "CapacityCredit"), ("", "CapacityFactorProcess"), ("", "CapacityFactorTech"), @@ -105,14 +106,14 @@ ("SegFrac", "TimeSegmentFraction"), ] -units_added_tables = [ +units_added_tables: list[tuple[str, str]] = [ ("", "MaxActivityGroup"), ("", "MaxCapacityGroup"), ("", "MinCapacityGroup"), ("", "MinActivityGroup"), ] -sequence_added_tables = [ +sequence_added_tables: list[tuple[str, str]] = [ ("time_season", "TimeSeason"), ("time_periods", "time_period"), ("time_of_day", "TimeOfDay"), @@ -126,12 +127,16 @@ if old_name == '': old_name = new_name - new_columns = [c[1] for c in con_new.execute(f'PRAGMA table_info({new_name});').fetchall()] - old_columns = [c[1] for c in con_old.execute(f'PRAGMA table_info({old_name});').fetchall()] + new_columns: list[str] = [ + c[1] for c in con_new.execute(f'PRAGMA table_info({new_name});').fetchall() + ] + old_columns: list[str] = [ + c[1] for c in con_old.execute(f'PRAGMA table_info({old_name});').fetchall() + ] cols = str(old_columns[0 : len(new_columns)])[1:-1].replace("'", '') try: - data = con_old.execute(f'SELECT {cols} FROM {old_name}').fetchall() + data: list[Any] = con_old.execute(f'SELECT {cols} FROM {old_name}').fetchall() except sqlite3.OperationalError: print('TABLE NOT FOUND: ' + old_name) data = [] @@ -222,10 +227,11 @@ # let's ensure all the non-global entries are consistent (same techs in each region) skip_rps = False try: - rps_entries = con_old.execute('SELECT * FROM tech_rps').fetchall() + rps_entries: list[tuple[str, str, str]] = con_old.execute('SELECT * FROM tech_rps').fetchall() except sqlite3.OperationalError: print('source does not appear to include RPS techs...skipping') skip_rps = True + rps_entries = [] if not skip_rps: for region, tech, _notes in rps_entries: groups[region].add(tech) @@ -239,7 +245,7 @@ for group, techs in groups.items(): print(f'group: {group} mismatches: {common ^ techs}') if group != 'global': - techs_common &= not common ^ techs + techs_common &= not (common ^ techs) if not techs_common: print( 'combining RPS techs failed. Some regions are not same. Must be done ' @@ -357,7 +363,7 @@ data = con_old.execute(read_qry).fetchall() if unlim_cap_present: # need to convert null -> 0 for unlim_cap to match new schema that does not allow null - new_data = [] + new_data: list[Any] = [] for row in data: new_row = list(row) if new_row[4] is None: diff --git a/temoa/utilities/db_migration_v3_1_to_v4.py b/temoa/utilities/db_migration_v3_1_to_v4.py index 55e6e2e6..af6d4528 100644 --- a/temoa/utilities/db_migration_v3_1_to_v4.py +++ b/temoa/utilities/db_migration_v3_1_to_v4.py @@ -16,6 +16,7 @@ import re import sqlite3 from pathlib import Path +from typing import Any # ---------- Mapping configuration ---------- CUSTOM_MAP: dict[str, str] = { @@ -93,7 +94,7 @@ def map_token_no_cascade(token: str) -> str: return to_snake_case(token) -def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple]: +def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple[Any, ...]]: try: return conn.execute(f'PRAGMA table_info({table});').fetchall() except sqlite3.OperationalError: @@ -131,7 +132,7 @@ def migrate_direct_table( return len(filtered) -def migrate_all(args) -> None: +def migrate_all(args: argparse.Namespace) -> None: src = Path(args.source) schema = Path(args.schema) out = Path(args.out) if args.out else src.with_suffix('.v4.sqlite') diff --git a/temoa/utilities/db_migration_v3_to_v3_1.py b/temoa/utilities/db_migration_v3_to_v3_1.py index 63e7b7a8..44dc8344 100644 --- a/temoa/utilities/db_migration_v3_to_v3_1.py +++ b/temoa/utilities/db_migration_v3_to_v3_1.py @@ -3,20 +3,16 @@ """ import argparse -import os import sqlite3 import sys from pathlib import Path +from typing import Any import pandas as pd -from temoa.core.model import TemoaModel - # Just to get the default lifetime... -this_dir = os.path.dirname(__file__) -root_dir = os.path.abspath(os.path.join(this_dir, '../..')) -sys.path.append(root_dir) - +# Assumes temoa is installed in the environment +from temoa.core.model import TemoaModel parser = argparse.ArgumentParser() parser.add_argument( @@ -40,9 +36,9 @@ new_db_name = legacy_db.stem + '_v3_1.sqlite' new_db_path = Path(legacy_db.parent, new_db_name) -con_old = sqlite3.connect(legacy_db) -con_new = sqlite3.connect(new_db_path) -cur = con_new.cursor() +con_old: sqlite3.Connection = sqlite3.connect(legacy_db) +con_new: sqlite3.Connection = sqlite3.connect(new_db_path) +cur: sqlite3.Cursor = con_new.cursor() # bring in the new schema and execute with open(schema_file) as src: @@ -79,7 +75,7 @@ def column_check(old_name: str, new_name: str) -> bool: # table mapping for DIRECT transfers # fmt: off -direct_transfer_tables = [ +direct_transfer_tables: list[tuple[str, str]] = [ ("", "CapacityCredit"), ("", "CapacityToActivity"), ("", "Commodity"), @@ -114,7 +110,7 @@ def column_check(old_name: str, new_name: str) -> bool: ("", "TimePeriodType"), ] -period_added_tables = [ +period_added_tables: list[tuple[str, str]] = [ ("", "CapacityFactorProcess"), ("", "CapacityFactorTech"), ("", "DemandSpecificDistribution"), @@ -122,7 +118,7 @@ def column_check(old_name: str, new_name: str) -> bool: ("", "TimeSegmentFraction"), ] -operator_added_tables = { +operator_added_tables: dict[str, tuple[str, str]] = { "EmissionLimit": ("LimitEmission", "le"), "TechOutputSplit": ("LimitTechOutputSplit", "ge"), "TechInputSplitAnnual": ("LimitTechInputSplitAnnual", "ge"), @@ -153,7 +149,7 @@ def column_check(old_name: str, new_name: str) -> bool: "MaxResource": ("LimitResource", "le"), } -no_transfer = { +no_transfer: dict[str, str] = { "MinSeasonalActivity": "LimitSeasonalCapacityFactor", "MaxSeasonalActivity": "LimitSeasonalCapacityFactor", "StorageInit": "LimitStorageLevelFraction", @@ -174,12 +170,12 @@ def column_check(old_name: str, new_name: str) -> bool: for old_name, (new_name, operator) in operator_added_tables.items(): try: - data = con_old.execute(f"SELECT * FROM {old_name}").fetchall() + data_rows: list[Any] = con_old.execute(f"SELECT * FROM {old_name}").fetchall() except sqlite3.OperationalError: print("TABLE NOT FOUND: " + old_name) continue - if not data: + if not data_rows: print("No data for: " + old_name) continue @@ -187,14 +183,18 @@ def column_check(old_name: str, new_name: str) -> bool: c[1] for c in con_new.execute(f"PRAGMA table_info({new_name});").fetchall() ] op_index = new_cols.index("operator") - data = [(*row[0:op_index], operator, *row[op_index:len(new_cols)-1]) for row in data] + data_transformed = [ + (*row[0:op_index], operator, *row[op_index : len(new_cols) - 1]) for row in data_rows + ] # construct the query with correct number of placeholders - num_placeholders = len(data[0]) + if not data_transformed: + continue + num_placeholders = len(data_transformed[0]) placeholders = ",".join(["?" for _ in range(num_placeholders)]) query = f"INSERT OR REPLACE INTO {new_name} VALUES ({placeholders})" - con_new.executemany(query, data) - print(f"Transfered {len(data)} rows from {old_name} to {new_name}") + con_new.executemany(query, data_transformed) + print(f"Transfered {len(data_transformed)} rows from {old_name} to {new_name}") # It wasn't active anyway... can't be bothered # StorageInit -> LimitStorageLevelFraction @@ -212,45 +212,49 @@ def column_check(old_name: str, new_name: str) -> bool: continue old_columns = [c[1] for c in con_old.execute(f"PRAGMA table_info({old_name});").fetchall()] - new_columns = [c[1] for c in con_new.execute(f"PRAGMA table_info({new_name});").fetchall()] + new_columns = [ + c[1] for c in con_new.execute(f"PRAGMA table_info({new_name});").fetchall() + ] cols = [c for c in new_columns if c in old_columns] - data = con_old.execute(f'SELECT {str(cols)[1:-1].replace("'","")} FROM {old_name}').fetchall() + data_rows = con_old.execute( + f'SELECT {str(cols)[1:-1].replace("'","")} FROM {old_name}' + ).fetchall() - if not data: + if not data_rows: print("No data for: " + old_name) continue # construct the query with correct number of placeholders - num_placeholders = len(data[0]) + num_placeholders = len(data_rows[0]) placeholders = ",".join(["?" for _ in range(num_placeholders)]) query = ( "INSERT OR REPLACE INTO " f"{new_name}{tuple(c for c in cols) if len(cols)>1 else f'({cols[0]})'} " f"VALUES ({placeholders})" ) - con_new.executemany(query, data) - print(f"Transfered {len(data)} rows from {old_name} to {new_name}") + con_new.executemany(query, data_rows) + print(f"Transfered {len(data_rows)} rows from {old_name} to {new_name}") -time_all = [ +time_all: list[str] = [ p[0] for p in cur.execute("SELECT period FROM TimePeriod").fetchall() ] time_all = sorted(time_all)[0:-1] # Exclude horizon end # get lifetimes. Major headache but needs to be done -lifetime_process = {} -data = cur.execute("SELECT region, tech, vintage FROM Efficiency").fetchall() -for rtv in data: +lifetime_process: dict[tuple[Any, ...], float] = {} +data_rows = cur.execute("SELECT region, tech, vintage FROM Efficiency").fetchall() +for rtv in data_rows: lifetime_process[rtv] = TemoaModel.default_lifetime_tech -data = cur.execute("SELECT region, tech, lifetime FROM LifetimeTech").fetchall() -for rtl in data: +data_rows = cur.execute("SELECT region, tech, lifetime FROM LifetimeTech").fetchall() +for rtl in data_rows: for v in time_all: - lifetime_process[*rtl[0:2], v] = rtl[2] -data = cur.execute("SELECT region, tech, vintage, lifetime FROM LifetimeProcess").fetchall() -for rtvl in data: + lifetime_process[(*rtl[0:2], v)] = rtl[2] +data_rows = cur.execute("SELECT region, tech, vintage, lifetime FROM LifetimeProcess").fetchall() +for rtvl in data_rows: lifetime_process[rtvl[0:3]] = rtvl[3] # Planning periods to add to period indices -time_optimize = [ +time_optimize: list[str] = [ p[0] for p in cur.execute('SELECT period FROM TimePeriod WHERE flag == "f"').fetchall() ] time_optimize = sorted(time_optimize)[0:-1] # Exclude horizon end @@ -270,24 +274,28 @@ def column_check(old_name: str, new_name: str) -> bool: old_columns = [c[1] for c in con_old.execute(f"PRAGMA table_info({old_name});").fetchall()] new_columns = [c[1] for c in con_new.execute(f"PRAGMA table_info({new_name});").fetchall()] cols = [c for c in new_columns if c in old_columns] - data = pd.read_sql_query(f'SELECT {str(cols)[1:-1].replace("'","")} FROM {old_name}', con_old) - if len(data) == 0: + # Use pandas for complex logic + df_data: pd.DataFrame = pd.read_sql_query( + f'SELECT {str(cols)[1:-1].replace("'","")} FROM {old_name}', con_old + ) + + if len(df_data) == 0: print("No data for: " + old_name) continue # This insanity collects the viable periods for each table if "vintage" in cols: - data["periods"] = [ + df_data["periods"] = [ ( p for p in time_optimize if v <= p < v+lifetime_process[r, t, v] ) - for r, t, v in data[["region","tech","vintage"]] + for r, t, v in df_data[["region","tech","vintage"]].values ] elif "tech" in cols: - periods = {} - for r, t in data[["region","tech"]].drop_duplicates().values: + periods: dict[tuple[str, str], list[str]] = {} + for r, t in df_data[["region","tech"]].drop_duplicates().values: periods[r, t] = [ p for p in time_optimize if any( @@ -300,16 +308,16 @@ def column_check(old_name: str, new_name: str) -> bool: ] ) ] - data["periods"] = [ + df_data["periods"] = [ periods[r, t] - for (r, t) in data[["region","tech"]].values + for (r, t) in df_data[["region","tech"]].values ] else: - data["periods"] = [time_optimize for i in data.index] + df_data["periods"] = [time_optimize for i in df_data.index] - data_new = [] + data_new: list[tuple[Any, ...]] = [] for p in time_optimize: - for _idx, row in data.iterrows(): + for _idx, row in df_data.iterrows(): if p not in row["periods"]: continue if old_name[0:5] == "TimeS": # horrible but covers TimeSeason and TimeSegmentFraction @@ -323,6 +331,8 @@ def column_check(old_name: str, new_name: str) -> bool: cols = [cols[0],"period",*cols[1::]] # construct the query with correct number of placeholders + if not data_new: + continue num_placeholders = len(data_new[0]) placeholders = ",".join(["?" for _ in range(num_placeholders)]) query = ( @@ -360,15 +370,18 @@ def column_check(old_name: str, new_name: str) -> bool: # LoanLifetimeTech -> LoanLifetimeProcess try: - data = con_old.execute("SELECT region, tech, lifetime, notes FROM LoanLifetimeTech").fetchall() + data_rows = con_old.execute( + "SELECT region, tech, lifetime, notes FROM LoanLifetimeTech" + ).fetchall() except sqlite3.OperationalError: print("TABLE NOT FOUND: LoanLifetimeTech") + data_rows = [] -if not data: +if not data_rows: print("No data for: LoanLifetimeTech") else: - new_data = [] - for row in data: + new_data: list[Any] = [] + for row in data_rows: vints = [ v[0] for v in con_old.execute( @@ -407,13 +420,13 @@ def column_check(old_name: str, new_name: str) -> bool: con_new.execute("VACUUM;") con_new.execute("PRAGMA FOREIGN_KEYS=1;") try: - data = con_new.execute("PRAGMA FOREIGN_KEY_CHECK;").fetchall() - if not data: + data_rows = con_new.execute("PRAGMA FOREIGN_KEY_CHECK;").fetchall() + if not data_rows: print("No Foreign Key Failures. (Good news!)") else: print("\nFK check fails (MUST BE FIXED):") print("(Table, Row ID, Reference Table, (fkid) )") - for row in data: + for row in data_rows: print(row) except sqlite3.OperationalError as e: print("Foreign Key Check FAILED on new DB. Something may be wrong with schema.") diff --git a/temoa/utilities/graph_utils.py b/temoa/utilities/graph_utils.py index a1541f0a..46fc2abb 100644 --- a/temoa/utilities/graph_utils.py +++ b/temoa/utilities/graph_utils.py @@ -37,10 +37,16 @@ else: # At runtime, use the base types which are not subscripted. # The TypeVar still enforces that the graph type is one of these. - GraphType = TypeVar('GraphType', nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph) + GraphType = TypeVar( + 'GraphType', + nx.Graph, + nx.DiGraph, + nx.MultiGraph, + nx.MultiDiGraph, + ) -def convert_graph_to_json[GraphType: (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)]( +def convert_graph_to_json( # noqa: UP047 # avoiding runtime type error with networkx generics nx_graph: GraphType, override_node_properties: dict[str, Any] | None, override_edge_properties: dict[str, Any] | None, diff --git a/temoa/utilities/run_all_v4_migrations.py b/temoa/utilities/run_all_v4_migrations.py index dbac5920..4d28eadd 100644 --- a/temoa/utilities/run_all_v4_migrations.py +++ b/temoa/utilities/run_all_v4_migrations.py @@ -25,7 +25,7 @@ def run_command( cmd: list[str], cwd: Path | None = None, capture_output: bool = True -) -> subprocess.CompletedProcess: +) -> subprocess.CompletedProcess[str]: """Helper to run shell commands.""" print(f'Executing: {" ".join(cmd)}') result = subprocess.run(cmd, cwd=cwd, capture_output=capture_output, text=True, check=False) diff --git a/temoa/utilities/sql_migration_v3_1_to_v4.py b/temoa/utilities/sql_migration_v3_1_to_v4.py index f888d28e..992917bd 100644 --- a/temoa/utilities/sql_migration_v3_1_to_v4.py +++ b/temoa/utilities/sql_migration_v3_1_to_v4.py @@ -24,6 +24,7 @@ import re import sqlite3 import sys +from typing import Any # ------------------ Mapping configuration (mirror sqlite migrator) ------------------ CUSTOM_MAP: dict[str, str] = { @@ -116,14 +117,14 @@ def map_column_name(col: str) -> str: return mapped -def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple]: +def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple[Any, ...]]: try: return conn.execute(f'PRAGMA table_info({table});').fetchall() except sqlite3.OperationalError: return [] -def migrate_dump_to_sqlite(args) -> None: +def migrate_dump_to_sqlite(args: argparse.Namespace) -> None: # --- 1. Load v3.1 SQL dump into a temporary in-memory DB --- print(f'Loading v3.1 SQL dump from {args.input} into in-memory DB...') con_old_in_memory = sqlite3.connect(':memory:') diff --git a/temoa/utilities/unit_cost_explorer.py b/temoa/utilities/unit_cost_explorer.py index 89cbabcb..5520a3bf 100644 --- a/temoa/utilities/unit_cost_explorer.py +++ b/temoa/utilities/unit_cost_explorer.py @@ -3,12 +3,27 @@ of storage capacity """ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + from pyomo.environ import value from temoa.components.costs import total_cost_rule from temoa.components.storage import storage_energy_upper_bound_constraint from temoa.core.model import TemoaModel +if TYPE_CHECKING: + from temoa.types.core_types import ( + Period, + Region, + Season, + Technology, + TimeOfDay, + Vintage, + ) + + # Written by: J. F. Hyink # jeff@westernspark.us # https://westernspark.us @@ -29,8 +44,17 @@ # indices -rtv = ('A', 'battery', 2020) # rtv -rptv = ('A', 2020, 'battery', 2020) # rptv +rtv: tuple[Region, Technology, Vintage] = ( + cast('Region', 'A'), + cast('Technology', 'battery'), + cast('Vintage', 2020), +) # rtv +rptv: tuple[Region, Period, Technology, Vintage] = ( + cast('Region', 'A'), + cast('Period', 2020), + cast('Technology', 'battery'), + cast('Vintage', 2020), +) # rptv model.time_future.construct([2020, 2025, 2030]) # needs to go 1 period beyond optimize horizon model.time_optimize.construct([2020, 2025]) model.period_length.construct() @@ -60,35 +84,38 @@ model.lifetime_tech.construct(data={('A', 'battery'): 20}) model.lifetime_process.construct(data={rtv: 40}) # M.ModelProcessLife.construct(data={rptv: 20}) +# make/fix VARS model.global_discount_rate.construct(data={None: 0.05}) model.is_survival_curve_process[rtv] = False # make/fix VARS model.v_new_capacity.construct() -model.v_new_capacity[rtv].set_value(1) +v_new_capacity = model.v_new_capacity +v_new_capacity[rtv].set_value(1) model.v_capacity.construct() -model.v_capacity[rptv].set_value(1) +v_capacity = model.v_capacity +v_capacity[rptv].set_value(1) # run the total cost rule on our "model": tot_cost_expr = total_cost_rule(model) -total_cost = value(tot_cost_expr) +total_cost: float = value(tot_cost_expr) print() print(f'Total cost for building 1 capacity unit of storage: ${total_cost:0.2f} [$M]') print('The total cost expression:') print(tot_cost_expr) # how much storage achieved for 1 unit of capacity? -storage_cap = 1 # unit -storage_dur = 4 # hr -c2a = 31.536 # PJ/GW-yr -c = 1 / 8760 # yr/hr -storage = storage_cap * storage_dur * c2a * c -PJ_to_kwh = 1 / 3600000 * 1e15 +storage_cap: float = 1 # unit +storage_dur: float = 4 # hr +c2a: float = 31.536 # PJ/GW-yr +c: float = 1 / 8760 # yr/hr +storage: float = storage_cap * storage_dur * c2a * c +PJ_to_kwh: float = 1 / 3600000 * 1e15 print() print(f'storage built: {storage:0.4f} [PJ] / {(storage * PJ_to_kwh):0.2f} [kWh]') -price_per_kwh = total_cost * 1e6 / (storage * PJ_to_kwh) +price_per_kwh: float = total_cost * 1e6 / (storage * PJ_to_kwh) print(f'price_per_kwh: ${price_per_kwh: 0.2f}\n') # let's look at the constraint for storage level @@ -98,7 +125,7 @@ model.time_season_all.construct(['winter', 'summer']) model.time_season.construct(data={2020: {'winter', 'summer'}, 2025: {'winter', 'summer'}}) model.days_per_period.construct(data={None: 365}) -tod_slices = 2 +tod_slices: int = 2 model.time_of_day.construct(data=range(1, tod_slices + 1)) model.tech_storage.construct(data=['battery']) model.process_life_frac_rptv.construct(data=[rptv]) @@ -116,7 +143,7 @@ # More PARAMS model.capacity_to_activity.construct(data={('A', 'battery'): 31.536}) model.storage_duration.construct(data={('A', 'battery'): 4}) -seasonal_fractions = {'winter': 0.4, 'summer': 0.6} +seasonal_fractions: dict[str, float] = {'winter': 0.4, 'summer': 0.6} model.segment_fraction.construct( data={ (p, s, d): seasonal_fractions[s] / tod_slices @@ -133,12 +160,20 @@ model.v_storage_level.construct() model.segment_fraction_per_season.construct() -model.is_seasonal_storage['battery'] = False -upper_limit = storage_energy_upper_bound_constraint(model, 'A', 2020, 'winter', 1, 'battery', 2020) +model.is_seasonal_storage[cast('Technology', 'battery')] = False +upper_limit = storage_energy_upper_bound_constraint( + model, + cast('Region', 'A'), + cast('Period', 2020), + cast('Season', 'winter'), + cast('TimeOfDay', '1'), + cast('Technology', 'battery'), + cast('Vintage', 2020), +) print('The storage level constraint for the single period in the "super day":\n', upper_limit) # cross-check the multiplier... -mulitplier = ( +mulitplier: float = ( storage_dur * model.segment_fraction_per_season[2020, 'winter'] * model.days_per_period