Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ python_version = "3.12"
mypy_path = "stubs"

# Exclude specific directories from type checking will try to add them back gradually
exclude = "(?x)(^temoa/extensions/|^temoa/utilities/|^stubs/)"
exclude = "(?x)(^temoa/extensions/|^stubs/)"

# Strict typing for our own code
disallow_untyped_defs = true
Expand Down
1 change: 1 addition & 0 deletions stubs/pyomo/core/base/component.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class ComponentData(ComponentBase):
def __idiv__(self, other: typingAny) -> typingAny: ...
def __itruediv__(self, other: typingAny) -> typingAny: ...
def __ipow__(self, other: typingAny) -> typingAny: ...
def set_value(self, val, skip_validation: bool = False) -> None: ...

class ActiveComponentData(ComponentData):
def __init__(self, component) -> None: ...
Expand Down
38 changes: 17 additions & 21 deletions temoa/utilities/capacity_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
"""

import itertools
import os.path
import sqlite3
from typing import Any

from definitions import PROJECT_ROOT
from matplotlib import pyplot as plt

# Written by: J. F. Hyink
Expand All @@ -17,34 +16,31 @@

# Created on: 7/18/23

# filename of db to analyze...
db = 'US_9R_8D_CT500.sqlite'
# filepath of db to analyze...
source_db_file = 'US_9R_8D_CT500.sqlite'


source_db_file = os.path.join(PROJECT_ROOT, 'data_files', 'untracked_data', db)
print(source_db_file)
res = []
res: list[Any] = []
try:
con = sqlite3.connect(source_db_file)
cur = con.cursor()
cur.execute('SELECT max_cap FROM max_capacity')
for row in cur:
res.append(row)
with sqlite3.connect(source_db_file) as con:
cur: sqlite3.Cursor = con.cursor()
cur.execute('SELECT max_cap FROM max_capacity')
for row in cur:
res.append(row)

except sqlite3.Error as e:
print(e)

finally:
con.close()

# chain them together into a list
caps = list(itertools.chain(*res))
caps: list[float] = list(itertools.chain(*res))

cutoff = 1 # GW : An arbitrary cutoff between big and small capacity systems.
small_cap_sources = [c for c in caps if c <= cutoff]
large_cap_sources = [c for c in caps if c > cutoff]
small_cap_sources: list[float] = [c for c in caps if c <= cutoff]
large_cap_sources: list[float] = [c for c in caps if c > cutoff]

aggregate_small_cap = sum(small_cap_sources)
aggregate_large_cap = sum(large_cap_sources)
aggregate_small_cap: float = sum(small_cap_sources)
aggregate_large_cap: float = sum(large_cap_sources)

print(f'{len(small_cap_sources)} small cap sources account for: {aggregate_small_cap: 0.1f} GW')
print(f'{len(large_cap_sources)} large cap sources account for: {aggregate_large_cap: 0.1f} GW')
Expand All @@ -56,8 +52,8 @@
# make a cumulative contribution plot, and find a 5% cutoff
cutoff_num_sources = 0
caps.sort()
total_cap = sum(caps)
cumulative_caps = [
total_cap: float = sum(caps)
cumulative_caps: list[float] = [
caps[0] / total_cap,
]
Comment on lines +55 to 58
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential division by zero if caps is empty.

If caps is empty after filtering, total_cap will be 0, causing a ZeroDivisionError on line 57.

Suggested guard
 caps.sort()
 total_cap: float = sum(caps)
+if not caps or total_cap == 0:
+    print('No capacity data to analyze')
+    sys.exit(0)
 cumulative_caps: list[float] = [
     caps[0] / total_cap,
 ]
🤖 Prompt for AI Agents
In @temoa/utilities/capacity_analyzer.py around lines 55 - 58, Guard against an
empty or all-zero caps list before computing normalized cumulative capacities:
check if caps is empty or if total_cap := sum(caps) is 0 and handle by returning
an appropriate empty/zero result (e.g., return [] or cumulative_caps filled with
zeros) instead of performing caps[0] / total_cap; update the code around the
variables caps, total_cap, and cumulative_caps to perform this early-return or
conditional logic and only compute the normalized cumulative list when total_cap
> 0.

for i, cap in enumerate(caps[1:]):
Expand Down
8 changes: 4 additions & 4 deletions temoa/utilities/clear_db_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from pathlib import Path

basic_output_tables = [
basic_output_tables: list[str] = [
'output_built_capacity',
'output_cost',
'output_curtailment',
Expand All @@ -19,15 +19,15 @@
'output_objective',
'output_retired_capacity',
]
optional_output_tables = ['output_flow_out_summary', 'myopic_efficiency']
optional_output_tables: list[str] = ['output_flow_out_summary', 'myopic_efficiency']

if len(sys.argv) != 2:
print('this utility file expects a CLA for the path to the database to clear')
sys.exit(-1)

target_db_str = sys.argv[1]
target_db_str: str = sys.argv[1]

proceed = input('This will clear ALL output tables in ' + target_db_str + '? (y/n): ')
proceed: str = input('This will clear ALL output tables in ' + target_db_str + '? (y/n): ')
if proceed == 'y':
target_db = Path(target_db_str)
if not target_db.exists():
Expand Down
17 changes: 11 additions & 6 deletions temoa/utilities/database_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
periods, and regions.
"""

