Skip to content

Commit ecad794

Browse files
authored
Refactor/namings (#76)
* `module_type` -> `module_name` * `from_datasets` -> `from_hub` * stage progress on `prediction` -> `decision` * stage progress on `Predictor` -> `Decision` * finish renaming to decision * stage progress on `retrieval` -> `embedding`
1 parent 3fcf43c commit ecad794

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+444
-456
lines changed
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
# TODO: make up a better and more versatile config
2-
- node_type: retrieval
2+
- node_type: embedding
33
metric: retrieval_hit_rate
44
search_space:
5-
- module_type: vector_db
5+
- module_name: retrieval
66
k: [10]
77
embedder_name:
88
- avsolatorio/GIST-small-Embedding-v0
99
- infgrad/stella-base-en-v2
1010
- node_type: scoring
1111
metric: scoring_roc_auc
1212
search_space:
13-
- module_type: knn
13+
- module_name: knn
1414
k: [1, 3, 5, 10]
1515
weights: ["uniform", "distance", "closest"]
16-
- module_type: linear
17-
- module_type: dnnc
16+
- module_name: linear
17+
- module_name: dnnc
1818
cross_encoder_name:
1919
- BAAI/bge-reranker-base
2020
- cross-encoder/ms-marco-MiniLM-L-6-v2
2121
k: [1, 3, 5, 10]
22-
- node_type: prediction
23-
metric: prediction_accuracy
22+
- node_type: decision
23+
metric: decision_accuracy
2424
search_space:
25-
- module_type: threshold
25+
- module_name: threshold
2626
thresh: [0.5]
27-
- module_type: argmax
27+
- module_name: argmax
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
# TODO: make up a better and more versatile config
2-
- node_type: retrieval
2+
- node_type: embedding
33
metric: retrieval_hit_rate_intersecting
44
search_space:
5-
- module_type: vector_db
5+
- module_name: retrieval
66
k: [10]
77
embedder_name:
88
- deepvk/USER-bge-m3
99
- node_type: scoring
1010
metric: scoring_roc_auc
1111
search_space:
12-
- module_type: knn
12+
- module_name: knn
1313
k: [3]
1414
weights: ["uniform", "distance", "closest"]
15-
- module_type: linear
16-
- node_type: prediction
17-
metric: prediction_accuracy
15+
- module_name: linear
16+
- node_type: decision
17+
metric: decision_accuracy
1818
search_space:
19-
- module_type: threshold
19+
- module_name: threshold
2020
thresh: [0.5]
21-
- module_type: adaptive
21+
- module_name: adaptive
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
- node_type: retrieval
2-
module_type: vector_db
1+
- node_type: embedding
2+
module_name: retrieval
33
module_config:
44
k: 10
55
model_name: infgrad/stella-base-en-v2
66
load_path: .
77
- node_type: scoring
8-
module_type: knn
8+
module_name: knn
99
module_config:
1010
k: 10
1111
weights: uniform
1212
load_path: .
13-
- node_type: prediction
14-
module_type: threshold
13+
- node_type: decision
14+
module_name: threshold
1515
module_config:
1616
thresh: 0.5
1717
load_path: .

autointent/_dataset/_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
9898
return DictReader().read(mapping)
9999

100100
@classmethod
101-
def from_datasets(cls, repo_id: str) -> "Dataset":
101+
def from_hub(cls, repo_id: str) -> "Dataset":
102102
"""
103103
Load a dataset from a Hugging Face repository.
104104

autointent/_pipeline/_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
179179
raise RuntimeError(msg)
180180

181181
scores = self.nodes[NodeType.scoring].module.predict(utterances) # type: ignore[union-attr]
182-
return self.nodes[NodeType.prediction].module.predict(scores) # type: ignore[union-attr]
182+
return self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr]
183183

184184
def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
185185
"""
@@ -193,7 +193,7 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
193193
raise RuntimeError(msg)
194194

195195
scores, scores_metadata = self.nodes[NodeType.scoring].module.predict_with_metadata(utterances) # type: ignore[union-attr]
196-
predictions = self.nodes[NodeType.prediction].module.predict(scores) # type: ignore[union-attr]
196+
predictions = self.nodes[NodeType.decision].module.predict(scores) # type: ignore[union-attr]
197197
regexp_predictions, regexp_predictions_metadata = None, None
198198
if NodeType.regexp in self.nodes:
199199
regexp_predictions, regexp_predictions_metadata = self.nodes[NodeType.regexp].module.predict_with_metadata( # type: ignore[union-attr]

autointent/configs/_inference_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class InferenceNodeConfig:
1212

1313
node_type: NodeType
1414
"""Type of the node. Should be one of the NODE_TYPES"""
15-
module_type: str
15+
module_name: str
1616
"""Type of the module. Should be one of the Module"""
1717
module_config: dict[str, Any]
1818
"""Configuration of the module"""

autointent/context/_context.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ._utils import NumpyEncoder, load_data
1919
from .data_handler import DataHandler
2020
from .optimization_info import OptimizationInfo
21-
from .vector_index_client import VectorIndex, VectorIndexClient
21+
from .vector_index_client import VectorIndexClient
2222

2323

2424
class Context:
@@ -96,15 +96,6 @@ def set_dataset(self, dataset: Dataset, force_multilabel: bool = False) -> None:
9696
random_seed=self.seed,
9797
)
9898

99-
def get_best_index(self) -> VectorIndex:
100-
"""
101-
Retrieve the best vector index based on optimization results.
102-
103-
:return: Best vector index object.
104-
"""
105-
model_name = self.optimization_info.get_best_embedder()
106-
return self.vector_index_client.get_index(model_name)
107-
10899
def get_inference_config(self) -> dict[str, Any]:
109100
"""
110101
Generate configuration settings for inference.
@@ -237,5 +228,5 @@ def has_saved_modules(self) -> bool:
237228
238229
:return: True if there are saved modules, False otherwise.
239230
"""
240-
node_types = ["regexp", "retrieval", "scoring", "prediction"]
231+
node_types = ["regexp", "embedding", "scoring", "decision"]
241232
return any(len(self.optimization_info.modules.get(nt)) > 0 for nt in node_types)

autointent/context/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def load_data(filepath: str | Path) -> Dataset:
5454
:return: A `Dataset` object containing the loaded data.
5555
"""
5656
if filepath == "default-multiclass":
57-
return Dataset.from_datasets("AutoIntent/clinc150_subset")
57+
return Dataset.from_hub("AutoIntent/clinc150_subset")
5858
if filepath == "default-multilabel":
59-
return Dataset.from_datasets("AutoIntent/clinc150_subset").to_multilabel().encode_labels()
59+
return Dataset.from_hub("AutoIntent/clinc150_subset").to_multilabel().encode_labels()
6060
if not Path(filepath).exists():
61-
return Dataset.from_datasets(str(filepath))
61+
return Dataset.from_hub(str(filepath))
6262
return Dataset.from_json(filepath)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._data_models import Artifact, PredictorArtifact, RetrieverArtifact, ScorerArtifact
1+
from ._data_models import Artifact, DecisionArtifact, RetrieverArtifact, ScorerArtifact
22
from ._optimization_info import OptimizationInfo
33

4-
__all__ = ["Artifact", "OptimizationInfo", "PredictorArtifact", "RetrieverArtifact", "ScorerArtifact"]
4+
__all__ = ["Artifact", "DecisionArtifact", "OptimizationInfo", "RetrieverArtifact", "ScorerArtifact"]

autointent/context/optimization_info/_data_models.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class RegexpArtifact(Artifact):
2323

2424
class RetrieverArtifact(Artifact):
2525
"""
26-
Artifact containing details from the retrieval node.
26+
Artifact containing details from the embedding node.
2727
28-
Name of the embedding model chosen after retrieval optimization.
28+
Name of the embedding model chosen after embedding optimization.
2929
"""
3030

3131
embedder_name: str
@@ -48,7 +48,7 @@ class ScorerArtifact(Artifact):
4848
)
4949

5050

51-
class PredictorArtifact(Artifact):
51+
class DecisionArtifact(Artifact):
5252
"""
5353
Artifact containing outputs from the predictor node.
5454
@@ -68,9 +68,9 @@ def validate_node_name(value: str) -> str:
6868
:return: Validated node type string.
6969
:raises ValueError: If the node type is invalid.
7070
"""
71-
if value in [NodeType.retrieval, NodeType.scoring, NodeType.prediction, NodeType.regexp]:
71+
if value in [NodeType.embedding, NodeType.scoring, NodeType.decision, NodeType.regexp]:
7272
return value
73-
msg = f"Unknown node_type: {value}. Expected one of ['regexp', 'retrieval', 'scoring', 'prediction']"
73+
msg = f"Unknown node_type: {value}. Expected one of ['regexp', 'embedding', 'scoring', 'decision']"
7474
raise ValueError(msg)
7575

7676

@@ -84,9 +84,9 @@ class Artifacts(BaseModel):
8484
model_config = ConfigDict(arbitrary_types_allowed=True)
8585

8686
regexp: list[RegexpArtifact] = []
87-
retrieval: list[RetrieverArtifact] = []
87+
embedding: list[RetrieverArtifact] = []
8888
scoring: list[ScorerArtifact] = []
89-
prediction: list[PredictorArtifact] = []
89+
decision: list[DecisionArtifact] = []
9090

9191
def add_artifact(self, node_type: str, artifact: Artifact) -> None:
9292
"""
@@ -120,7 +120,7 @@ def get_best_artifact(self, node_type: str, idx: int) -> Artifact:
120120
class Trial(BaseModel):
121121
"""Representation of an individual optimization trial."""
122122

123-
module_type: str
123+
module_name: str
124124
"""Type of the module being optimized."""
125125
module_params: dict[str, Any]
126126
"""Parameters of the module for the trial."""
@@ -136,9 +136,9 @@ class Trials(BaseModel):
136136
"""Container for managing optimization trials for pipeline nodes."""
137137

138138
regexp: list[Trial] = []
139-
retrieval: list[Trial] = []
139+
embedding: list[Trial] = []
140140
scoring: list[Trial] = []
141-
prediction: list[Trial] = []
141+
decision: list[Trial] = []
142142

143143
def get_trial(self, node_type: str, idx: int) -> Trial:
144144
"""
@@ -174,12 +174,12 @@ class TrialsIds(BaseModel):
174174

175175
regexp: int | None = None
176176
"""Best trial index for the regexp node."""
177-
retrieval: int | None = None
178-
"""Best trial index for the retrieval node."""
177+
embedding: int | None = None
178+
"""Best trial index for the embedding node."""
179179
scoring: int | None = None
180180
"""Best trial index for the scoring"""
181-
prediction: int | None = None
182-
"""Best trial index for the prediction node."""
181+
decision: int | None = None
182+
"""Best trial index for the decision node."""
183183

184184
def get_best_trial_idx(self, node_type: str) -> int | None:
185185
"""

0 commit comments

Comments
 (0)