Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 271 additions & 125 deletions notebooks/21_dc-resistivity-inversion-w-beta-cooling.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/inversion_ideas/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .conditions import Condition
from .directive import Directive
from .minimizer import Minimizer
from .minimizer import Minimizer, MinimizerResult
from .objective_function import Combo, Objective, Scaled
from .simulation import Simulation

Expand All @@ -13,6 +13,7 @@
"Condition",
"Directive",
"Minimizer",
"MinimizerResult",
"Objective",
"Scaled",
"Simulation",
Expand Down
38 changes: 36 additions & 2 deletions src/inversion_ideas/base/minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,50 @@
"""

from abc import ABC, abstractmethod
from collections.abc import Generator
from collections.abc import Callable, Generator

from ..typing import Model
from .objective_function import Objective


class MinimizerResult(dict):
"""
Dictionary to store results of a single minimization iteration.

This class is a child of ``dict``, but allows to access the values through
attributes.

Notes
-----
Inspired in the :class:`scipy.optimize.OptimizeResult`.
"""

def __getattr__(self, name):
try:
return self[name]
except KeyError as e:
raise AttributeError(name) from e

__setattr__ = dict.__setitem__ # type: ignore[assignment]
__delattr__ = dict.__delitem__ # type: ignore[assignment]

def __dir__(self):
return list(self.keys())


class Minimizer(ABC):
"""
Base class to represent minimizers as generators.
"""

@abstractmethod
def __call__(self, objective: Objective, initial_model: Model) -> Generator[Model]:
def __call__(
self,
objective: Objective,
initial_model: Model,
*,
callback: Callable[[MinimizerResult], None] | None = None,
) -> Generator[Model]:
"""
Minimize objective function.

Expand All @@ -25,6 +56,9 @@ def __call__(self, objective: Objective, initial_model: Model) -> Generator[Mode
Objective function to be minimized.
initial_model : (n_params) array
Initial model used to start the minimization.
callback : callable, optional
Callable that gets called after each iteration.
Takes a :class:`inversion_ideas.base.MinimizerResult` as argument.

Returns
-------
Expand Down
76 changes: 66 additions & 10 deletions src/inversion_ideas/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@
import typing
from collections.abc import Callable

from rich.console import Group, RenderableType
from rich.live import Live
from rich.panel import Panel
from rich.spinner import Spinner
from rich.tree import Tree

from .base import Condition, Directive, Minimizer, Objective
from .inversion_log import InversionLog, InversionLogRich
from .typing import Model
from .inversion_log import InversionLog, InversionLogRich, MinimizerLog
from .typing import Log, Model
from .utils import get_logger


Expand Down Expand Up @@ -40,11 +46,15 @@ class Inversion:
no limit on the total amount of iterations.
cache_models : bool, optional
Whether to cache each model after each iteration.
log : InversionLog or bool, optional
Instance of :class:`InversionLog` to store information about the inversion.
log : Log or bool, optional
Instance of :class:`InversionLog` to store information about the inversion,
or any object that follows the :class:`inversion_ideas.typing.Log` protocol.
If `True`, a default :class:`InversionLog` is going to be used.
If `False`, no log will be assigned to the inversion, and :attr:`Inversion.log`
will be ``None``.
log_minimizers : bool, optional
Whether to log the minimizers or not. Logging minimizers is only possible when
the ``minimizer`` is an instance of :class:`inversion_ideas.base.Minimizer``.
minimizer_kwargs : dict, optional
Extra arguments that will be passed to the ``minimizer`` when called.
"""
Expand All @@ -59,7 +69,8 @@ def __init__(
stopping_criteria: Condition | Callable[[Model], bool],
max_iterations: int | None = None,
cache_models=False,
log: "InversionLog | bool" = True,
log: Log | InversionLog | bool = True,
log_minimizers: bool = True,
minimizer_kwargs: dict | None = None,
):
self.objective_function = objective_function
Expand All @@ -72,6 +83,7 @@ def __init__(
if minimizer_kwargs is None:
minimizer_kwargs = {}
self.minimizer_kwargs = minimizer_kwargs
self.log_minimizers = log_minimizers

# Assign log
if log is False:
Expand All @@ -86,6 +98,11 @@ def __init__(
# Assign model as a copy of the initial model
self.model = initial_model.copy()

# TODO: Support for handling custom callbacks for the minimizer
if log is not None and "callback" in self.minimizer_kwargs:
msg = "Passing a custom callback for the minimizer is not yet supported."
raise NotImplementedError(msg)

def __next__(self):
"""
Run next iteration in the inversion.
Expand Down Expand Up @@ -137,10 +154,20 @@ def __next__(self):
directive(self.model, self.counter)

# Minimize objective function
# ---------------------------
if isinstance(self.minimizer, Minimizer):
# Keep only the last model of the minimizer iterator
# Generate a new minimizer log for this iteration
minimizer_kwargs = self.minimizer_kwargs.copy()
if self.log is not None:
minimizer_log = MinimizerLog()
self.minimizer_logs.append(minimizer_log)
minimizer_kwargs["callback"] = minimizer_log.update

# Unpack the generator and keep only the last model
*_, model = self.minimizer(
self.objective_function, self.model, **self.minimizer_kwargs
self.objective_function,
self.model,
**minimizer_kwargs,
)
else:
model = self.minimizer(
Expand Down Expand Up @@ -185,6 +212,17 @@ def models(self) -> list:
self._models = [self.initial_model]
return self._models

@property
def minimizer_logs(self) -> list[None | MinimizerLog] | None:
"""
Logs of minimizers.
"""
if not self.log_minimizers:
return None
if not hasattr(self, "_minimizer_logs"):
self._minimizer_logs = [None]
return self._minimizer_logs

def run(self, show_log=True) -> Model:
"""
Run the inversion.
Expand All @@ -195,11 +233,29 @@ def run(self, show_log=True) -> Model:
Whether to show the ``log`` (if it's defined) during the inversion.
"""
if show_log and self.log is not None:
if not hasattr(self.log, "live"):
if not isinstance(self.log, RenderableType):
raise NotImplementedError()
with self.log.live() as live:

spinner = Spinner(
name="dots", text="Starting inversion...", style="green", speed=1
)
log = Tree(self.log) if self.log_minimizers else self.log
group = Group(log, spinner)

with Live(group, refresh_per_second=10) as live:
for _ in self:
live.refresh()
if self.minimizer_logs is not None:
minimizer_log = self.minimizer_logs[self.counter]
if minimizer_log is not None:
renderable = minimizer_log.__rich__()
renderable.title = (
f"Minimizer log for iteration {self.counter}"
)
log.add(renderable)
spinner.text = f"Running iteration {self.counter + 1}..."
group.renderables.pop(-1)
live.refresh()

else:
for _ in self:
pass
Expand Down
Loading
Loading