from __future__ import annotations

import os
import re
import sqlite3
from os import PathLike
from typing import Any, cast

import deprecated
import pandas as pd
Expand Down Expand Up @@ -55,9 +58,9 @@ def __init__(self, database_path: str | PathLike[str], scenario: str | None = No

def close(self) -> None:
"""Closes the database cursor and connection."""
if self.cur:
if hasattr(self, 'cur') and self.cur:
self.cur.close()
if self.con:
if hasattr(self, 'con') and self.con:
self.con.close()

@staticmethod
Expand Down Expand Up @@ -115,6 +118,8 @@ def get_time_peridos_for_flags(self, flags: list[str] | None = None) -> set[int]
query = f'SELECT period FROM time_period WHERE flag IN ({in_clause})'

self.cur.execute(query)
# cast to int because sqlite might return strings or ints depending on how data was inserted
# but type hint says set[int]
return {int(row[0]) for row in self.cur}

def get_technologies_for_flags(self, flags: list[str] | None = None) -> set[str]:
Expand All @@ -125,7 +130,7 @@ def get_technologies_for_flags(self, flags: list[str] | None = None) -> set[str]
in_clause = ', '.join(f"'{flag}'" for flag in flags)
query = f'SELECT tech FROM Technology WHERE flag IN ({in_clause})'

return {row[0] for row in self.cur.execute(query)}
return {cast('str', row[0]) for row in self.cur.execute(query)}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

LGTM, but consider removing quotes around type.

The cast is correct for type narrowing. Since the project requires Python 3.12+, you can simplify to cast(str, row[0]) without the string literal.

🤖 Prompt for AI Agents
In @temoa/utilities/database_util.py at line 133, The cast usage in the return
expression of the method that does "return {cast('str', row[0]) for row in
self.cur.execute(query)}" should use a real type object instead of a string
literal; change the cast call to cast(str, row[0]) (no quotes) in that set
comprehension so it uses the actual type for Python 3.12+ type narrowing while
leaving the rest of the expression (self.cur.execute(query)) unchanged.


