Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyomo/contrib/solver/solvers/knitro/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from typing import Optional

from pyomo.common.dependencies import attempt_import

knitro, KNITRO_AVAILABLE = attempt_import("knitro")


def get_version() -> Optional[str]:
def get_version() -> str | None:
if not KNITRO_AVAILABLE:
return None
return knitro.__version__
7 changes: 3 additions & 4 deletions pyomo/contrib/solver/solvers/knitro/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import datetime
import time
from io import StringIO
from typing import Optional

from pyomo.common.collections import ComponentMap
from pyomo.common.errors import ApplicationError, DeveloperError, PyomoException
Expand Down Expand Up @@ -56,7 +55,7 @@ class KnitroSolverBase(SolutionProvider, PackageChecker, SolverBase):
_engine: Engine
_model_data: KnitroModelData
_stream: StringIO
_saved_var_values: dict[int, Optional[float]]
_saved_var_values: dict[int, float | None]

def __init__(self, **kwds) -> None:
PackageChecker.__init__(self)
Expand Down Expand Up @@ -182,10 +181,10 @@ def get_values(
self,
item_type: type[ItemType],
value_type: ValueType,
items: Optional[Sequence[ItemType]] = None,
items: Sequence[ItemType] | None = None,
*,
exists: bool,
solution_id: Optional[int] = None,
solution_id: int | None = None,
) -> Mapping[ItemType, float]:
error_type = self._get_error_type(item_type, value_type)
if not exists:
Expand Down
4 changes: 2 additions & 2 deletions pyomo/contrib/solver/solvers/knitro/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# ___________________________________________________________________________

from collections.abc import Callable
from typing import Any, Optional, Protocol
from typing import Any, Protocol

from pyomo.contrib.solver.solvers.knitro.typing import (
Callback,
Expand Down Expand Up @@ -85,7 +85,7 @@ def hess(self, req: CallbackRequest, res: CallbackResult) -> int:
return 0


def build_callback_handler(function: Function, idx: Optional[int]) -> CallbackHandler:
def build_callback_handler(function: Function, idx: int | None) -> CallbackHandler:
if idx is None:
return ObjectiveCallbackHandler(function)
return ConstraintCallbackHandler(idx, function)
28 changes: 13 additions & 15 deletions pyomo/contrib/solver/solvers/knitro/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence
from types import MappingProxyType
from typing import Any, Optional, TypeVar
from typing import Any, TypeVar

from pyomo.common.enums import ObjectiveSense
from pyomo.common.errors import DeveloperError
Expand Down Expand Up @@ -92,7 +92,7 @@ def api_set_param(param_type: int) -> Callable[..., None]:

def api_get_values(
item_type: type[ItemType], value_type: ValueType
) -> Callable[..., Optional[list[float]]]:
) -> Callable[..., list[float] | None]:
if item_type is VarData:
if value_type == ValueType.PRIMAL:
return knitro.KN_get_var_primal_values
Expand All @@ -108,7 +108,7 @@ def api_get_values(
)


def api_add_items(item_type: type[ItemType]) -> Callable[..., Optional[list[int]]]:
def api_add_items(item_type: type[ItemType]) -> Callable[..., list[int] | None]:
if item_type is VarData:
return knitro.KN_add_vars
elif item_type is ConstraintData:
Expand Down Expand Up @@ -169,11 +169,11 @@ class Engine:

has_objective: bool
maps: Mapping[type[ItemData], MutableMapping[int, int]]
nonlinear_map: MutableMapping[Optional[int], NonlinearExpressionData]
nonlinear_map: MutableMapping[int | None, NonlinearExpressionData]
nonlinear_diff_order: int

_kc: Optional[Any]
_status: Optional[int]
_kc: Any | None
_status: int | None

def __init__(self, *, nonlinear_diff_order: int = 2) -> None:
self.has_objective = False
Expand Down Expand Up @@ -242,7 +242,7 @@ def set_options(self, **options) -> None:
for param, val in options.items():
self.set_option(param, val)

def set_outlev(self, level: Optional[int] = None) -> None:
def set_outlev(self, level: int | None = None) -> None:
if level is None:
level = knitro.KN_OUTLEV_ALL
self.set_options(outlev=level)
Expand Down Expand Up @@ -289,7 +289,7 @@ def get_num_solutions(self) -> int:
def get_solve_time(self) -> float:
return self.execute(knitro.KN_get_solve_time_real)

def get_obj_value(self) -> Optional[float]:
def get_obj_value(self) -> float | None:
if not self.has_objective:
return None
if self._status not in {
Expand All @@ -303,7 +303,7 @@ def get_obj_value(self) -> Optional[float]:
return None
return self.execute(knitro.KN_get_obj_value)

def get_obj_bound(self) -> Optional[float]:
def get_obj_bound(self) -> float | None:
if not self.has_objective:
return None
return self.execute(knitro.KN_get_mip_relaxation_bnd)
Expand All @@ -319,7 +319,7 @@ def get_values(
item_type: type[ItemType],
value_type: ValueType,
items: Iterable[ItemType],
) -> Optional[list[float]]:
) -> list[float] | None:
func = api_get_values(item_type, value_type)
idxs = self.get_idxs(item_type, items)
return self.execute(func, idxs)
Expand Down Expand Up @@ -367,7 +367,7 @@ def set_con_structures(self, cons: Iterable[ConstraintData]) -> None:
def set_obj_structures(self, obj: ObjectiveData) -> None:
self.add_structures(None, obj.expr)

