From 7d0027a2efff9bccb3f51b8c1c0c11fe78cf67f4 Mon Sep 17 00:00:00 2001 From: eminyouskn Date: Tue, 27 Jan 2026 13:40:21 -0500 Subject: [PATCH 1/2] improve type hints --- pyomo/contrib/solver/solvers/knitro/api.py | 4 +-- pyomo/contrib/solver/solvers/knitro/base.py | 7 +++-- .../contrib/solver/solvers/knitro/callback.py | 4 +-- pyomo/contrib/solver/solvers/knitro/engine.py | 26 +++++++++---------- .../contrib/solver/solvers/knitro/package.py | 7 +++-- .../contrib/solver/solvers/knitro/solution.py | 12 ++++----- pyomo/contrib/solver/solvers/knitro/typing.py | 4 +-- pyomo/contrib/solver/solvers/knitro/utils.py | 5 ++-- 8 files changed, 32 insertions(+), 37 deletions(-) diff --git a/pyomo/contrib/solver/solvers/knitro/api.py b/pyomo/contrib/solver/solvers/knitro/api.py index 89c3f1392b4..06b0e8d24f0 100644 --- a/pyomo/contrib/solver/solvers/knitro/api.py +++ b/pyomo/contrib/solver/solvers/knitro/api.py @@ -9,14 +9,12 @@ # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ -from typing import Optional - from pyomo.common.dependencies import attempt_import knitro, KNITRO_AVAILABLE = attempt_import("knitro") -def get_version() -> Optional[str]: +def get_version() -> str | None: if not KNITRO_AVAILABLE: return None return knitro.__version__ diff --git a/pyomo/contrib/solver/solvers/knitro/base.py b/pyomo/contrib/solver/solvers/knitro/base.py index feb6595126f..a08aa65a311 100644 --- a/pyomo/contrib/solver/solvers/knitro/base.py +++ b/pyomo/contrib/solver/solvers/knitro/base.py @@ -14,7 +14,6 @@ import datetime import time from io import StringIO -from typing import Optional from pyomo.common.collections import ComponentMap from pyomo.common.errors import ApplicationError, DeveloperError, PyomoException @@ -56,7 +55,7 @@ class KnitroSolverBase(SolutionProvider, PackageChecker, SolverBase): _engine: Engine _model_data: KnitroModelData _stream: StringIO - _saved_var_values: dict[int, Optional[float]] + _saved_var_values: dict[int, float | None] def __init__(self, **kwds) -> None: PackageChecker.__init__(self) @@ -182,10 +181,10 @@ def get_values( self, item_type: type[ItemType], value_type: ValueType, - items: Optional[Sequence[ItemType]] = None, + items: Sequence[ItemType] | None = None, *, exists: bool, - solution_id: Optional[int] = None, + solution_id: int | None = None, ) -> Mapping[ItemType, float]: error_type = self._get_error_type(item_type, value_type) if not exists: diff --git a/pyomo/contrib/solver/solvers/knitro/callback.py b/pyomo/contrib/solver/solvers/knitro/callback.py index 6a9ddb71eb0..fcf391f3e94 100644 --- a/pyomo/contrib/solver/solvers/knitro/callback.py +++ b/pyomo/contrib/solver/solvers/knitro/callback.py @@ -10,7 +10,7 @@ # ___________________________________________________________________________ from collections.abc import Callable -from typing import Any, Optional, Protocol +from typing import Any, Protocol from pyomo.contrib.solver.solvers.knitro.typing import ( Callback, @@ -85,7 +85,7 @@ def hess(self, req: CallbackRequest, res: CallbackResult) -> int: return 0 -def build_callback_handler(function: Function, idx: Optional[int]) -> CallbackHandler: +def build_callback_handler(function: Function, idx: int | None) -> CallbackHandler: if idx is None: return ObjectiveCallbackHandler(function) return ConstraintCallbackHandler(idx, function) diff --git a/pyomo/contrib/solver/solvers/knitro/engine.py b/pyomo/contrib/solver/solvers/knitro/engine.py index 325d778ae2a..9c0620ecef0 100644 --- a/pyomo/contrib/solver/solvers/knitro/engine.py +++ b/pyomo/contrib/solver/solvers/knitro/engine.py @@ -11,7 +11,7 @@ from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence from types import MappingProxyType -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from pyomo.common.enums import ObjectiveSense from pyomo.common.errors import DeveloperError @@ -92,7 +92,7 @@ def api_set_param(param_type: int) -> Callable[..., None]: def api_get_values( item_type: type[ItemType], value_type: ValueType -) -> Callable[..., Optional[list[float]]]: + ) -> Callable[..., list[float] | None]: if item_type is VarData: if value_type == ValueType.PRIMAL: return knitro.KN_get_var_primal_values @@ -108,7 +108,7 @@ def api_get_values( ) -def api_add_items(item_type: type[ItemType]) -> Callable[..., Optional[list[int]]]: +def api_add_items(item_type: type[ItemType]) -> Callable[..., list[int] | None]: if item_type is VarData: return knitro.KN_add_vars elif item_type is ConstraintData: @@ -169,11 +169,11 @@ class Engine: has_objective: bool maps: Mapping[type[ItemData], MutableMapping[int, int]] - nonlinear_map: MutableMapping[Optional[int], NonlinearExpressionData] + nonlinear_map: MutableMapping[int | None, NonlinearExpressionData] nonlinear_diff_order: int - _kc: Optional[Any] - _status: Optional[int] + _kc: Any | None + _status: int | None def __init__(self, *, nonlinear_diff_order: int = 2) -> None: self.has_objective = False @@ -242,7 +242,7 @@ def set_options(self, **options) -> None: for param, val in options.items(): self.set_option(param, val) - def set_outlev(self, level: Optional[int] = None) -> None: + def set_outlev(self, level: int | None = None) -> None: if level is None: level = knitro.KN_OUTLEV_ALL self.set_options(outlev=level) @@ -289,7 +289,7 @@ def get_num_solutions(self) -> int: def get_solve_time(self) -> float: return self.execute(knitro.KN_get_solve_time_real) - def get_obj_value(self) -> Optional[float]: + def get_obj_value(self) -> float | None: if not self.has_objective: return None if self._status not in { @@ -303,7 +303,7 @@ def get_obj_value(self) -> Optional[float]: return None return self.execute(knitro.KN_get_obj_value) - def get_obj_bound(self) -> Optional[float]: + def get_obj_bound(self) -> float | None: if not self.has_objective: return None return self.execute(knitro.KN_get_mip_relaxation_bnd) @@ -319,7 +319,7 @@ def get_values( item_type: type[ItemType], value_type: ValueType, items: Iterable[ItemType], - ) -> Optional[list[float]]: + ) -> list[float] | None: func = api_get_values(item_type, value_type) idxs = self.get_idxs(item_type, items) return self.execute(func, idxs) @@ -367,7 +367,7 @@ def set_con_structures(self, cons: Iterable[ConstraintData]) -> None: def set_obj_structures(self, obj: ObjectiveData) -> None: self.add_structures(None, obj.expr) - def add_structures(self, i: Optional[int], expr) -> None: + def add_structures(self, i: int | None, expr) -> None: repn = generate_standard_repn(expr) if repn is None: return @@ -408,7 +408,7 @@ def add_structures(self, i: Optional[int], expr) -> None: ) def add_callback( - self, i: Optional[int], expr: NonlinearExpressionData, callback: Callback + self, i: int | None, expr: NonlinearExpressionData, callback: Callback ) -> None: is_obj = i is None idx_cons = [i] if not is_obj else None @@ -441,7 +441,7 @@ def register_callbacks(self) -> None: self.register_callback(i, expr) def register_callback( - self, i: Optional[int], expr: NonlinearExpressionData + self, i: int | None, expr: NonlinearExpressionData ) -> None: callback = build_callback_handler(expr, idx=i).expand() self.add_callback(i, expr, callback) diff --git a/pyomo/contrib/solver/solvers/knitro/package.py b/pyomo/contrib/solver/solvers/knitro/package.py index 2298bb93071..d668fce3b0a 100644 --- a/pyomo/contrib/solver/solvers/knitro/package.py +++ b/pyomo/contrib/solver/solvers/knitro/package.py @@ -10,7 +10,6 @@ # ___________________________________________________________________________ import io -from typing import Optional from pyomo.common.tee import TeeStream, capture_output from pyomo.contrib.solver.common.base import Availability @@ -61,7 +60,7 @@ def create_context(): return knitro.KN_new_lm(lmc) @staticmethod - def get_version() -> Optional[tuple[int, int, int]]: + def get_version() -> tuple[int, int, int] | None: """Get the version of the KNITRO solver as a tuple. Returns: @@ -97,7 +96,7 @@ def check_availability() -> Availability: class PackageChecker: - _available_cache: Optional[Availability] + _available_cache: Availability | None def __init__(self) -> None: self._available_cache = None @@ -107,5 +106,5 @@ def available(self) -> Availability: self._available_cache = Package.check_availability() return self._available_cache - def version(self) -> Optional[tuple[int, int, int]]: + def version(self) -> tuple[int, int, int] | None: return Package.get_version() diff --git a/pyomo/contrib/solver/solvers/knitro/solution.py b/pyomo/contrib/solver/solvers/knitro/solution.py index 1222446a7a5..b1a4a8bb7e2 100644 --- a/pyomo/contrib/solver/solvers/knitro/solution.py +++ b/pyomo/contrib/solver/solvers/knitro/solution.py @@ -10,7 +10,7 @@ # ___________________________________________________________________________ from collections.abc import Mapping, Sequence -from typing import Optional, Protocol +from typing import Protocol from pyomo.contrib.solver.common.solution_loader import SolutionLoaderBase from pyomo.contrib.solver.solvers.knitro.typing import ItemType, ValueType @@ -25,10 +25,10 @@ def get_values( self, item_type: type[ItemType], value_type: ValueType, - items: Optional[Sequence[ItemType]] = None, + items: Sequence[ItemType] | None = None, *, exists: bool, - solution_id: Optional[int] = None, + solution_id: int | None = None, ) -> Mapping[ItemType, float]: ... @@ -60,7 +60,7 @@ def get_primals(self, vars_to_load=None): return self.get_vars(vars_to_load) def get_vars( - self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None + self, vars_to_load: Sequence[VarData] | None = None, solution_id: int | None = None ) -> Mapping[VarData, float]: return self._provider.get_values( VarData, @@ -71,7 +71,7 @@ def get_vars( ) def get_reduced_costs( - self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None + self, vars_to_load: Sequence[VarData] | None = None, solution_id: int | None = None ) -> Mapping[VarData, float]: return self._provider.get_values( VarData, @@ -82,7 +82,7 @@ def get_reduced_costs( ) def get_duals( - self, cons_to_load: Optional[Sequence[ConstraintData]] = None, solution_id=None + self, cons_to_load: Sequence[ConstraintData] | None = None, solution_id: int | None = None ) -> Mapping[ConstraintData, float]: return self._provider.get_values( ConstraintData, diff --git a/pyomo/contrib/solver/solvers/knitro/typing.py b/pyomo/contrib/solver/solvers/knitro/typing.py index 11575ab4467..414affb8b4d 100644 --- a/pyomo/contrib/solver/solvers/knitro/typing.py +++ b/pyomo/contrib/solver/solvers/knitro/typing.py @@ -10,7 +10,7 @@ # ___________________________________________________________________________ from collections.abc import Callable -from typing import Any, NamedTuple, Protocol, TypeVar, Union +from typing import Any, NamedTuple, Protocol, TypeVar from pyomo.common.enums import Enum from pyomo.core.base.constraint import ConstraintData @@ -38,7 +38,7 @@ def sign(self) -> float: return -1.0 if self == ValueType.DUAL else 1.0 -ItemData = Union[VarData, ConstraintData] +ItemData = VarData | ConstraintData ItemType = TypeVar("ItemType", bound=ItemData) diff --git a/pyomo/contrib/solver/solvers/knitro/utils.py b/pyomo/contrib/solver/solvers/knitro/utils.py index bba08f73216..f3733e6adb8 100644 --- a/pyomo/contrib/solver/solvers/knitro/utils.py +++ b/pyomo/contrib/solver/solvers/knitro/utils.py @@ -10,7 +10,6 @@ # ___________________________________________________________________________ from collections.abc import Iterable, Mapping, MutableSet, Sequence -from typing import Optional from pyomo.common.collections import ComponentMap, ComponentSet from pyomo.common.numeric_types import value @@ -75,11 +74,11 @@ class KnitroModelData: variables: list[VarData] _vars: MutableSet[VarData] - def __init__(self, block: Optional[BlockData] = None) -> None: + def __init__(self, block: BlockData | None = None) -> None: """Initialize a Problem instance. Args: - block (Optional[BlockData]): Pyomo block to initialize from. If None, + block (BlockData | None): Pyomo block to initialize from. If None, creates an empty problem that can be populated later. """ From cdd219170cb7384a17afee36cf35cfbd1fc0f301 Mon Sep 17 00:00:00 2001 From: eminyouskn Date: Tue, 27 Jan 2026 13:40:58 -0500 Subject: [PATCH 2/2] black --- pyomo/contrib/solver/solvers/knitro/engine.py | 6 ++---- pyomo/contrib/solver/solvers/knitro/solution.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pyomo/contrib/solver/solvers/knitro/engine.py b/pyomo/contrib/solver/solvers/knitro/engine.py index 9c0620ecef0..62d7f989ca6 100644 --- a/pyomo/contrib/solver/solvers/knitro/engine.py +++ b/pyomo/contrib/solver/solvers/knitro/engine.py @@ -92,7 +92,7 @@ def api_set_param(param_type: int) -> Callable[..., None]: def api_get_values( item_type: type[ItemType], value_type: ValueType - ) -> Callable[..., list[float] | None]: +) -> Callable[..., list[float] | None]: if item_type is VarData: if value_type == ValueType.PRIMAL: return knitro.KN_get_var_primal_values @@ -440,8 +440,6 @@ def register_callbacks(self) -> None: for i, expr in self.nonlinear_map.items(): self.register_callback(i, expr) - def register_callback( - self, i: int | None, expr: NonlinearExpressionData - ) -> None: + def register_callback(self, i: int | None, expr: NonlinearExpressionData) -> None: callback = build_callback_handler(expr, idx=i).expand() self.add_callback(i, expr, callback) diff --git a/pyomo/contrib/solver/solvers/knitro/solution.py b/pyomo/contrib/solver/solvers/knitro/solution.py index b1a4a8bb7e2..100a48185d2 100644 --- a/pyomo/contrib/solver/solvers/knitro/solution.py +++ b/pyomo/contrib/solver/solvers/knitro/solution.py @@ -60,7 +60,9 @@ def get_primals(self, vars_to_load=None): return self.get_vars(vars_to_load) def get_vars( - self, vars_to_load: Sequence[VarData] | None = None, solution_id: int | None = None + self, + vars_to_load: Sequence[VarData] | None = None, + solution_id: int | None = None, ) -> Mapping[VarData, float]: return self._provider.get_values( VarData, @@ -71,7 +73,9 @@ def get_vars( ) def get_reduced_costs( - self, vars_to_load: Sequence[VarData] | None = None, solution_id: int | None = None + self, + vars_to_load: Sequence[VarData] | None = None, + solution_id: int | None = None, ) -> Mapping[VarData, float]: return self._provider.get_values( VarData, @@ -82,7 +86,9 @@ def get_reduced_costs( ) def get_duals( - self, cons_to_load: Sequence[ConstraintData] | None = None, solution_id: int | None = None + self, + cons_to_load: Sequence[ConstraintData] | None = None, + solution_id: int | None = None, ) -> Mapping[ConstraintData, float]: return self._provider.get_values( ConstraintData,