def get_commodities_and_tech(
self, inp_comm: str | None, inp_tech: str | None, region: str | None
Expand Down Expand Up @@ -171,7 +176,7 @@ def get_commodities_for_flags(self, flags: list[str] | None = None) -> set[str]:
in_clause = ', '.join(f"'{flag}'" for flag in flags)
query = f'SELECT name FROM Commodity WHERE flag IN ({in_clause})'

return {row[0] for row in self.cur.execute(query)}
return {cast('str', row[0]) for row in self.cur.execute(query)}

def get_commodities_by_technology(
self, region: str | None, comm_type: str = 'input'
Expand All @@ -187,11 +192,11 @@ def get_commodities_by_technology(
if region:
query += f" WHERE region LIKE '%{region}%'"

return {tuple(row) for row in self.cur.execute(query)}
return {cast('tuple[str, str]', row) for row in self.cur.execute(query)}

def get_capacity_for_tech_and_period(
self, tech: str | None = None, period: int | None = None, region: str | None = None
) -> pd.DataFrame | pd.Series:
) -> pd.DataFrame | pd.Series[Any]:
"""Retrieves capacity data, aggregated by technology."""
if not self.scenario:
raise ValueError('A scenario must be set for output-related queries')
Expand Down
30 changes: 18 additions & 12 deletions temoa/utilities/db_migration_to_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any

parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -36,9 +37,9 @@
new_db_name = legacy_db.stem + '_v3.sqlite'
new_db_path = Path(legacy_db.parent, new_db_name)

con_old = sqlite3.connect(legacy_db)
con_new = sqlite3.connect(new_db_path)
cur = con_new.cursor()
con_old: sqlite3.Connection = sqlite3.connect(legacy_db)
con_new: sqlite3.Connection = sqlite3.connect(new_db_path)
cur: sqlite3.Cursor = con_new.cursor()

# bring in the new schema and execute
with open(schema_file) as src:
Expand All @@ -50,7 +51,7 @@

# table mapping for DIRECT transfers
# fmt: off
direct_transfer_tables = [
direct_transfer_tables: list[tuple[str, str]] = [
("", "CapacityCredit"),
("", "CapacityFactorProcess"),
("", "CapacityFactorTech"),
Expand Down Expand Up @@ -105,14 +106,14 @@
("SegFrac", "TimeSegmentFraction"),
]

units_added_tables = [
units_added_tables: list[tuple[str, str]] = [
("", "MaxActivityGroup"),
("", "MaxCapacityGroup"),
("", "MinCapacityGroup"),
("", "MinActivityGroup"),
]

sequence_added_tables = [
sequence_added_tables: list[tuple[str, str]] = [
("time_season", "TimeSeason"),
("time_periods", "time_period"),
("time_of_day", "TimeOfDay"),
Expand All @@ -126,12 +127,16 @@
if old_name == '':
old_name = new_name

new_columns = [c[1] for c in con_new.execute(f'PRAGMA table_info({new_name});').fetchall()]
old_columns = [c[1] for c in con_old.execute(f'PRAGMA table_info({old_name});').fetchall()]
new_columns: list[str] = [
c[1] for c in con_new.execute(f'PRAGMA table_info({new_name});').fetchall()
]
old_columns: list[str] = [
c[1] for c in con_old.execute(f'PRAGMA table_info({old_name});').fetchall()
]
cols = str(old_columns[0 : len(new_columns)])[1:-1].replace("'", '')

try:
data = con_old.execute(f'SELECT {cols} FROM {old_name}').fetchall()
data: list[Any] = con_old.execute(f'SELECT {cols} FROM {old_name}').fetchall()
except sqlite3.OperationalError:
print('TABLE NOT FOUND: ' + old_name)
data = []
Expand Down Expand Up @@ -222,10 +227,11 @@
# let's ensure all the non-global entries are consistent (same techs in each region)
skip_rps = False
try:
rps_entries = con_old.execute('SELECT * FROM tech_rps').fetchall()
rps_entries: list[tuple[str, str, str]] = con_old.execute('SELECT * FROM tech_rps').fetchall()
except sqlite3.OperationalError:
print('source does not appear to include RPS techs...skipping')
skip_rps = True
rps_entries = []
if not skip_rps:
for region, tech, _notes in rps_entries:
groups[region].add(tech)
Expand All @@ -239,7 +245,7 @@
for group, techs in groups.items():
print(f'group: {group} mismatches: {common ^ techs}')
if group != 'global':
techs_common &= not common ^ techs
techs_common &= not (common ^ techs)
if not techs_common:
print(
'combining RPS techs failed. Some regions are not same. Must be done '
Expand Down Expand Up @@ -357,7 +363,7 @@
data = con_old.execute(read_qry).fetchall()
if unlim_cap_present:
# need to convert null -> 0 for unlim_cap to match new schema that does not allow null
new_data = []
new_data: list[Any] = []
for row in data:
new_row = list(row)
if new_row[4] is None:
Expand Down
5 changes: 3 additions & 2 deletions temoa/utilities/db_migration_v3_1_to_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import sqlite3
from pathlib import Path
from typing import Any

# ---------- Mapping configuration ----------
CUSTOM_MAP: dict[str, str] = {
Expand Down Expand Up @@ -93,7 +94,7 @@ def map_token_no_cascade(token: str) -> str:
return to_snake_case(token)


def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple]:
def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple[Any, ...]]:
try:
return conn.execute(f'PRAGMA table_info({table});').fetchall()
except sqlite3.OperationalError:
Expand Down Expand Up @@ -131,7 +132,7 @@ def migrate_direct_table(
return len(filtered)


def migrate_all(args) -> None:
def migrate_all(args: argparse.Namespace) -> None:
src = Path(args.source)
schema = Path(args.schema)
out = Path(args.out) if args.out else src.with_suffix('.v4.sqlite')
Expand Down
Loading
Loading