def add_structures(self, i: Optional[int], expr) -> None:
def add_structures(self, i: int | None, expr) -> None:
repn = generate_standard_repn(expr)
if repn is None:
return
Expand Down Expand Up @@ -408,7 +408,7 @@ def add_structures(self, i: Optional[int], expr) -> None:
)

def add_callback(
self, i: Optional[int], expr: NonlinearExpressionData, callback: Callback
self, i: int | None, expr: NonlinearExpressionData, callback: Callback
) -> None:
is_obj = i is None
idx_cons = [i] if not is_obj else None
Expand Down Expand Up @@ -440,8 +440,6 @@ def register_callbacks(self) -> None:
for i, expr in self.nonlinear_map.items():
self.register_callback(i, expr)

def register_callback(
self, i: Optional[int], expr: NonlinearExpressionData
) -> None:
def register_callback(self, i: int | None, expr: NonlinearExpressionData) -> None:
callback = build_callback_handler(expr, idx=i).expand()
self.add_callback(i, expr, callback)
7 changes: 3 additions & 4 deletions pyomo/contrib/solver/solvers/knitro/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# ___________________________________________________________________________

import io
from typing import Optional

from pyomo.common.tee import TeeStream, capture_output
from pyomo.contrib.solver.common.base import Availability
Expand Down Expand Up @@ -61,7 +60,7 @@ def create_context():
return knitro.KN_new_lm(lmc)

@staticmethod
def get_version() -> Optional[tuple[int, int, int]]:
def get_version() -> tuple[int, int, int] | None:
"""Get the version of the KNITRO solver as a tuple.

Returns:
Expand Down Expand Up @@ -97,7 +96,7 @@ def check_availability() -> Availability:


class PackageChecker:
_available_cache: Optional[Availability]
_available_cache: Availability | None

def __init__(self) -> None:
self._available_cache = None
Expand All @@ -107,5 +106,5 @@ def available(self) -> Availability:
self._available_cache = Package.check_availability()
return self._available_cache

def version(self) -> Optional[tuple[int, int, int]]:
def version(self) -> tuple[int, int, int] | None:
return Package.get_version()
18 changes: 12 additions & 6 deletions pyomo/contrib/solver/solvers/knitro/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# ___________________________________________________________________________

from collections.abc import Mapping, Sequence
from typing import Optional, Protocol
from typing import Protocol

from pyomo.contrib.solver.common.solution_loader import SolutionLoaderBase
from pyomo.contrib.solver.solvers.knitro.typing import ItemType, ValueType
Expand All @@ -25,10 +25,10 @@ def get_values(
self,
item_type: type[ItemType],
value_type: ValueType,
items: Optional[Sequence[ItemType]] = None,
items: Sequence[ItemType] | None = None,
*,
exists: bool,
solution_id: Optional[int] = None,
solution_id: int | None = None,
) -> Mapping[ItemType, float]: ...


Expand Down Expand Up @@ -60,7 +60,9 @@ def get_primals(self, vars_to_load=None):
return self.get_vars(vars_to_load)

def get_vars(
self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None
self,
vars_to_load: Sequence[VarData] | None = None,
solution_id: int | None = None,
) -> Mapping[VarData, float]:
return self._provider.get_values(
VarData,
Expand All @@ -71,7 +73,9 @@ def get_vars(
)

def get_reduced_costs(
self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None
self,
vars_to_load: Sequence[VarData] | None = None,
solution_id: int | None = None,
) -> Mapping[VarData, float]:
return self._provider.get_values(
VarData,
Expand All @@ -82,7 +86,9 @@ def get_reduced_costs(
)

def get_duals(
self, cons_to_load: Optional[Sequence[ConstraintData]] = None, solution_id=None
self,
cons_to_load: Sequence[ConstraintData] | None = None,
solution_id: int | None = None,
) -> Mapping[ConstraintData, float]:
return self._provider.get_values(
ConstraintData,
Expand Down
4 changes: 2 additions & 2 deletions pyomo/contrib/solver/solvers/knitro/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# ___________________________________________________________________________

from collections.abc import Callable
from typing import Any, NamedTuple, Protocol, TypeVar, Union
from typing import Any, NamedTuple, Protocol, TypeVar

from pyomo.common.enums import Enum
from pyomo.core.base.constraint import ConstraintData
Expand Down Expand Up @@ -38,7 +38,7 @@ def sign(self) -> float:
return -1.0 if self == ValueType.DUAL else 1.0


ItemData = Union[VarData, ConstraintData]
ItemData = VarData | ConstraintData
ItemType = TypeVar("ItemType", bound=ItemData)


Expand Down
5 changes: 2 additions & 3 deletions pyomo/contrib/solver/solvers/knitro/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# ___________________________________________________________________________

from collections.abc import Iterable, Mapping, MutableSet, Sequence
from typing import Optional

from pyomo.common.collections import ComponentMap, ComponentSet
from pyomo.common.numeric_types import value
Expand Down Expand Up @@ -75,11 +74,11 @@ class KnitroModelData:
variables: list[VarData]
_vars: MutableSet[VarData]

def __init__(self, block: Optional[BlockData] = None) -> None:
def __init__(self, block: BlockData | None = None) -> None:
"""Initialize a Problem instance.

Args:
block (Optional[BlockData]): Pyomo block to initialize from. If None,
block (BlockData | None): Pyomo block to initialize from. If None,
creates an empty problem that can be populated later.

"""
Expand Down