Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,16 @@ mattergen-generate $RESULTS_PATH --pretrained-name=$MODEL_NAME --batch_size=16 -

Once you have generated a list of structures contained in `$RESULTS_PATH` (either using MatterGen or another method), you can relax the structures using the default MatterSim machine learning force field (see [repository](https://github.com/microsoft/mattersim)) and compute novelty, uniqueness, stability (using energy estimated by MatterSim), and other metrics via the following command:
```bash
git lfs pull -I data-release/alex-mp/reference_MP2020correction.gz --exclude="" # first download the reference dataset from Git LFS
git lfs pull -I data-release/alex-mp/reference_MP2020correction.gz --exclude="" # first download the MP2020 reference dataset from Git LFS
mattergen-evaluate --structures_path=$RESULTS_PATH --relax=True --structure_matcher='disordered' --save_as="$RESULTS_PATH/metrics.json"
```

If you want to use the reference dataset while applying the TRI2024 correction scheme (recommended), instead run the following:
```bash
git lfs pull -I data-release/alex-mp/reference_TRI2024correction.gz --exclude="" # ownload the TRI2024 reference datasets
mattergen-evaluate --structures_path=$RESULTS_PATH --relax=True --structure_matcher='disordered' --save_as="$RESULTS_PATH/metrics.json" --reference_dataset_path="data-release/alex-mp/reference_TRI2024correction.gz"
```

This script will write `metrics.json` containing the metric results to `$RESULTS_PATH` and will print it to your console.
> [!IMPORTANT]
> The evaluation script in this repository uses [MatterSim](https://github.com/microsoft/mattersim), a machine-learning force field (MLFF) to relax structures and assess their stability via MatterSim's predicted energies. While this is orders of magnitude faster than evaluation via density functional theory (DFT), it doesn't require a license to run the evaluation, and typically has a high accuracy, there are important caveats. (1) In the MatterGen publication we use DFT to evaluate structures generated by all models and baselines; (2) DFT is more accurate and reliable, particularly in less common chemical systems. Thus, evaluation results obtained with this evaluation code may give different results than DFT evaluation; and we recommend to confirm results obtained with MLFFs with DFT before drawing conclusions.
Expand All @@ -146,6 +153,13 @@ This script will try to read structures from disk in the following precedence or

Here, we expect `energies.npy` to be a numpy array with the entries being `float` energies in the same order as the structures read from `$RESULTS_PATH`.

> [!IMPORTANT]
> For any task beyond benchmarking against existing literature, we recommend using the TRI2024 correction scheme and reference dataset. To do so, run:
```bash
git lfs pull -I data-release/alex-mp/reference_TRI2024correction.gz --exclude="" # first download the reference dataset from Git LFS
mattergen-evaluate --structures_path=$RESULTS_PATH --energies_path='energies.npy' --relax=False --structure_matcher='disordered' --save_as='metrics' --energy_correction_scheme="TRI2024" --reference_dataset_path="data-release/alex-mp/reference_TRI2024correction.gz"
```

If you want to save the relaxed structures, toghether with their energies, forces, and stresses, add `--structures_output_path=YOUR_PATH` to the script call, like so:
```bash
mattergen-evaluate --structures_path=$RESULTS_PATH --relax=True --structure_matcher='disordered' --save_as='metrics' --structures_output_path="relaxed_structures.extxyz"
Expand Down Expand Up @@ -190,7 +204,8 @@ LMDBGZSerializer().serialize(reference_dataset, "path_to_file.gz")
where `entries` is a list of `pymatgen.entries.computed_entries.ComputedStructureEntry` objects containing structure-energy pairs for each structure.

By default, we apply the MaterialsProject2020Compatibility energy correction scheme to all input structures during evaluation, and assume that the reference dataset
has already been pre-processed using the same compatibility scheme. Therefore, unless you have already done this, you should obtain the `entries` object for
has already been pre-processed using the same compatibility scheme.
Therefore, unless you have already done this, you should obtain the `entries` object for
your custom reference dataset in the following way:

``` python
Expand All @@ -205,6 +220,22 @@ for structure, energy in zip(structures, energies)
))
```

> [!NOTE]
> Because of some known issues with the MaterialsProject2020Compatibility scheme, we recommend using the `TRI110Compatibility2024` reference dataset and correction scheme to evaluate stability of materials outside benchmarks.
To do so, run:
``` python
from mattergen.evaluation.utils.vasprunlike import VasprunLike
from mattergen.evaluation.reference.correction_schemes import TRI110Compatibility2024

entries = []
for structure, energy in zip(structures, energies)
vasprun_like = VasprunLike(structure=structure, energy=energy)
entries.append(vasprun_like.get_computed_entry(
inc_structure=True, energy_correction_scheme=TRI110Compatibility2024()
))
```


## Train MatterGen yourself
Before we can train MatterGen from scratch, we have to unpack and preprocess the dataset files.

Expand Down
1 change: 1 addition & 0 deletions data-release/alex-mp/.gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
alex_mp_20.zip filter=lfs diff=lfs merge=lfs -text
reference_MP2020correction.gz filter=lfs diff=lfs merge=lfs -text
reference_TRI2024correction.gz filter=lfs diff=lfs merge=lfs -text
3 changes: 3 additions & 0 deletions data-release/alex-mp/reference_TRI2024correction.gz
Git LFS file not shown
4 changes: 4 additions & 0 deletions mattergen/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

from pymatgen.core.structure import Structure
from pymatgen.entries.compatibility import Compatibility, MaterialsProject2020Compatibility

from mattergen.common.utils.globals import get_device
from mattergen.evaluation.metrics.evaluator import MetricsEvaluator
Expand All @@ -26,6 +27,7 @@ def evaluate(
potential_load_path: str | None = None,
device: str = str(get_device()),
structures_output_path: str | None = None,
energy_correction_scheme: Compatibility = MaterialsProject2020Compatibility(),
) -> dict[str, float | int]:
"""Evaluate the structures against a reference dataset.

Expand All @@ -39,6 +41,7 @@ def evaluate(
potential_load_path: Path to the Machine Learning potential to use for relaxation.
device: Device to use for relaxation.
structures_output_path: Path to save the relaxed structures.
energy_correction_scheme: Energy correction scheme to use for computing energy-based metrics. Must be compatible with the reference dataset used (e.g., MP2020correction reference dataset requires MP2020 energy correction scheme).

Returns:
metrics: a dictionary of metrics and their values.
Expand All @@ -57,6 +60,7 @@ def evaluate(
original_structures=structures,
reference=reference,
structure_matcher=structure_matcher,
energy_correction_scheme=energy_correction_scheme
)
return evaluator.compute_metrics(
metrics=evaluator.available_metrics,
Expand Down
15 changes: 15 additions & 0 deletions mattergen/evaluation/metrics/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import numpy.typing
from pandas import DataFrame
from pymatgen.analysis.phase_diagram import PhaseDiagram
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from tqdm import tqdm

from mattergen.evaluation.metrics.core import BaseAggregateMetric, BaseMetric, BaseMetricsCapability
from mattergen.evaluation.metrics.structure import StructureMetricsCapability
from mattergen.evaluation.reference.correction_schemes import TRI110Compatibility2024
from mattergen.evaluation.reference.reference_dataset import ReferenceDataset
from mattergen.evaluation.utils.globals import DEFAULT_STABILITY_THRESHOLD
from mattergen.evaluation.utils.logging import logger
Expand Down Expand Up @@ -116,6 +118,7 @@ def __init__(
self.warn_missing_data(missing_terminals)
raise MissingTerminalsError(self.missing_terminals_error_str)
super().__init__(structure_summaries=structure_summaries, n_failed_jobs=n_failed_jobs)
check_energy_correction_scheme_compatibility(reference_dataset, structure_summaries)
self.reference_dataset = reference_dataset
self.stability_threshold = stability_threshold

Expand Down Expand Up @@ -356,3 +359,15 @@ def compute_pre_aggregation_values(self) -> numpy.typing.NDArray:
& self.structure_capability.is_unique
& self.energy_capability.is_stable
)


def check_energy_correction_scheme_compatibility(reference_dataset: ReferenceDataset, structure_summaries: list[MetricsStructureSummary]):

energy_correction_schemes = set([e.name for structure_summary in structure_summaries for e in structure_summary.entry.energy_adjustments])

if reference_dataset.name == "MP2020correction":
assert all(['MP2020' in s for s in energy_correction_schemes]), "Reference dataset contains energy corrections that are not compatible with MP2020correction scheme."
elif reference_dataset.name == "TRI2024correction":
assert all(['TRI' in s for s in energy_correction_schemes]), "Reference dataset contains energy corrections that are not compatible with TRI2024correction scheme."
else:
logger.warning("Using a custom reference dataset. Make sure that the energy corrections used in the dataset are compatible with the reference dataset.")
16 changes: 12 additions & 4 deletions mattergen/evaluation/metrics/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from mattergen.evaluation.metrics.energy import EnergyMetricsCapability, MissingTerminalsError
from mattergen.evaluation.metrics.property import PropertyMetricsCapability
from mattergen.evaluation.metrics.structure import StructureMetricsCapability
from mattergen.evaluation.reference.presets import ReferenceMP2020Correction
from mattergen.evaluation.reference.correction_schemes import TRI110Compatibility2024
from mattergen.evaluation.reference.presets import (
ReferenceMP2020Correction,
ReferenceTRI2024Correction,
)
from mattergen.evaluation.reference.reference_dataset import ReferenceDataset
from mattergen.evaluation.utils.globals import DEFAULT_STABILITY_THRESHOLD
from mattergen.evaluation.utils.logging import logger
Expand Down Expand Up @@ -103,8 +107,12 @@ def from_structures_and_energies(
) -> Self:

if reference is None:
print("No reference dataset provided. Using MP2020 correction as reference.")
reference = ReferenceMP2020Correction()
if type(energy_correction_scheme) == TRI110Compatibility2024:
print("No reference dataset provided, but TRI correction scheme detected. Using TRI2024 corrected dataset as reference.")
reference = ReferenceTRI2024Correction()
else:
print("No reference dataset provided. Using MP2020 corrected dataset as reference.")
reference = ReferenceMP2020Correction()

structure_summaries = get_metrics_structure_summaries(
structures=structures,
Expand Down Expand Up @@ -136,7 +144,7 @@ def from_structure_summaries(
) -> Self:

if reference is None:
print("No reference dataset provided. Using MP2020 correction as reference.")
print("No reference dataset provided. Using MP2020 corrected dataset as reference.")
reference = ReferenceMP2020Correction()

capabilities: list[BaseMetricsCapability] = []
Expand Down
103 changes: 103 additions & 0 deletions mattergen/evaluation/reference/correction_schemes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from pymatgen.core import Element
from pymatgen.entries.compatibility import Compatibility, CompatibilityError
from pymatgen.entries.computed_entries import (
ComputedEntry,
ComputedStructureEntry,
EnergyAdjustment,
)


class IdentityCorrectionScheme(Compatibility):
"""Perform no energy correction."""

def get_adjustments(
self, entry: ComputedEntry | ComputedStructureEntry
) -> list[EnergyAdjustment]:
return []


class TRI110Compatibility2024(Compatibility):
"""This is an implementation of the correction scheme defined in

A Simple Linear Relation Solves Unphysical DFT Energy Corrections
B. A. Rohr, S. K. Suram, J. S. Bakander, ChemRxiv, 10.26434/chemrxiv-2024-q5058, (2024)

https://chemrxiv.org/engage/chemrxiv/article-details/67252d617be152b1d0b2c1ef
"""

# Compatibility.name needed for compatibility with CorrectedEntriesBuilder.process_item.
name: str = "TRI110Compatibility2024"

# See Section 2.1 of
# https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/672533a35a82cea2fac0b474/original/supplemental-information-a-simple-linear-relation-solves-unphysical-dft-energy-corrections.pdf
PBE_CORRECTION: float = 1.108

# See Table 1 of
# https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/672533a35a82cea2fac0b474/original/supplemental-information-a-simple-linear-relation-solves-unphysical-dft-energy-corrections.pdf
U_CORRECTION = {
Element("Co"): -2.275,
Element("Cr"): -2.707,
Element("Fe"): -3.189,
Element("Mn"): -2.28,
Element("Mo"): -4.93,
Element("Ni"): -3.361,
Element("V"): -2.774,
Element("W"): -6.261,
}

def get_adjustments(
self, entry: ComputedEntry | ComputedStructureEntry
) -> list[EnergyAdjustment]:
"""Get the energy adjustments for a ComputedEntry.

This method must generate a list of EnergyAdjustment objects
of the appropriate type (constant, composition-based, or temperature-based)
to be applied to the ComputedEntry, and must raise a CompatibilityError
if the entry is not compatible.

Args:
entry: A ComputedEntry object.

Returns:
list[EnergyAdjustment]: A list of EnergyAdjustment to be applied to the
Entry, which are evaluated in ComputedEntry.correction. Note that
the later implements a linear sum of corrections.

Raises:
CompatibilityError if the entry is not compatible

"""
if entry.parameters.get("run_type") not in ("GGA", "GGA+U"):
raise CompatibilityError(
f"Entry {entry.entry_id} has invalid run type {entry.parameters.get('run_type')}. "
f"Must be GGA or GGA+U. Discarding."
)

adjustments = []

if entry.parameters.get("run_type") in ["GGA", "GGA+U"]:
# multiplicative adjust for all PBE or PBE+U calculations
# energy adjustments are applied additively in downstram pymatgen code, so
# refactor multiplicate factor as an addition to uncorrected energy
adjustments.append(
EnergyAdjustment(value=entry.energy * (self.PBE_CORRECTION - 1.0), name="TRI110PBE")
)

if entry.parameters.get("run_type") == "GGA+U":
u_elements = [el for el in entry.composition if el in self.U_CORRECTION]

# number of atoms of each element
composition_dict: dict[str, float] = entry.composition.as_dict()

# eV
# EnergyAdjustment(value) expects the total energy, so we multiply the
# correction per U atom by the number of atoms of that type and not the
# fractional composition.
u_correction = sum(
[composition_dict[el.name] * self.U_CORRECTION[el] for el in u_elements]
)

# EnergyAdjustment(value) assumes total energy
adjustments.append(EnergyAdjustment(value=u_correction, name="TRI110PBE_U"))

return adjustments
25 changes: 24 additions & 1 deletion mattergen/evaluation/reference/presets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from pathlib import Path
from functools import cached_property
from pathlib import Path

from mattergen.evaluation.reference.reference_dataset import ReferenceDataset
from mattergen.evaluation.reference.reference_dataset_serializer import LMDBGZSerializer
Expand All @@ -29,3 +29,26 @@ def from_preset(cls) -> "ReferenceMP2020Correction":
def is_ordered(self) -> bool:
"""Returns True if all structures are ordered."""
return True # Setting it manually to avoid computation at runtime.


class ReferenceTRI2024Correction(ReferenceDataset):
"""Reference dataset using the TRI2024 Energy Correction scheme.
This dataset contains entries from the Materials Project [https://next-gen.materialsproject.org/]
and Alexandria [https://next-gen.materialsproject.org/].
All 845,997 structures are relaxed using the GGA-PBE functional and have energy corrections applied using the TRI2024 scheme.
"""

def __init__(self):
super().__init__("TRI2024correction", ReferenceTRI2024Correction.from_preset())

@classmethod
def from_preset(cls) -> "ReferenceTRI2024Correction":
current_dir = Path(__file__).parent
return LMDBGZSerializer().deserialize(
f"{current_dir}/../../../data-release/alex-mp/reference_TRI2024correction.gz"
)

@cached_property
def is_ordered(self) -> bool:
"""Returns True if all structures are ordered."""
return True # Setting it manually to avoid computation at runtime.
2 changes: 0 additions & 2 deletions mattergen/evaluation/utils/vasprunlike.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,3 @@ def get_computed_entry(
energy_correction_scheme.process_entry(entry)

return entry
return entry
return entry
12 changes: 11 additions & 1 deletion mattergen/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import fire
import numpy as np
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility

from mattergen.common.utils.eval_utils import load_structures
from mattergen.common.utils.globals import get_device
from mattergen.evaluation.evaluate import evaluate
from mattergen.evaluation.reference.correction_schemes import TRI2024EnergyCorrectionScheme
from mattergen.evaluation.reference.reference_dataset_serializer import LMDBGZSerializer
from mattergen.evaluation.utils.structure_matcher import (
DefaultDisorderedStructureMatcher,
Expand All @@ -30,6 +32,7 @@ def main(
reference_dataset_path: str | None = None,
device: str = str(get_device()),
structures_output_path: str | None = None,
energy_correction_scheme: Literal["MP2020", "TRI2024"] = "MP2020",
):
structures = load_structures(Path(structures_path))
energies = np.load(energies_path) if energies_path else None
Expand All @@ -41,7 +44,13 @@ def main(
reference = None
if reference_dataset_path:
reference = LMDBGZSerializer().deserialize(reference_dataset_path)


match energy_correction_scheme:
case "MP2020":
energy_correction_scheme = MaterialsProject2020Compatibility()
case "TRI2024":
energy_correction_scheme = TRI2024EnergyCorrectionScheme()

metrics = evaluate(
structures=structures,
relax=relax,
Expand All @@ -52,6 +61,7 @@ def main(
reference=reference,
device=device,
structures_output_path=structures_output_path,
energy_correction_scheme=energy_correction_scheme,
)
print(json.dumps(metrics, indent=2))

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ dependencies = [
"pytest",
"pytorch-lightning==2.0.6",
"seaborn>=0.13.2", # for plotting
"setuptools",
"setuptools<81",
"SMACT",
"sympy>=1.11.1",
"torch==2.2.1+cu118; sys_platform == 'linux'",
Expand Down Expand Up @@ -126,5 +126,5 @@ explicit = true


[build-system]
requires = ["setuptools >=77.0.3"]
requires = ["setuptools <81"]
build-backend = "setuptools.build_meta"
Loading