Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ process:
redis_address: 'redis://localhost:6379' # the address of redis server
lowercase: false # whether to convert text to lower case
ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece]
window_size: 5 # window size of shingling
num_permutations: 256 # number of permutations in minhash computing
Expand All @@ -956,6 +956,16 @@ process:
tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication.

# Selector ops
- domain_diversity_selector: # selector to select samples based on the data's domain diversity
api_or_hf_model: 'text-embedding-v3' # API or huggingface embedding model name
is_hf_model: False # indicates if the model is from HuggingFace
api_endpoint: '/embeddings' # embedding URL endpoint for the API
response_path: 'data.0.embedding' # path to extract content from the API response
model_params: {} # parameters for initializing the API model
select_ratio: # the ratio to be sampled
init_k: 3 # the value of k in k-means algorithm
ebd_dim: 512 # the embedding's dimension via API
strategy: 'inter' # the selection strategy based on the relation across domains
- frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value
field_key: '' # the target keys corresponding to multi-level field information need to be separated by '.'
top_ratio: # ratio of selected top specified field value
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/ops/selector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .domain_diversity_selector import DomainDiversitySelector
from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector
from .random_selector import RandomSelector
from .range_specified_field_selector import RangeSpecifiedFieldSelector
from .tags_specified_field_selector import TagsSpecifiedFieldSelector
from .topk_specified_field_selector import TopkSpecifiedFieldSelector

__all__ = [
"DomainDiversitySelector",
"FrequencySpecifiedFieldSelector",
"RandomSelector",
"RangeSpecifiedFieldSelector",
Expand Down
187 changes: 187 additions & 0 deletions data_juicer/ops/selector/domain_diversity_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from typing import Dict, Optional

import numpy as np
from pydantic import Field, PositiveInt
from sklearn.cluster import KMeans
from tqdm import tqdm
from typing_extensions import Annotated

from data_juicer.ops.base_op import OPERATORS, Selector
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

torch = LazyLoader("torch")


@OPERATORS.register_module("domain_diversity_selector")
class DomainDiversitySelector(Selector):
"""Selector to select samples based on the data's domain diversity."""

_accelerator = "cuda"

def __init__(
self,
api_or_hf_model: str = "text-embedding-v3",
is_hf_model: bool = False,
api_endpoint: str = "/embeddings",
response_path: str = "data.0.embedding",
model_params: Dict = {},
select_ratio: Optional[Annotated[float, Field(ge=0, le=1)]] = None,
init_k: PositiveInt = 3,
ebd_dim: PositiveInt = 512,
strategy: str = "inter",
*args,
**kwargs,
):
"""
Initialization method.

:param api_or_hf_model: API or huggingface embedding model name.
:param is_hf_model: Indicates if the model is from HuggingFace.
:param api_endpoint: Embedding URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'data.0.embedding' for embedding model.
:param model_params: Parameters for initializing the API model.
:param select_ratio: The ratio to select.
:param init_k: The value of k in k-means algorithm.
:param ebd_dim: The embedding's dimension via API.
:param strategy: 'inter' - Domain's inter diversity,
'intra' - Domain's intra diversity,
'global' - Diversity to global centroid.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.api_or_hf_model = api_or_hf_model
self.is_hf_model = is_hf_model
self.api_endpoint = api_endpoint
self.response_path = response_path
self.select_ratio = select_ratio
self.init_k = init_k
self.ebd_dim = ebd_dim
self.strategy = strategy

if is_hf_model:
self.model_key = prepare_model(
model_type="embedding", model_path=api_or_hf_model, trust_remote_code=True, **model_params
)
else:
self.model_key = prepare_model(
model_type="api",
model=api_or_hf_model,
endpoint=self.api_endpoint,
response_path=self.response_path,
**model_params,
)

def dataset_embedding(self, dataset, rank=None):
embeddings = []
model = get_model(self.model_key, rank, self.use_cuda())

if self.is_hf_model:
# Embeddings extract via local models
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
with torch.no_grad():
embedding = model.encode(text)
embeddings.append(embedding)
else:
# Embeddings extract via API
for sample in tqdm(dataset, desc="Embedding", unit="sample"):
text = sample["text"]
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(embedding)

embeddings = np.array(embeddings)
return embeddings

def domain_diversity_status(self, dataset):

data_status = []

embeddings_array = self.dataset_embedding(dataset)
global_centroid = np.mean(embeddings_array, axis=0)

# K-means cluster
kmeans = KMeans(n_clusters=self.init_k, random_state=42)
labels = kmeans.fit_predict(embeddings_array)

centroid_embeddings = []
for label in np.unique(labels):
label_embeddings = embeddings_array[labels == label]
centroid = np.mean(label_embeddings, axis=0)
centroid_embeddings.append(centroid)

centroid_embeddings = np.array(centroid_embeddings)

# Sample-level cos-similarity to other centroids
for i, entry in tqdm(enumerate(dataset), total=len(dataset), desc="Calculating similarity:"):
current_embedding = embeddings_array[i]
current_label = int(labels[i])

similarities = []
for j, centroid in enumerate(centroid_embeddings):
if j != current_label:
similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0), torch.tensor(centroid).unsqueeze(0)
).item()
similarities.append(similarity)

own_centroid_similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0),
torch.tensor(centroid_embeddings[current_label]).unsqueeze(0),
).item()

global_centroid_similarity = torch.nn.functional.cosine_similarity(
torch.tensor(current_embedding).unsqueeze(0), torch.tensor(global_centroid).unsqueeze(0)
).item()
total_similarity = sum(similarities)

data_status.append(
{
"text": entry["text"],
"label": current_label,
"similarity_with_other_centroids": similarities,
"total_similarity": total_similarity,
"similarity_with_own_centroid": own_centroid_similarity,
"global_centroid_similarity": global_centroid_similarity,
"original_index": i,
}
)

return data_status, labels

def diversity_process(self, dataset):
data_status, labels = self.domain_diversity_status(dataset)
select_indices = []

for label in np.unique(labels):
label_data_status = [item for item in data_status if item["label"] == label]

# Related to the strategy
if self.strategy == "inter":
label_data_status.sort(key=lambda x: x["total_similarity"])
elif self.strategy == "intra":
label_data_status.sort(key=lambda x: x["similarity_with_own_centroid"], reverse=True)
elif self.strategy == "global":
label_data_status.sort(key=lambda x: x["global_centroid_similarity"])
else:
raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.")

num_to_select = max(1, int(self.select_ratio * len(label_data_status)))
selected_indices = [item["original_index"] for item in label_data_status[:num_to_select]]
select_indices.extend(selected_indices)

select_dataset = dataset.select(select_indices)

return select_dataset

def process(self, dataset):

if len(dataset) <= 1:
return dataset
if self.select_ratio is None:
return dataset

select_dataset = self.diversity_process(dataset)
return select_dataset
Loading
Loading