diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 73f39e2bd3..4c059a5bca 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -936,7 +936,7 @@ process: redis_address: 'redis://localhost:6379' # the address of redis server lowercase: false # whether to convert text to lower case ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations - - ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm + - ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] window_size: 5 # window size of shingling num_permutations: 256 # number of permutations in minhash computing @@ -956,6 +956,16 @@ process: tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication. # Selector ops + - domain_diversity_selector: # selector to select samples based on the data's domain diversity + api_or_hf_model: 'text-embedding-v3' # API or huggingface embedding model name + is_hf_model: False # indicates if the model is from HuggingFace + api_endpoint: '/embeddings' # embedding URL endpoint for the API + response_path: 'data.0.embedding' # path to extract content from the API response + model_params: {} # parameters for initializing the API model + select_ratio: # the ratio to be sampled + init_k: 3 # the value of k in k-means algorithm + ebd_dim: 512 # the embedding's dimension via API + strategy: 'inter' # the selection strategy based on the relation across domains - frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value field_key: '' # the target keys corresponding to multi-level field information need to be separated by '.' top_ratio: # ratio of selected top specified field value diff --git a/data_juicer/ops/selector/__init__.py b/data_juicer/ops/selector/__init__.py index adaa1a3419..572eb38057 100644 --- a/data_juicer/ops/selector/__init__.py +++ b/data_juicer/ops/selector/__init__.py @@ -1,3 +1,4 @@ +from .domain_diversity_selector import DomainDiversitySelector from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector from .random_selector import RandomSelector from .range_specified_field_selector import RangeSpecifiedFieldSelector @@ -5,6 +6,7 @@ from .topk_specified_field_selector import TopkSpecifiedFieldSelector __all__ = [ + "DomainDiversitySelector", "FrequencySpecifiedFieldSelector", "RandomSelector", "RangeSpecifiedFieldSelector", diff --git a/data_juicer/ops/selector/domain_diversity_selector.py b/data_juicer/ops/selector/domain_diversity_selector.py new file mode 100644 index 0000000000..e0de693ef8 --- /dev/null +++ b/data_juicer/ops/selector/domain_diversity_selector.py @@ -0,0 +1,187 @@ +from typing import Dict, Optional + +import numpy as np +from pydantic import Field, PositiveInt +from sklearn.cluster import KMeans +from tqdm import tqdm +from typing_extensions import Annotated + +from data_juicer.ops.base_op import OPERATORS, Selector +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +torch = LazyLoader("torch") + + +@OPERATORS.register_module("domain_diversity_selector") +class DomainDiversitySelector(Selector): + """Selector to select samples based on the data's domain diversity.""" + + _accelerator = "cuda" + + def __init__( + self, + api_or_hf_model: str = "text-embedding-v3", + is_hf_model: bool = False, + api_endpoint: str = "/embeddings", + response_path: str = "data.0.embedding", + model_params: Dict = {}, + select_ratio: Optional[Annotated[float, Field(ge=0, le=1)]] = None, + init_k: PositiveInt = 3, + ebd_dim: PositiveInt = 512, + strategy: str = "inter", + *args, + **kwargs, + ): + """ + Initialization method. + + :param api_or_hf_model: API or huggingface embedding model name. + :param is_hf_model: Indicates if the model is from HuggingFace. + :param api_endpoint: Embedding URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'data.0.embedding' for embedding model. + :param model_params: Parameters for initializing the API model. + :param select_ratio: The ratio to select. + :param init_k: The value of k in k-means algorithm. + :param ebd_dim: The embedding's dimension via API. + :param strategy: 'inter' - Domain's inter diversity, + 'intra' - Domain's intra diversity, + 'global' - Diversity to global centroid. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.api_or_hf_model = api_or_hf_model + self.is_hf_model = is_hf_model + self.api_endpoint = api_endpoint + self.response_path = response_path + self.select_ratio = select_ratio + self.init_k = init_k + self.ebd_dim = ebd_dim + self.strategy = strategy + + if is_hf_model: + self.model_key = prepare_model( + model_type="embedding", model_path=api_or_hf_model, trust_remote_code=True, **model_params + ) + else: + self.model_key = prepare_model( + model_type="api", + model=api_or_hf_model, + endpoint=self.api_endpoint, + response_path=self.response_path, + **model_params, + ) + + def dataset_embedding(self, dataset, rank=None): + embeddings = [] + model = get_model(self.model_key, rank, self.use_cuda()) + + if self.is_hf_model: + # Embeddings extract via local models + for sample in tqdm(dataset, desc="Embedding", unit="sample"): + text = sample["text"] + with torch.no_grad(): + embedding = model.encode(text) + embeddings.append(embedding) + else: + # Embeddings extract via API + for sample in tqdm(dataset, desc="Embedding", unit="sample"): + text = sample["text"] + embedding = model(text, dimensions=self.ebd_dim, encoding_format="float") + embeddings.append(embedding) + + embeddings = np.array(embeddings) + return embeddings + + def domain_diversity_status(self, dataset): + + data_status = [] + + embeddings_array = self.dataset_embedding(dataset) + global_centroid = np.mean(embeddings_array, axis=0) + + # K-means cluster + kmeans = KMeans(n_clusters=self.init_k, random_state=42) + labels = kmeans.fit_predict(embeddings_array) + + centroid_embeddings = [] + for label in np.unique(labels): + label_embeddings = embeddings_array[labels == label] + centroid = np.mean(label_embeddings, axis=0) + centroid_embeddings.append(centroid) + + centroid_embeddings = np.array(centroid_embeddings) + + # Sample-level cos-similarity to other centroids + for i, entry in tqdm(enumerate(dataset), total=len(dataset), desc="Calculating similarity:"): + current_embedding = embeddings_array[i] + current_label = int(labels[i]) + + similarities = [] + for j, centroid in enumerate(centroid_embeddings): + if j != current_label: + similarity = torch.nn.functional.cosine_similarity( + torch.tensor(current_embedding).unsqueeze(0), torch.tensor(centroid).unsqueeze(0) + ).item() + similarities.append(similarity) + + own_centroid_similarity = torch.nn.functional.cosine_similarity( + torch.tensor(current_embedding).unsqueeze(0), + torch.tensor(centroid_embeddings[current_label]).unsqueeze(0), + ).item() + + global_centroid_similarity = torch.nn.functional.cosine_similarity( + torch.tensor(current_embedding).unsqueeze(0), torch.tensor(global_centroid).unsqueeze(0) + ).item() + total_similarity = sum(similarities) + + data_status.append( + { + "text": entry["text"], + "label": current_label, + "similarity_with_other_centroids": similarities, + "total_similarity": total_similarity, + "similarity_with_own_centroid": own_centroid_similarity, + "global_centroid_similarity": global_centroid_similarity, + "original_index": i, + } + ) + + return data_status, labels + + def diversity_process(self, dataset): + data_status, labels = self.domain_diversity_status(dataset) + select_indices = [] + + for label in np.unique(labels): + label_data_status = [item for item in data_status if item["label"] == label] + + # Related to the strategy + if self.strategy == "inter": + label_data_status.sort(key=lambda x: x["total_similarity"]) + elif self.strategy == "intra": + label_data_status.sort(key=lambda x: x["similarity_with_own_centroid"], reverse=True) + elif self.strategy == "global": + label_data_status.sort(key=lambda x: x["global_centroid_similarity"]) + else: + raise ValueError("Invalid strategy. Use 'inter', 'intra' or 'global'.") + + num_to_select = max(1, int(self.select_ratio * len(label_data_status))) + selected_indices = [item["original_index"] for item in label_data_status[:num_to_select]] + select_indices.extend(selected_indices) + + select_dataset = dataset.select(select_indices) + + return select_dataset + + def process(self, dataset): + + if len(dataset) <= 1: + return dataset + if self.select_ratio is None: + return dataset + + select_dataset = self.diversity_process(dataset) + return select_dataset diff --git a/docs/Operators.md b/docs/Operators.md index d1abbbb782..6a6b5b500f 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -38,7 +38,7 @@ Data-Juicer 中的算子分为以下 7 种类型。 | [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | | [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | | [mapper](#mapper) | 81 | Edits and transforms samples. 对数据样本进行编辑和转换。 | -| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 | +| [selector](#selector) | 6 | Selects top samples based on ranking. 基于排序选取高质量样本。 | All the specific operators are listed below, each featured with several capability tags. 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 @@ -97,7 +97,7 @@ All the specific operators are listed below, each featured with several capabili | flagged_words_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with flagged-word ratio less than a specific max value. 过滤以保持标记词比率小于特定最大值的样本。 | [code](../data_juicer/ops/filter/flagged_words_filter.py) | [tests](../tests/ops/filter/test_flagged_words_filter.py) | | general_field_filter | 💻CPU 🟡Beta | Filter to keep samples based on a general field filter condition. 根据常规字段筛选条件保留样本。 | [code](../data_juicer/ops/filter/general_field_filter.py) | [tests](../tests/ops/filter/test_general_field_filter.py) | | image_aesthetics_filter | 🏞Image 💻CPU 🧩HF 🟢Stable | Filter to keep samples with aesthetics scores within a specific range. 过滤以保持美学分数在特定范围内的样品。 | [code](../data_juicer/ops/filter/image_aesthetics_filter.py) | [tests](../tests/ops/filter/test_image_aesthetics_filter.py) | -| image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持特定范围内的图像长宽比的样本。 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) | +| image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持样本的图像纵横比在特定范围内。 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) | | image_face_count_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with the number of faces within a specific range. 过滤以保持样本的面数在特定范围内。 | [code](../data_juicer/ops/filter/image_face_count_filter.py) | [tests](../tests/ops/filter/test_image_face_count_filter.py) | | image_face_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with face area ratios within a specific range. 过滤以保持面面积比在特定范围内的样本。 | [code](../data_juicer/ops/filter/image_face_ratio_filter.py) | [tests](../tests/ops/filter/test_image_face_ratio_filter.py) | | image_nsfw_filter | 🏞Image 💻CPU 🧩HF 🟢Stable | Filter to keep samples whose images have low nsfw scores. 过滤器保留图像具有低nsfw分数的样本。 | [code](../data_juicer/ops/filter/image_nsfw_filter.py) | [tests](../tests/ops/filter/test_image_nsfw_filter.py) | @@ -148,7 +148,7 @@ All the specific operators are listed below, each featured with several capabili | local_formatter | 🟢Stable | The class is used to load a dataset from local files or local directory. 类用于从本地文件或本地目录加载数据集。 | [code](../data_juicer/format/formatter.py) | [tests](../tests/format/test_unify_format.py) | | parquet_formatter | 🟢Stable | The class is used to load and format parquet-type files. 该类用于加载和格式化镶木地板类型的文件。 | [code](../data_juicer/format/parquet_formatter.py) | [tests](../tests/format/test_parquet_formatter.py) | | remote_formatter | 🟢Stable | The class is used to load a dataset from repository of huggingface hub. 该类用于从huggingface hub的存储库加载数据集。 | [code](../data_juicer/format/formatter.py) | [tests](../tests/format/test_unify_format.py) | -| text_formatter | 🔴Alpha | The class is used to load and format text-type files. 类用于加载和格式化文本类型文件。 | [code](../data_juicer/format/text_formatter.py) | - | +| text_formatter | 🔴Alpha | The class is used to load and format text-type files. 类用于加载和格式化文本类型的文件。 | [code](../data_juicer/format/text_formatter.py) | - | | tsv_formatter | 🟢Stable | The class is used to load and format tsv-type files. 该类用于加载和格式化tsv类型的文件。 | [code](../data_juicer/format/tsv_formatter.py) | [tests](../tests/format/test_tsv_formatter.py) | ## grouper @@ -165,7 +165,7 @@ All the specific operators are listed below, each featured with several capabili |----------|------|-------------|-------------|------------| | audio_add_gaussian_noise_mapper | 📣Audio 💻CPU 🟡Beta | Mapper to add gaussian noise to audio. 映射器向音频添加高斯噪声。 | [code](../data_juicer/ops/mapper/audio_add_gaussian_noise_mapper.py) | [tests](../tests/ops/mapper/test_audio_add_gaussian_noise_mapper.py) | | audio_ffmpeg_wrapped_mapper | 📣Audio 💻CPU 🟢Stable | Simple wrapper for FFmpeg audio filters. FFmpeg音频滤波器的简单包装。 | [code](../data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py) | -| calibrate_qa_mapper | 🔤Text 💻CPU 🔗API 🟢Stable | Mapper to calibrate question-answer pairs based on reference text. 映射器基于参考文本校准问题-答案对。 | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | +| calibrate_qa_mapper | 🔤Text 💻CPU 🔗API 🟢Stable | Mapper to calibrate question-answer pairs based on reference text. 映射器基于参考文本校准问答对。 | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | | calibrate_query_mapper | 💻CPU 🟢Stable | Mapper to calibrate query in question-answer pairs based on reference text. 映射器基于参考文本校准问答对中的查询。 | [code](../data_juicer/ops/mapper/calibrate_query_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_query_mapper.py) | | calibrate_response_mapper | 💻CPU 🟢Stable | Mapper to calibrate response in question-answer pairs based on reference text. 映射器基于参考文本校准问答对中的响应。 | [code](../data_juicer/ops/mapper/calibrate_response_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_response_mapper.py) | | chinese_convert_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. 映射器在繁体中文,简体中文和日语汉字之间转换中文。 | [code](../data_juicer/ops/mapper/chinese_convert_mapper.py) | [tests](../tests/ops/mapper/test_chinese_convert_mapper.py) | @@ -196,7 +196,7 @@ All the specific operators are listed below, each featured with several capabili | image_diffusion_mapper | 🔮Multimodal 💻CPU 🧩HF 🟢Stable | Generate image by diffusion model. 通过扩散模型生成图像。 | [code](../data_juicer/ops/mapper/image_diffusion_mapper.py) | [tests](../tests/ops/mapper/test_image_diffusion_mapper.py) | | image_face_blur_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to blur faces detected in images. 映射器模糊图像中检测到的人脸。 | [code](../data_juicer/ops/mapper/image_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_image_face_blur_mapper.py) | | image_remove_background_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to remove background of images. 映射器删除图像的背景。 | [code](../data_juicer/ops/mapper/image_remove_background_mapper.py) | [tests](../tests/ops/mapper/test_image_remove_background_mapper.py) | -| image_segment_mapper | 🏞Image 💻CPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 在图像上执行segment-anything并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) | +| image_segment_mapper | 🏞Image 💻CPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 对图像执行segment-任何操作并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) | | image_tagging_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to generate image tags. 映射器生成图像标签。 | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) | | imgdiff_difference_area_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_area_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_area_generator_mapper.py) | | imgdiff_difference_caption_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_caption_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_caption_generator_mapper.py) | @@ -249,6 +249,7 @@ All the specific operators are listed below, each featured with several capabili | Operator 算子 | Tags 标签 | Description 描述 | Source code 源码 | Unit tests 单测样例 | |----------|------|-------------|-------------|------------| +| domain_diversity_selector | 💻CPU 🔗API 🟡Beta | Selector to select samples based on the data's domain diversity. 选择器根据数据的域多样性选择样本。 | [code](../data_juicer/ops/selector/domain_diversity_selector.py) | [tests](../tests/ops/selector/test_domain_diversity_selector.py) | | frequency_specified_field_selector | 💻CPU 🟢Stable | Selector to select samples based on the sorted frequency of specified field. 选择器根据指定字段的排序频率选择样本。 | [code](../data_juicer/ops/selector/frequency_specified_field_selector.py) | [tests](../tests/ops/selector/test_frequency_specified_field_selector.py) | | random_selector | 💻CPU 🟢Stable | Selector to random select samples. 选择器来随机选择样本。 | [code](../data_juicer/ops/selector/random_selector.py) | [tests](../tests/ops/selector/test_random_selector.py) | | range_specified_field_selector | 💻CPU 🟢Stable | Selector to select a range of samples based on the sorted specified field value from smallest to largest. 选择器根据从最小到最大的排序指定字段值选择样本范围。 | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | diff --git a/tests/ops/selector/test_domain_diversity_selector.py b/tests/ops/selector/test_domain_diversity_selector.py new file mode 100644 index 0000000000..f4609a9e99 --- /dev/null +++ b/tests/ops/selector/test_domain_diversity_selector.py @@ -0,0 +1,145 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.selector.domain_diversity_selector import DomainDiversitySelector +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + +@unittest.skipIf(FROM_FORK, "Skipping the test because running from a fork repo") +class DomainDiversitySelectorTest(DataJuicerTestCaseBase): + + def _run_domain_diversity_selector(self, dataset: Dataset, target_num, op): + dataset = op.process(dataset) + res_list = dataset.to_list() + self.assertEqual(len(res_list), target_num) + + def test_ratio_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_num = 3 + dataset = Dataset.from_list(ds_list) + op = DomainDiversitySelector(select_ratio=0.2) + self._run_domain_diversity_selector(dataset, tgt_num, op) + + +if __name__ == '__main__': + unittest.main()