-
Notifications
You must be signed in to change notification settings - Fork 568
Rework SolutionLoader #3701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rework SolutionLoader #3701
Changes from all commits
b57ab07
710807b
438b9b5
ac42345
70ca6e7
5ec0421
2885f42
c62a7b3
a4e2b81
1750fc5
413d63d
0bbdd70
a96b518
ce4e77c
c3f2d48
ba4b29c
0792800
f3370f3
3a44486
a125456
b5d16d9
9473f29
7557036
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,11 +9,34 @@ | |
| # This software is distributed under the 3-clause BSD License. | ||
| # ___________________________________________________________________________ | ||
|
|
||
| from typing import Sequence, Dict, Optional, Mapping | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Sequence, Dict, Optional, Mapping, List, Any | ||
|
|
||
| from pyomo.core.base.constraint import ConstraintData | ||
| from pyomo.core.base.var import VarData | ||
| from pyomo.core.staleflag import StaleFlagManager | ||
| from pyomo.core.base.suffix import Suffix | ||
|
|
||
|
|
||
| def load_import_suffixes( | ||
| pyomo_model, solution_loader: SolutionLoaderBase, solution_id=None | ||
| ): | ||
| dual_suffix = None | ||
| rc_suffix = None | ||
| for suffix in pyomo_model.component_objects(Suffix, descend_into=True, active=True): | ||
| if not suffix.import_enabled(): | ||
| continue | ||
| if suffix.local_name == 'dual': | ||
| dual_suffix = suffix | ||
| elif suffix.local_name == 'rc': | ||
| rc_suffix = suffix | ||
| if dual_suffix is not None: | ||
| for k, v in solution_loader.get_duals(solution_id=solution_id).items(): | ||
| dual_suffix[k] = v | ||
| if rc_suffix is not None: | ||
| for k, v in solution_loader.get_reduced_costs(solution_id=solution_id).items(): | ||
| rc_suffix[k] = v | ||
|
|
||
|
|
||
| class SolutionLoaderBase: | ||
|
|
@@ -23,24 +46,70 @@ class SolutionLoaderBase: | |
| Intent of this class and its children is to load the solution back into the model. | ||
| """ | ||
|
|
||
| def load_vars(self, vars_to_load: Optional[Sequence[VarData]] = None) -> None: | ||
| def get_solution_ids(self) -> List[Any]: | ||
| """ | ||
| If there are multiple solutions available, this will return a | ||
| list of the solution ids which can then be used with other | ||
| methods like `load_solution`. If only one solution is | ||
| available, this will return [None]. If no solutions | ||
| are available, this will return None | ||
|
|
||
| Returns | ||
| ------- | ||
| solutions_ids: List[Any] | ||
| The identifiers for multiple solutions | ||
| """ | ||
| return NotImplemented | ||
|
|
||
| def get_number_of_solutions(self) -> int: | ||
| """ | ||
| Returns | ||
| ------- | ||
| num_solutions: int | ||
| Indicates the number of solutions found | ||
| """ | ||
| return NotImplemented | ||
|
|
||
| def load_solution(self, solution_id=None): | ||
| """ | ||
| Load the solution of the primal variables into the value attribute of the variables. | ||
| Load the solution (everything that can be) back into the model | ||
|
|
||
| Parameters | ||
| ---------- | ||
| solution_id: Optional[Any] | ||
| If there are multiple solutions, this specifies which solution | ||
| should be loaded. If None, the default solution will be used. | ||
| """ | ||
| # this should load everything it can | ||
| self.load_vars(solution_id=solution_id) | ||
| self.load_import_suffixes(solution_id=solution_id) | ||
|
|
||
| def load_vars( | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None | ||
| ) -> None: | ||
| """ | ||
| Load the solution of the primal variables into the value attribute | ||
| of the variables. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| vars_to_load: list | ||
| The minimum set of variables whose solution should be loaded. If vars_to_load | ||
| is None, then the solution to all primal variables will be loaded. Even if | ||
| vars_to_load is specified, the values of other variables may also be | ||
| loaded depending on the interface. | ||
| The minimum set of variables whose solution should be loaded. If | ||
| vars_to_load is None, then the solution to all primal variables | ||
| will be loaded. Even if vars_to_load is specified, the values of | ||
| other variables may also be loaded depending on the interface. | ||
| solution_id: Optional[Any] | ||
| If there are multiple solutions, this specifies which solution | ||
| should be loaded. If None, the default solution will be used. | ||
| """ | ||
| for var, val in self.get_primals(vars_to_load=vars_to_load).items(): | ||
| for var, val in self.get_vars( | ||
| vars_to_load=vars_to_load, solution_id=solution_id | ||
| ).items(): | ||
| var.set_value(val, skip_validation=True) | ||
| StaleFlagManager.mark_all_as_stale(delayed=True) | ||
|
|
||
| def get_primals( | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None | ||
| def get_vars( | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None | ||
| ) -> Mapping[VarData, float]: | ||
| """ | ||
| Returns a ComponentMap mapping variable to var value. | ||
|
|
@@ -50,18 +119,21 @@ def get_primals( | |
| vars_to_load: list | ||
| A list of the variables whose solution value should be retrieved. If vars_to_load | ||
| is None, then the values for all variables will be retrieved. | ||
| solution_id: Optional[Any] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of |
||
| If there are multiple solutions, this specifies which solution | ||
| should be retrieved. If None, the default solution will be used. | ||
|
|
||
| Returns | ||
| ------- | ||
| primals: ComponentMap | ||
| Maps variables to solution values | ||
| """ | ||
| raise NotImplementedError( | ||
| f"Derived class {self.__class__.__name__} failed to implement required method 'get_primals'." | ||
| f"Derived class {self.__class__.__name__} failed to implement required method 'get_vars'." | ||
| ) | ||
|
|
||
| def get_duals( | ||
| self, cons_to_load: Optional[Sequence[ConstraintData]] = None | ||
| self, cons_to_load: Optional[Sequence[ConstraintData]] = None, solution_id=None | ||
| ) -> Dict[ConstraintData, float]: | ||
| """ | ||
| Returns a dictionary mapping constraint to dual value. | ||
|
|
@@ -71,16 +143,19 @@ def get_duals( | |
| cons_to_load: list | ||
| A list of the constraints whose duals should be retrieved. If cons_to_load | ||
| is None, then the duals for all constraints will be retrieved. | ||
| solution_id: Optional[Any] | ||
| If there are multiple solutions, this specifies which solution | ||
| should be retrieved. If None, the default solution will be used. | ||
|
|
||
| Returns | ||
| ------- | ||
| duals: dict | ||
| Maps constraints to dual values | ||
| """ | ||
| raise NotImplementedError(f'{type(self)} does not support the get_duals method') | ||
| return NotImplemented | ||
|
|
||
| def get_reduced_costs( | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None | ||
| ) -> Mapping[VarData, float]: | ||
| """ | ||
| Returns a ComponentMap mapping variable to reduced cost. | ||
|
|
@@ -90,45 +165,70 @@ def get_reduced_costs( | |
| vars_to_load: list | ||
| A list of the variables whose reduced cost should be retrieved. If vars_to_load | ||
| is None, then the reduced costs for all variables will be loaded. | ||
| solution_id: Optional[Any] | ||
| If there are multiple solutions, this specifies which solution | ||
| should be retrieved. If None, the default solution will be used. | ||
|
|
||
| Returns | ||
| ------- | ||
| reduced_costs: ComponentMap | ||
| Maps variables to reduced costs | ||
| """ | ||
| raise NotImplementedError( | ||
| f'{type(self)} does not support the get_reduced_costs method' | ||
| ) | ||
| return NotImplemented | ||
|
|
||
| def load_import_suffixes(self, solution_id=None): | ||
| """ | ||
| Parameters | ||
| ---------- | ||
| solution_id: Optional[Any] | ||
| If there are multiple solutions, this specifies which solution | ||
| should be loaded. If None, the default solution will be used. | ||
| """ | ||
| return NotImplemented | ||
|
|
||
|
|
||
| class PersistentSolutionLoader(SolutionLoaderBase): | ||
| """ | ||
| Loader for persistent solvers | ||
| """ | ||
|
|
||
| def __init__(self, solver): | ||
| def __init__(self, solver, pyomo_model): | ||
| self._solver = solver | ||
| self._valid = True | ||
| self._pyomo_model = pyomo_model | ||
|
|
||
| def _assert_solution_still_valid(self): | ||
| if not self._valid: | ||
| raise RuntimeError('The results in the solver are no longer valid.') | ||
|
|
||
| def get_primals(self, vars_to_load=None): | ||
| def get_solution_ids(self) -> List[Any]: | ||
| self._assert_solution_still_valid() | ||
| return super().get_solution_ids() | ||
|
|
||
| def get_number_of_solutions(self) -> int: | ||
| self._assert_solution_still_valid() | ||
| return self._solver._get_primals(vars_to_load=vars_to_load) | ||
| return super().get_number_of_solutions() | ||
|
|
||
| def get_vars(self, vars_to_load=None, solution_id=None): | ||
| self._assert_solution_still_valid() | ||
| return self._solver._get_primals( | ||
| vars_to_load=vars_to_load, solution_id=solution_id | ||
| ) | ||
|
|
||
| def get_duals( | ||
| self, cons_to_load: Optional[Sequence[ConstraintData]] = None | ||
| self, cons_to_load: Optional[Sequence[ConstraintData]] = None, solution_id=None | ||
| ) -> Dict[ConstraintData, float]: | ||
| self._assert_solution_still_valid() | ||
| return self._solver._get_duals(cons_to_load=cons_to_load) | ||
|
|
||
| def get_reduced_costs( | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None | ||
| ) -> Mapping[VarData, float]: | ||
| self._assert_solution_still_valid() | ||
| return self._solver._get_reduced_costs(vars_to_load=vars_to_load) | ||
|
|
||
| def load_import_suffixes(self, solution_id=None): | ||
| load_import_suffixes(self._pyomo_model, self, solution_id=solution_id) | ||
|
|
||
| def invalidate(self): | ||
| self._valid = False | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| # ___________________________________________________________________________ | ||
|
|
||
| import io | ||
| from typing import Sequence, Optional, Mapping | ||
| from typing import Sequence, Optional, Mapping, List, Any | ||
|
|
||
| from pyomo.common.collections import ComponentMap | ||
| from pyomo.common.errors import MouseTrap | ||
|
|
@@ -27,7 +27,10 @@ | |
| SolutionStatus, | ||
| TerminationCondition, | ||
| ) | ||
| from pyomo.contrib.solver.common.solution_loader import SolutionLoaderBase | ||
| from pyomo.contrib.solver.common.solution_loader import ( | ||
| SolutionLoaderBase, | ||
| load_import_suffixes, | ||
| ) | ||
|
|
||
|
|
||
| class ASLSolFileData: | ||
|
|
@@ -55,16 +58,34 @@ class ASLSolFileSolutionLoader(SolutionLoaderBase): | |
| Loader for solvers that create ASL .sol files (e.g., ipopt) | ||
| """ | ||
|
|
||
| def __init__(self, sol_data: ASLSolFileData, nl_info: NLWriterInfo) -> None: | ||
| def __init__( | ||
| self, sol_data: ASLSolFileData, nl_info: NLWriterInfo, pyomo_model | ||
| ) -> None: | ||
| self._sol_data = sol_data | ||
| self._nl_info = nl_info | ||
| self._pyomo_model = pyomo_model | ||
|
|
||
| def get_number_of_solutions(self) -> int: | ||
| if self._nl_info is None: | ||
| return 0 | ||
| return 1 | ||
|
|
||
| def get_solution_ids(self) -> List[Any]: | ||
| return [None] | ||
|
|
||
| def load_import_suffixes(self, solution_id=None): | ||
| load_import_suffixes(self._pyomo_model, self, solution_id=solution_id) | ||
|
Comment on lines
+76
to
+77
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only loads
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you give me an example of what
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes: Suffix data comes back from the ASL flagged for "variables", "constraints", "objectives", and the "problem". We blindly parse that into the corresponding attributes (
|
||
|
|
||
| def load_vars(self, vars_to_load: Optional[Sequence[VarData]] = None) -> None: | ||
| def load_vars( | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None | ||
| ) -> None: | ||
| if solution_id is not None: | ||
| raise ValueError(f'{self.__class__.__name__} does not support solution_id') | ||
| if vars_to_load is not None: | ||
| # If we are given a list of variables to load, it is easiest | ||
| # to use the filtering in get_primals and then just set | ||
| # to use the filtering in get_vars and then just set | ||
| # those values. | ||
| for var, val in self.get_primals(vars_to_load).items(): | ||
| for var, val in self.get_vars(vars_to_load).items(): | ||
| var.set_value(val, skip_validation=True) | ||
| StaleFlagManager.mark_all_as_stale(delayed=True) | ||
| return | ||
|
|
@@ -92,9 +113,11 @@ def load_vars(self, vars_to_load: Optional[Sequence[VarData]] = None) -> None: | |
|
|
||
| StaleFlagManager.mark_all_as_stale(delayed=True) | ||
|
|
||
| def get_primals( | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None | ||
| def get_vars( | ||
| self, vars_to_load: Optional[Sequence[VarData]] = None, solution_id=None | ||
| ) -> Mapping[VarData, float]: | ||
| if solution_id is not None: | ||
| raise ValueError(f'{self.__class__.__name__} does not support solution_id') | ||
| result = ComponentMap() | ||
| if not self._sol_data.primals: | ||
| # SOL file contained no primal values | ||
|
|
@@ -139,8 +162,10 @@ def get_primals( | |
| return result | ||
|
|
||
| def get_duals( | ||
| self, cons_to_load: Optional[Sequence[ConstraintData]] = None | ||
| self, cons_to_load: Optional[Sequence[ConstraintData]] = None, solution_id=None | ||
| ) -> dict[ConstraintData, float]: | ||
| if solution_id is not None: | ||
| raise ValueError(f'{self.__class__.__name__} does not support solution_id') | ||
| if len(self._nl_info.eliminated_vars) > 0: | ||
| raise MouseTrap( | ||
| 'Complete duals are not available when variables have ' | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some questions: