Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,11 @@ python -m chebifier predict --help
You can also use the package programmatically:

```python
from chebifier.ensemble.base_ensemble import BaseEnsemble
import yaml
from chebifier import BaseEnsemble

# Load configuration from YAML file
with open('configs/example_config.yml', 'r') as f:
config = yaml.safe_load(f)

# Instantiate ensemble model
ensemble = BaseEnsemble(config)
# Instantiate ensemble model. If desired, can pass
# a path to a configuration, like 'configs/example_config.yml'
ensemble = BaseEnsemble()

# Make predictions
smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]
Expand Down
9 changes: 7 additions & 2 deletions chebifier/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Note: The top-level package __init__.py runs only once,
# even if multiple subpackages are imported later.

from ._custom_cache import PerSmilesPerModelLRUCache
from ._custom_cache import PerSmilesPerModelLRUCache, modelwise_smiles_lru_cache
from .ensemble.base_ensemble import BaseEnsemble

modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)
__all__ = [
"BaseEnsemble",
"PerSmilesPerModelLRUCache",
"modelwise_smiles_lru_cache",
]
8 changes: 8 additions & 0 deletions chebifier/_custom_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from functools import wraps
from typing import Any, Callable

__all__ = [
"PerSmilesPerModelLRUCache",
"modelwise_smiles_lru_cache",
]


class PerSmilesPerModelLRUCache:
"""
Expand Down Expand Up @@ -206,3 +211,6 @@ def _load_cache(self) -> None:
self._cache = loaded
except Exception as e:
print(f"[Cache Load Error] {e}")


modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)
38 changes: 1 addition & 37 deletions chebifier/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import importlib.resources

import click
import yaml

from chebifier.model_registry import ENSEMBLES

Expand Down Expand Up @@ -72,43 +69,10 @@ def predict(
resolve_inconsistencies=True,
):
"""Predict ChEBI classes for SMILES strings using an ensemble model."""
# Load configuration from YAML file
if not ensemble_config:
print("Using default ensemble configuration")
with (
importlib.resources.files("chebifier")
.joinpath("ensemble.yml")
.open("r") as f
):
config = yaml.safe_load(f)
else:
print(f"Loading ensemble configuration from {ensemble_config}")
with open(ensemble_config, "r") as f:
config = yaml.safe_load(f)

with (
importlib.resources.files("chebifier")
.joinpath("model_registry.yml")
.open("r") as f
):
model_registry = yaml.safe_load(f)

new_config = {}
for model_name, entry in config.items():
if "load_model" in entry:
if entry["load_model"] not in model_registry:
raise ValueError(
f"Model {entry['load_model']} not found in model registry. "
f"Available models are: {','.join(model_registry.keys())}."
)
new_config[model_name] = {**model_registry[entry["load_model"]], **entry}
else:
new_config[model_name] = entry
config = new_config

# Instantiate ensemble model
ensemble = ENSEMBLES[ensemble_type](
config,
ensemble_config,
chebi_version=chebi_version,
resolve_inconsistencies=resolve_inconsistencies,
)
Expand Down
34 changes: 31 additions & 3 deletions chebifier/ensemble/base_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,60 @@
import importlib
import os
import time
from pathlib import Path
from typing import Union

import torch
import tqdm
import yaml

from chebifier.check_env import check_package_installed
from chebifier.hugging_face import download_model_files
from chebifier.inconsistency_resolution import PredictionSmoother
from chebifier.prediction_models.base_predictor import BasePredictor
from chebifier.utils import get_disjoint_files, load_chebi_graph
from chebifier.utils import (
get_default_configs,
get_disjoint_files,
load_chebi_graph,
process_config,
)


class BaseEnsemble:
def __init__(
self,
model_configs: dict,
model_configs: Union[str, Path, dict, None] = None,
chebi_version: int = 241,
resolve_inconsistencies: bool = True,
):
# Deferred Import: To avoid circular import error
from chebifier.model_registry import MODEL_TYPES

# Load configuration from YAML file
if not model_configs:
config = get_default_configs()
elif isinstance(model_configs, dict):
config = model_configs
else:
print(f"Loading ensemble configuration from {model_configs}")
with open(model_configs, "r") as f:
config = yaml.safe_load(f)

with (
importlib.resources.files("chebifier")
.joinpath("model_registry.yml")
.open("r") as f
):
model_registry = yaml.safe_load(f)

processed_configs = process_config(config, model_registry)

self.chebi_graph = load_chebi_graph()
self.disjoint_files = get_disjoint_files()

self.models = []
self.positive_prediction_threshold = 0.5
for model_name, model_config in model_configs.items():
for model_name, model_config in processed_configs.items():
model_cls = MODEL_TYPES[model_config["type"]]
if "hugging_face" in model_config:
hugging_face_kwargs = download_model_files(model_config["hugging_face"])
Expand Down
3 changes: 2 additions & 1 deletion chebifier/inconsistency_resolution.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import csv
import os
import torch
from pathlib import Path

import torch


def get_disjoint_groups(disjoint_files):
if disjoint_files is None:
Expand Down
4 changes: 2 additions & 2 deletions chebifier/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
WMVwithPPVNPVEnsemble,
)
from chebifier.prediction_models import (
ChEBILookupPredictor,
ChemlogPeptidesPredictor,
ElectraPredictor,
ResGatedPredictor,
ChEBILookupPredictor,
)
from chebifier.prediction_models.c3p_predictor import C3PPredictor
from chebifier.prediction_models.chemlog_predictor import (
ChemlogXMolecularEntityPredictor,
ChemlogOrganoXCompoundPredictor,
ChemlogXMolecularEntityPredictor,
)

ENSEMBLES = {
Expand Down
2 changes: 1 addition & 1 deletion chebifier/prediction_models/base_predictor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from abc import ABC

from chebifier import modelwise_smiles_lru_cache
from .._custom_cache import modelwise_smiles_lru_cache


class BasePredictor(ABC):
Expand Down
2 changes: 1 addition & 1 deletion chebifier/prediction_models/chemlog_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import tqdm

from .base_predictor import BasePredictor
from .. import modelwise_smiles_lru_cache
from .base_predictor import BasePredictor

AA_DICT = {
"A": "L-alanine",
Expand Down
37 changes: 29 additions & 8 deletions chebifier/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import importlib.resources
import os
import pickle

import fastobo
import networkx as nx
import requests
import fastobo
import yaml

from chebifier.hugging_face import download_model_files
import pickle


def load_chebi_graph(filename=None):
Expand Down Expand Up @@ -123,9 +126,27 @@ def get_disjoint_files():
return disjoint_files


if __name__ == "__main__":
# chebi_graph = build_chebi_graph(chebi_version=241)
# save the graph to a file
# pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb"))
chebi_graph = load_chebi_graph()
print(chebi_graph)
def get_default_configs():
default_config_name = "ensemble.yml"
print(f"Using default ensemble configuration from {default_config_name}")
with (
importlib.resources.files("chebifier")
.joinpath(default_config_name)
.open("r") as f
):
return yaml.safe_load(f)


def process_config(config, model_registry):
new_config = {}
for model_name, entry in config.items():
if "load_model" in entry:
if entry["load_model"] not in model_registry:
raise ValueError(
f"Model {entry['load_model']} not found in model registry. "
f"Available models are: {','.join(model_registry.keys())}."
)
new_config[model_name] = {**model_registry[entry["load_model"]], **entry}
else:
new_config[model_name] = entry
return new_config