Skip to content

Commit 7f0fdc2

Browse files
committed
auto-generate problem
1 parent f12ff0b commit 7f0fdc2

File tree

3 files changed

+102
-6
lines changed

3 files changed

+102
-6
lines changed

autoemulate/compare.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from autoemulate.plotting import _plot_model
2121
from autoemulate.printing import _print_setup
2222
from autoemulate.save import ModelSerialiser
23+
from autoemulate.sensitivity_analysis import plot_sensitivity_analysis
24+
from autoemulate.sensitivity_analysis import sensitivity_analysis
2325
from autoemulate.utils import _ensure_2d
2426
from autoemulate.utils import _get_full_model_name
2527
from autoemulate.utils import _redirect_warnings
@@ -523,3 +525,64 @@ def plot_eval(
523525
)
524526

525527
return fig
528+
529+
def sensitivity_analysis(
530+
self, model=None, problem=None, N=1024, conf_level=0.95, as_df=True
531+
):
532+
"""Perform Sobol sensitivity analysis on a fitted emulator.
533+
534+
Parameters
535+
----------
536+
model : object, optional
537+
Fitted model. If None, uses the best model from cross-validation.
538+
problem : dict, optional
539+
The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'.
540+
If None, the problem is generated from X using minimum and maximum values of the features as bounds.
541+
542+
Example:
543+
```python
544+
problem = {
545+
"num_vars": 2,
546+
"names": ["x1", "x2"],
547+
"bounds": [[0, 1], [0, 1]],
548+
}
549+
```
550+
N : int, optional
551+
Number of samples to generate. Default is 1024.
552+
conf_level : float, optional
553+
Confidence level for the confidence intervals. Default is 0.95.
554+
as_df : bool, optional
555+
If True, return a long-format pandas DataFrame (default is True).
556+
"""
557+
if model is None:
558+
if not hasattr(self, "best_model"):
559+
raise RuntimeError("Must run compare() before sensitivity_analysis()")
560+
model = self.best_model
561+
self.logger.info(
562+
f"No model provided, using best model {get_model_name(model)} from cross-validation for sensitivity analysis"
563+
)
564+
565+
Si = sensitivity_analysis(model, problem, self.X, N, conf_level, as_df)
566+
return Si
567+
568+
def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=None):
569+
"""
570+
Plot the sensitivity analysis results.
571+
572+
Parameters:
573+
-----------
574+
results : pd.DataFrame
575+
The results from sobol_results_to_df.
576+
index : str, default "S1"
577+
The type of sensitivity index to plot.
578+
- "S1": first-order indices
579+
- "S2": second-order/interaction indices
580+
- "ST": total-order indices
581+
n_cols : int, optional
582+
The number of columns in the plot. Defaults to 3 if there are 3 or more outputs,
583+
otherwise the number of outputs.
584+
figsize : tuple, optional
585+
Figure size as (width, height) in inches.If None, automatically calculated.
586+
587+
"""
588+
return plot_sensitivity_analysis(results, index, n_cols, figsize)

autoemulate/sensitivity_analysis.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from autoemulate.utils import _ensure_2d
88

99

10-
def sensitivity_analysis(model, problem, N=1024, conf_level=0.95, as_df=True):
10+
def sensitivity_analysis(
11+
model, problem=None, X=None, N=1024, conf_level=0.95, as_df=True
12+
):
1113
"""Perform Sobol sensitivity analysis on a fitted emulator.
1214
1315
Parameters:
@@ -39,7 +41,7 @@ def sensitivity_analysis(model, problem, N=1024, conf_level=0.95, as_df=True):
3941
containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry
4042
is a list of length corresponding to the number of parameters.
4143
"""
42-
Si = sobol_analysis(model, problem, N, conf_level)
44+
Si = sobol_analysis(model, problem, X, N, conf_level)
4345

4446
if as_df:
4547
return sobol_results_to_df(Si)
@@ -85,7 +87,21 @@ def _get_output_names(problem, num_outputs):
8587
return output_names
8688

8789

88-
def sobol_analysis(model, problem, N=1024, conf_level=0.95):
90+
def _generate_problem(X):
91+
"""
92+
Generate a problem definition from a design matrix.
93+
"""
94+
if X.ndim == 1:
95+
raise ValueError("X must be a 2D array.")
96+
97+
return {
98+
"num_vars": X.shape[1],
99+
"names": [f"x{i+1}" for i in range(X.shape[1])],
100+
"bounds": [[X[:, i].min(), X[:, i].max()] for i in range(X.shape[1])],
101+
}
102+
103+
104+
def sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
89105
"""
90106
Perform Sobol sensitivity analysis on a fitted emulator.
91107
@@ -105,8 +121,13 @@ def sobol_analysis(model, problem, N=1024, conf_level=0.95):
105121
containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry
106122
is a list of length corresponding to the number of parameters.
107123
"""
108-
# correctly defined?
109-
problem = _check_problem(problem)
124+
# get problem
125+
if problem is not None:
126+
problem = _check_problem(problem)
127+
elif X is not None:
128+
problem = _generate_problem(X)
129+
else:
130+
raise ValueError("Either problem or X must be provided.")
110131

111132
# saltelli sampling
112133
param_values = sample(problem, N)
@@ -240,7 +261,7 @@ def plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
240261
Figure size as (width, height) in inches.If None, automatically calculated.
241262
242263
"""
243-
with plt.style.context("seaborn-v0_8-whitegrid"):
264+
with plt.style.context("fast"):
244265
# prepare data
245266
results = _validate_input(results, index)
246267
unique_outputs = results["output"].unique()

tests/test_sensitivity_analysis.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from autoemulate.experimental_design import LatinHypercube
88
from autoemulate.sensitivity_analysis import _calculate_layout
99
from autoemulate.sensitivity_analysis import _check_problem
10+
from autoemulate.sensitivity_analysis import _generate_problem
1011
from autoemulate.sensitivity_analysis import _get_output_names
1112
from autoemulate.sensitivity_analysis import _validate_input
1213
from autoemulate.sensitivity_analysis import sobol_analysis
@@ -192,3 +193,14 @@ def test_calculate_layout_custom():
192193
n_rows, n_cols = _calculate_layout(3, 2)
193194
assert n_rows == 2
194195
assert n_cols == 2
196+
197+
198+
# test _generate_problem -----------------------------------------------------
199+
200+
201+
def test_generate_problem():
202+
X = np.array([[0, 0], [1, 1], [2, 2]])
203+
problem = _generate_problem(X)
204+
assert problem["num_vars"] == 2
205+
assert problem["names"] == ["x1", "x2"]
206+
assert problem["bounds"] == [[0, 2], [0, 2]]

0 commit comments

Comments
 (0)