diff --git a/pyproject.toml b/pyproject.toml index 71e3414..6795eb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "DataDreamer" -version = "0.39.0" +version = "0.40.0" description = "Prompt. Generate Synthetic Data. Train & Align Models." license = "MIT" authors= [ @@ -73,7 +73,7 @@ ignore_missing_imports = true exclude = [".*/"] [tool.coverage.run] -omit = ["src/__main__.py", "src/project/*", "src/tests/*"] +omit = ["src/__main__.py", "src/project/*", "src/tests/*", "src/utils/hf_structured_decoding_utils.py"] [tool.coverage.report] exclude_lines = [ diff --git a/src/llms/hf_api_endpoint.py b/src/llms/hf_api_endpoint.py index 22ed24a..2842ad4 100644 --- a/src/llms/hf_api_endpoint.py +++ b/src/llms/hf_api_endpoint.py @@ -134,7 +134,7 @@ def get_generated_texts(self, kwargs, prompt) -> list[str]: func=self.client.text_generation, model=self.endpoint, prompt=prompt, - do_sample=True, + do_sample=kwargs.pop("do_sample", True), max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, return_full_text=False, diff --git a/src/llms/hf_transformers.py b/src/llms/hf_transformers.py index d790d3f..b2b6873 100644 --- a/src/llms/hf_transformers.py +++ b/src/llms/hf_transformers.py @@ -31,7 +31,7 @@ is_encoder_decoder, validate_quantization_config, ) -from ..utils.import_utils import ignore_transformers_warnings +from ..utils.import_utils import ignore_inference_warnings, ignore_transformers_warnings from .llm import ( DEFAULT_BATCH_SIZE, LLM, @@ -412,12 +412,22 @@ def _run_batch( # noqa: C901 stop=stop, prompts=prompts, tokenizer=cached_tokenizer ) stopping_criteria_list.append(sequence_stopping_criteria) - logits_processor = LogitsProcessorList(logits_processor_list) + logits_processor = LogitsProcessorList( + kwargs.pop( + "logits_processor_list", + ( + kwargs.pop("pre_logits_processor", []) + + logits_processor_list + + kwargs.pop("logits_processor", []) + + kwargs.pop("post_logits_processor", []) + ), + ) + ) stopping_criteria = StoppingCriteriaList(stopping_criteria_list) generation_kwargs = dict( max_new_tokens=max_new_tokens, pad_token_id=cached_tokenizer.eos_token_id, - do_sample=True, + do_sample=kwargs.pop("do_sample", True), temperature=temperature, top_p=top_p, logits_processor=logits_processor, @@ -562,36 +572,37 @@ def max_length_func( "cached_tokenizer": cached_tokenizer, } - results_generator = self._run_over_batches( - run_batch=self._run_batch, - get_max_input_length_function=partial( - get_max_length_function, self.tokenizer - ), - max_model_length=self.get_max_context_length, - inputs=prompts, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_p=top_p, - n=n, - stop=stop, - repetition_penalty=repetition_penalty, - logit_bias=logit_bias, - batch_size=batch_size, - batch_scheduler_buffer_size=batch_scheduler_buffer_size, - adaptive_batch_size=adaptive_batch_size, - seed=seed, - progress_interval=progress_interval, - force=force, - cache_only=cache_only, - verbose=verbose, - log_level=log_level, - total_num_inputs=total_num_prompts, - **kwargs, - ) - if not return_generator: - return list(results_generator) - else: - return results_generator + with ignore_inference_warnings(): + results_generator = self._run_over_batches( + run_batch=self._run_batch, + get_max_input_length_function=partial( + get_max_length_function, self.tokenizer + ), + max_model_length=self.get_max_context_length, + inputs=prompts, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + repetition_penalty=repetition_penalty, + logit_bias=logit_bias, + batch_size=batch_size, + batch_scheduler_buffer_size=batch_scheduler_buffer_size, + adaptive_batch_size=adaptive_batch_size, + seed=seed, + progress_interval=progress_interval, + force=force, + cache_only=cache_only, + verbose=verbose, + log_level=log_level, + total_num_inputs=total_num_prompts, + **kwargs, + ) + if not return_generator: + return list(results_generator) + else: + return results_generator @cached_property def model_card(self) -> None | str: diff --git a/src/requirements.txt b/src/requirements.txt index b9579c6..73ab3d6 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -13,8 +13,8 @@ ring>=0.10.1,<1.0.0 psutil>=6.1.1 faiss-cpu>=1.9.0.post1,<2.0.0 evaluate>=0.4.3,<1.0.0 -tiktoken>=0.8.0,<1.0.0 -sentence-transformers>=3.3.1,<4.0.0 +tiktoken>=0.7.0,<1.0.0 +sentence-transformers>=3.4.0,<4.0.0 setfit>=1.1.1,<2.0.0 openai>=1.59.6,<2.0.0 datasets>=3.2.0,<4.0.0 @@ -23,8 +23,9 @@ bitsandbytes>=0.45.0,<1.0.0 huggingface-hub>=0.27.1,<1.0.0 optimum>=1.21.2,<2.0.0 accelerate>=1.3.0,<2.0.0 -transformers>=4.48.0,<4.50.0 +transformers>=4.48.1,<4.50.0 ctransformers>=0.2.27,<1.0.0 +outlines-core==0.1.27 Pyro5>=5.15 litellm==1.57.8 trl==0.9.6 diff --git a/src/steps/step_operations.py b/src/steps/step_operations.py index c529a00..53d130f 100644 --- a/src/steps/step_operations.py +++ b/src/steps/step_operations.py @@ -217,6 +217,7 @@ def __create_step_operation_step( # noqa: C901 run: Callable, no_save: bool = False, setup: None | Callable = None, + steps: "list[Step] | None" = None, **kwargs, ) -> "Step": from .step import LazyRows @@ -287,12 +288,22 @@ def run(self): wait(step) # Create the op step - step_op_step = _StepOpStep( - name=final_name, - inputs={ + if steps is not None and len(steps) > 0: + prev_step_outputs = {} + for s_idx, s in enumerate(steps): + each_prev_step_outputs = { + f"{s_idx}_{column_name}": s.output[column_name] + for column_name in s.output.column_names + } + prev_step_outputs.update(each_prev_step_outputs) + else: + prev_step_outputs = { column_name: step.output[column_name] for column_name in step.output.column_names - }, + } + step_op_step = _StepOpStep( + name=final_name, + inputs=prev_step_outputs, verbose=step.verbose, log_level=step.log_level, **kwargs, @@ -375,6 +386,7 @@ def run(self): ) kwargs["step"] = steps[0] + kwargs["steps"] = steps kwargs["no_save"] = lazy kwargs["args"] = {"fingerprint": [[step.fingerprint for step in steps], axis]} kwargs["run"] = run diff --git a/src/tests/llms/test_llms.py b/src/tests/llms/test_llms.py index b4d4020..88394ed 100644 --- a/src/tests/llms/test_llms.py +++ b/src/tests/llms/test_llms.py @@ -56,6 +56,9 @@ _chat_prompt_template_and_system_prompt_from_tokenizer, ) from ...utils.hf_model_utils import get_model_prompt_template, get_orig_model +from ...utils.hf_structured_decoding_utils import ( # type:ignore[attr-defined] + JSONLogitProcessor, +) from ...utils.import_utils import ( ignore_litellm_warnings, ignore_transformers_warnings, @@ -1749,6 +1752,52 @@ def _run_batch_mocked(*args, **kwargs): ] ) + def test_structured_json_decoding(self, create_datadreamer): + from typing import Literal + + from pydantic import BaseModel, Field + + class AnswerOutputSchema(BaseModel): + answer_to_question: Literal["blue", "green"] + rgb_red_value: float = Field(..., ge=0.0, le=1.0) + rgb_green_value: float = Field(..., ge=0.0, le=1.0) + rgb_blue_value: float = Field(..., ge=0.0, le=1.0) + item_being_asked_about: str + + with create_datadreamer(): + llm = HFTransformers("gpt2") + generated_texts = llm.run( + [ + "Question: What is the color of a UPS truck?\nAnswer in JSON Format:", + "Question: What is the color of the sky?\nAnswer in JSON Format:", + ], + max_new_tokens=100, + do_sample=False, + batch_size=2, + logits_processor=[ + JSONLogitProcessor( + tokenizer=llm.tokenizer, + json_spec=AnswerOutputSchema.model_json_schema(), + whitespace_pattern=r"[\n ]{0,1}", + ) + ], + ) + objs = [json.loads(t) for t in generated_texts] # type:ignore[arg-type] + assert set(objs[0].keys()) == { + "answer_to_question", + "rgb_red_value", + "rgb_green_value", + "rgb_blue_value", + "item_being_asked_about", + } + assert set(objs[1].keys()) == { + "answer_to_question", + "rgb_red_value", + "rgb_green_value", + "rgb_blue_value", + "item_being_asked_about", + } + class TestCTransformers: def test_init(self, create_datadreamer): diff --git a/src/tests/test_datadreamer.py b/src/tests/test_datadreamer.py index 44ff1a9..013c1fa 100644 --- a/src/tests/test_datadreamer.py +++ b/src/tests/test_datadreamer.py @@ -10,7 +10,7 @@ from .. import DataDreamer, __version__ from ..datasets import OutputDataset from ..errors import StepOutputError -from ..steps import DataCardType, LazyRows, Step +from ..steps import DataCardType, LazyRows, Step, zipped from ..utils.background_utils import check_if_fault_handler_is_setup from ..utils.fs_utils import dir_size @@ -461,6 +461,10 @@ def setup_2(self): self.register_data_card(DataCardType.CITATION, "citation2") self.register_data_card(DataCardType.URL, "http://example2.com") + def setup_4(self): + self.register_data_card(DataCardType.CITATION, "citation4") + self.register_data_card(DataCardType.URL, "http://example4.com") + with create_datadreamer(): step_1 = create_test_step( name="my-step-1", inputs=None, output_names=["out1"], setup=setup_1 @@ -506,6 +510,21 @@ def setup_2(self): assert captured.out.startswith( """{\n "my-step-1": {\n "Date & Time":""" ) + step_4 = create_test_step( + name="my-step-4", inputs=None, output_names=["out4"], setup=setup_4 + ) + step_4._set_output({"out4": ["a4", "b4", "c4"]}) + step_34 = zipped(step_4, step_3) + capsys.readouterr() + step_34.data_card() + captured = capsys.readouterr() + assert captured.out.startswith( + """{\n "my-step-4": {\n "Date & Time":""" + ) + assert "my-step-1" in captured.out + assert "my-step-2" in captured.out + assert "my-step-2 (shuffle)" in captured.out + assert "zipped(my-step-4, my-step-2 (shuffle))" in captured.out def test_num_shards( self, create_datadreamer, create_test_step: Callable[..., Step] diff --git a/src/trainers/train_hf_classifier.py b/src/trainers/train_hf_classifier.py index 577f4c7..29c93c1 100644 --- a/src/trainers/train_hf_classifier.py +++ b/src/trainers/train_hf_classifier.py @@ -14,6 +14,7 @@ from ..utils.arg_utils import AUTO, Default from ..utils.distributed_utils import not_distributed_or_main_process from ..utils.hf_training_utils import ( + ComputeMetricsState, TrainingArguments, _monkey_patch_TrainerState__post_init__, get_logging_callback, @@ -157,7 +158,9 @@ def _train( # type:ignore[override] ) # Prepare compute metrics - def compute_accuracy_metrics(accuracy, f1, eval_pred): + compute_metrics_state = ComputeMetricsState() + + def compute_accuracy_metrics(accuracy, f1, eval_pred, compute_result=None): predictions, labels = eval_pred if isinstance(predictions, tuple): predictions = predictions[0] @@ -179,18 +182,22 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred): f1_metrics = f1.compute( predictions=hard_predictions, references=labels, average="micro" ) - return { - **accuracy_metrics, - **f1_metrics, - "joint_metric": JointMetric( - is_joint_metric=True, - primary=f1_metrics["f1"], - primary_name="f1", - secondary=(-1 * loss), - secondary_name="loss", - secondary_inversed=True, - ), - } + return compute_metrics_state.add_metrics( + batch_size=len(labels), + metrics_dict={ + **accuracy_metrics, + **f1_metrics, + "joint_metric": JointMetric( + is_joint_metric=True, + primary=f1_metrics["f1"], + primary_name="f1", + secondary=(-1 * loss), + secondary_name="loss", + secondary_inversed=True, + ), + }, + compute_result=compute_result, + ) compute_metrics = kwargs.pop("compute_metrics", None) or partial( compute_accuracy_metrics, @@ -250,6 +257,7 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred): num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, + batch_eval_metrics=kwargs.pop("batch_eval_metrics", True), optim=optim, learning_rate=learning_rate, weight_decay=weight_decay, diff --git a/src/trainers/train_hf_finetune.py b/src/trainers/train_hf_finetune.py index 1d61f65..417d9de 100644 --- a/src/trainers/train_hf_finetune.py +++ b/src/trainers/train_hf_finetune.py @@ -6,6 +6,7 @@ from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn from ..utils.arg_utils import AUTO, Default from ..utils.hf_training_utils import ( + ComputeMetricsState, CustomDataCollatorWithPadding, Seq2SeqTrainingArguments, TrainingArguments, @@ -137,6 +138,7 @@ def _train( # type:ignore[override] ) # Prepare compute metrics + compute_metrics_state = ComputeMetricsState() # This computation can use a fair bit of CPU RAM due to the size of these # tensors (batch_size * sequence_length * vocabulary_size), so we should try @@ -150,13 +152,14 @@ def _train( # type:ignore[override] except RuntimeError: # pragma: no cover compute_perplexity_dtype = torch.float32 - def compute_perplexity_metrics(eval_pred): + def compute_perplexity_metrics(eval_pred, compute_result=None): preds, labels = eval_pred del eval_pred if isinstance(preds, tuple): preds = preds[0] preds = torch.tensor(preds, dtype=compute_perplexity_dtype) labels = torch.tensor(labels) + batch_size = len(labels) if self._is_encoder_decoder: nll = torch.nn.functional.cross_entropy( input=preds.view(-1, preds.size(-1)), target=labels.view(-1) @@ -172,7 +175,11 @@ def compute_perplexity_metrics(eval_pred): input=shift_preds.view(-1, shift_preds.size(-1)), target=shift_labels.view(-1), ) - return {"perplexity": torch.exp(nll)} + return compute_metrics_state.add_metrics( + batch_size=batch_size, + metrics_dict={"perplexity": torch.exp(nll)}, + compute_result=compute_result, + ) compute_metrics = ( kwargs.pop("compute_metrics", None) or compute_perplexity_metrics @@ -230,6 +237,7 @@ def compute_perplexity_metrics(eval_pred): num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, + batch_eval_metrics=kwargs.pop("batch_eval_metrics", True), optim=optim, learning_rate=learning_rate, weight_decay=weight_decay, diff --git a/src/trainers/train_hf_ppo.py b/src/trainers/train_hf_ppo.py index 42c1e60..d57de77 100644 --- a/src/trainers/train_hf_ppo.py +++ b/src/trainers/train_hf_ppo.py @@ -20,6 +20,7 @@ from ..utils.fs_utils import mkdir from ..utils.hf_model_utils import is_peft_model from ..utils.hf_training_utils import ( + ComputeMetricsState, CustomDataCollatorWithPadding, TrainingArguments, get_logging_callback, @@ -105,10 +106,16 @@ class PPOTrainerWrapper(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def compute_metrics(eval_pred): + compute_metrics_state = ComputeMetricsState() + + def compute_metrics(eval_pred, compute_result=None): preds, _ = eval_pred mean_preds = preds.mean(axis=0) - return {"rewards": mean_preds} + return compute_metrics_state.add_metrics( + batch_size=len(preds), + metrics_dict={"rewards": mean_preds}, + compute_result=compute_result, + ) self.compute_metrics = compute_metrics @@ -630,6 +637,7 @@ def _train( # type:ignore[override] # noqa: C901 num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, + batch_eval_metrics=kwargs.pop("batch_eval_metrics", True), optim="adamw_torch", learning_rate=learning_rate, lr_scheduler_type="linear", diff --git a/src/trainers/train_hf_reward_model.py b/src/trainers/train_hf_reward_model.py index fda5f85..ab47581 100644 --- a/src/trainers/train_hf_reward_model.py +++ b/src/trainers/train_hf_reward_model.py @@ -15,6 +15,7 @@ from ..utils.distributed_utils import not_distributed_or_main_process from ..utils.hf_model_utils import get_base_model_from_peft_model from ..utils.hf_training_utils import ( + ComputeMetricsState, CustomDataCollatorWithPadding, TrainingArguments, _monkey_patch_TrainerState__post_init__, @@ -218,7 +219,9 @@ def _train_with_pairs( ) # Prepare compute metrics - def compute_accuracy_metrics(accuracy, eval_pred): + compute_metrics_state = ComputeMetricsState() + + def compute_accuracy_metrics(accuracy, eval_pred, compute_result=None): predictions, labels = eval_pred loss = F.cross_entropy( input=torch.tensor(predictions), @@ -228,17 +231,21 @@ def compute_accuracy_metrics(accuracy, eval_pred): accuracy_metrics = accuracy.compute( predictions=hard_predictions, references=labels ) - return { - **accuracy_metrics, - "joint_metric": JointMetric( - is_joint_metric=True, - primary=accuracy_metrics["accuracy"], - primary_name="f1", - secondary=(-1 * loss), - secondary_name="loss", - secondary_inversed=True, - ), - } + return compute_metrics_state.add_metrics( + batch_size=len(labels), + metrics_dict={ + **accuracy_metrics, + "joint_metric": JointMetric( + is_joint_metric=True, + primary=accuracy_metrics["accuracy"], + primary_name="f1", + secondary=(-1 * loss), + secondary_name="loss", + secondary_inversed=True, + ), + }, + compute_result=compute_result, + ) compute_metrics = kwargs.pop("compute_metrics", None) or partial( compute_accuracy_metrics, evaluate.load("accuracy") @@ -306,6 +313,7 @@ class RewardConfig(_TrainingArgumentDeviceOverrideMixin, _RewardConfig): num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, + batch_eval_metrics=kwargs.pop("batch_eval_metrics", True), optim=optim, learning_rate=learning_rate, weight_decay=weight_decay, @@ -442,7 +450,9 @@ def _train_with_scores( ) # Prepare compute metrics - def compute_mse_metrics(eval_pred): + compute_metrics_state = ComputeMetricsState() + + def compute_mse_metrics(eval_pred, compute_result=None): predictions, labels = eval_pred if isinstance(predictions, tuple): # pragma: no cover predictions = predictions[0] @@ -450,7 +460,11 @@ def compute_mse_metrics(eval_pred): mse_metrics = { "mse": F.mse_loss(torch.tensor(predictions), torch.tensor(labels)) } - return {**mse_metrics} + return compute_metrics_state.add_metrics( + batch_size=len(labels), + metrics_dict={**mse_metrics}, + compute_result=compute_result, + ) compute_metrics = kwargs.pop("compute_metrics", None) or compute_mse_metrics @@ -507,6 +521,7 @@ def compute_mse_metrics(eval_pred): num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, + batch_eval_metrics=kwargs.pop("batch_eval_metrics", True), optim=optim, learning_rate=learning_rate, weight_decay=weight_decay, diff --git a/src/trainers/train_sentence_transformer.py b/src/trainers/train_sentence_transformer.py index 5265a68..12c53a3 100644 --- a/src/trainers/train_sentence_transformer.py +++ b/src/trainers/train_sentence_transformer.py @@ -24,6 +24,7 @@ validate_peft_config, ) from ..utils.hf_training_utils import ( + ComputeMetricsState, CustomDataCollatorWithPadding, TrainingArguments, _monkey_patch_TrainerState__post_init__, @@ -124,6 +125,7 @@ def forward( negative_input_ids: None | torch.Tensor = None, negative_attention_mask: None | torch.Tensor = None, labels: None | torch.Tensor = None, + num_items_in_batch=None, ): _uniq_ids = [] sentence_features = [] @@ -456,7 +458,9 @@ def _train( # type:ignore[override] # noqa: C901 ) # Prepare compute metrics - def compute_accuracy_metrics(accuracy, f1, eval_pred): + compute_metrics_state = ComputeMetricsState() + + def compute_accuracy_metrics(accuracy, f1, eval_pred, compute_result=None): (all_embeddings, loss), labels = eval_pred if isinstance(loss, np.ndarray): # pragma: no cover loss = np.mean(loss) @@ -496,17 +500,24 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred): accuracy_metrics = accuracy.compute( predictions=preds, references=labels ) - return { - **accuracy_metrics, - "joint_metric": JointMetric( - is_joint_metric=True, - primary=accuracy_metrics["accuracy"], - primary_name="f1", - secondary=(-1 * loss), - secondary_name="loss", - secondary_inversed=True, - ), - } + inverse_loss = -1 * loss + if isinstance(inverse_loss, torch.Tensor): # pragma: no cover + inverse_loss = inverse_loss.cpu().item() + return compute_metrics_state.add_metrics( + batch_size=len(labels), + metrics_dict={ + **accuracy_metrics, + "joint_metric": JointMetric( + is_joint_metric=True, + primary=accuracy_metrics["accuracy"], + primary_name="f1", + secondary=inverse_loss, + secondary_name="loss", + secondary_inversed=True, + ), + }, + compute_result=compute_result, + ) else: return {} @@ -598,6 +609,7 @@ def __getattr__(self, name): num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, + batch_eval_metrics=kwargs.pop("batch_eval_metrics", True), optim=optim, learning_rate=learning_rate, weight_decay=weight_decay, diff --git a/src/trainers/trainer.py b/src/trainers/trainer.py index fd22f73..83bd0d9 100644 --- a/src/trainers/trainer.py +++ b/src/trainers/trainer.py @@ -47,6 +47,36 @@ def __lt__(self, other): def __sub__(self, other): return self.secondary - other.secondary + def __add__(self, other): + return JointMetric( + is_joint_metric=True, + primary=self.primary + other.primary, + primary_name=self.primary_name, + secondary=self.secondary + other.secondary, + secondary_name=self.secondary_name, + secondary_inversed=self.secondary_inversed, + ) + + def __mul__(self, other): + return JointMetric( + is_joint_metric=True, + primary=self.primary * other, + primary_name=self.primary_name, + secondary=self.secondary * other, + secondary_name=self.secondary_name, + secondary_inversed=self.secondary_inversed, + ) + + def __truediv__(self, other): + return JointMetric( + is_joint_metric=True, + primary=self.primary / other, + primary_name=self.primary_name, + secondary=self.secondary / other, + secondary_name=self.secondary_name, + secondary_inversed=self.secondary_inversed, + ) + def __repr__(self) -> str: # pragma: no cover secondary = (-1 * self.secondary) if self.secondary_inversed else self.secondary return ( diff --git a/src/utils/hf_model_utils.py b/src/utils/hf_model_utils.py index 67070ee..c1623bd 100644 --- a/src/utils/hf_model_utils.py +++ b/src/utils/hf_model_utils.py @@ -165,7 +165,7 @@ def get_model_max_context_length(model_name: str, config: PretrainedConfig) -> i else: if "bloom" in model_name: # pragma: no cover max_context_length = 2048 - elif config.model_type in ["t5", "mt5"]: + elif config.model_type in ["t5", "mt5", "umt5"]: max_context_length = 512 else: raise RuntimeError( diff --git a/src/utils/hf_structured_decoding_utils.py b/src/utils/hf_structured_decoding_utils.py new file mode 100644 index 0000000..8264798 --- /dev/null +++ b/src/utils/hf_structured_decoding_utils.py @@ -0,0 +1,168 @@ +# type: ignore +# ruff: noqa + +import json +import math +from collections import defaultdict +from functools import lru_cache +from typing import DefaultDict, Optional + +import torch +from outlines_core.fsm.guide import ( + RegexGuide as CoreRegexGuide, + create_states_mapping as uncached_create_states_mapping, +) +from outlines_core.fsm.json_schema import build_regex_from_schema +from transformers import LogitsProcessor, PreTrainedTokenizerBase + + +def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs): + return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs) + + +class RegexGuide(CoreRegexGuide): + """ + Guide to generate text in the language of a regular expression. + CoreRegexGuide with outlines cache + """ + + @classmethod + def from_regex(cls, regex_string: str, tokenizer, **kwargs): + return super().from_regex( + regex_string, + tokenizer, + _create_states_mapping=cached_create_states_mapping, + **kwargs, + ) + + +class _GrammarLogitProcessor(LogitsProcessor): + fsm_state: DefaultDict[int, int] + fsm: RegexGuide + + def __init__(self, tokenizer: Optional[PreTrainedTokenizerBase], grammar: str): + self.tokenizer = _GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + self.fsm = _GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer) + + def __call__(self, logits: torch.Tensor, fsm_grammar_state: int): + if fsm_grammar_state == -1 or self.fsm is None: + return logits + allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens + mask = torch.full_like(logits, -math.inf) + if allowed_tokens is not None: + mask[:, allowed_tokens] = 0 + biased_scores = logits + mask + return biased_scores + + def advance(self, next_token_id, fsm_grammar_state): + return _GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsm + ) + + @staticmethod + def _advance(next_token_id, fsm_grammar_state, fsm): + if fsm_grammar_state == -1: + return fsm_grammar_state + return fsm.get_next_state(fsm_grammar_state, next_token_id) + + # TODO: move grammar compilation into the router + @staticmethod + @lru_cache(maxsize=32, typed=True) + def _cached_compile_fsm(schema: str, tokenizer: Optional[PreTrainedTokenizerBase]): + regex_str = schema + fsm = RegexGuide.from_regex(regex_str, tokenizer) + return fsm + + @staticmethod + @lru_cache(maxsize=32, typed=True) + def _cached_adapt_tokenizer(tokenizer): + """Adapt tokenizer to work with the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + return tokenizer + + +class GrammarLogitProcessor: + def __init__( + self, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + grammar: str = "", + fsm_grammar_state: Optional[int] = 0, + ): + self.grammar_processor = ( + _GrammarLogitProcessor(tokenizer, grammar) if grammar != "" else None + ) + self.tokenizer = tokenizer + self.grammar = grammar + + # Initialize FSM grammar states for each batch item + self.fsm_grammar_states = defaultdict(lambda: fsm_grammar_state) + + def __call__(self, input_ids, scores): + batch_size = scores.size(0) + + # Warp next scores with the grammar_processor + if self.grammar_processor is not None: + for i in range(batch_size): + scores[i, :] = self.grammar_processor( + scores[i, :].unsqueeze(0), self.fsm_grammar_states[i] + )[0] + + # Compute the log softmax of scores for the entire batch + next_logprob = torch.log_softmax(scores, dim=-1) + + # Get the next ID with the highest score for each batch item + next_ids = scores.argmax(dim=-1).view(batch_size, 1) + + # Advance grammar states for each batch item + for i in range(batch_size): + self.advance_grammar(i, next_ids[i].item()) + + # Create a mask to set everything except next_ids to -inf + mask = torch.full_like(scores, float("-inf")) + mask.scatter_( + dim=-1, index=next_ids, value=0.0 + ) # Set the score for next_id to 0 (log(1) = 0) + + next_logprob = mask # Replace all scores with the mask + + return next_logprob + + def advance_grammar(self, batch_idx: int, next_id: int): + if self.grammar_processor is not None: + self.fsm_grammar_states[batch_idx] = self.grammar_processor.advance( + next_id, self.fsm_grammar_states[batch_idx] + ) + return self + + +class JSONLogitProcessor(GrammarLogitProcessor): + def __init__(self, tokenizer, json_spec, whitespace_pattern=r"[\n ]*"): + if not isinstance(json_spec, str): + json_spec = json.dumps(json_spec) + compiled_grammar = build_regex_from_schema( + json_spec, whitespace_pattern=whitespace_pattern + ) + super().__init__(tokenizer=tokenizer, grammar=compiled_grammar) + + +__all__ = ["GrammarLogitProcessor", "JSONLogitProcessor"] diff --git a/src/utils/hf_training_utils.py b/src/utils/hf_training_utils.py index 7876b49..0f3a3b0 100644 --- a/src/utils/hf_training_utils.py +++ b/src/utils/hf_training_utils.py @@ -10,11 +10,10 @@ import dill import numpy as np import torch +from datasets import Dataset, IterableDataset, Value, concatenate_datasets from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR -from datasets import Dataset, IterableDataset, Value, concatenate_datasets - from .. import DataDreamer from ..datasets import ( OutputDatasetColumn, @@ -997,9 +996,14 @@ def on_log(self_, args, state, control, logs=None, **kwargs): def wrap_compute_metrics(compute_metrics, training_args: "TrainingArguments"): - def _wrapped_compute_metrics(*args, **kwargs): + def _wrapped_compute_metrics(*args, compute_result: None | bool = None, **kwargs): if not_distributed_or_main_process(): - computed_metrics = compute_metrics(*args, **kwargs) + if compute_result is not None: + computed_metrics = compute_metrics( + *args, compute_result=compute_result, **kwargs + ) + else: # pragma: no cover + computed_metrics = compute_metrics(*args, **kwargs) if is_distributed(): # pragma: no cover for _ in range(get_local_world_size() - 1): DataDreamer.ctx.distributed_pipe.put(dill.dumps(computed_metrics)) @@ -1028,3 +1032,53 @@ def _save_memory_in__EvalLoopContainer_add(self, *args, **kwargs): @cache def _monkey_patch_EvalLoopContainer_add(): EvalLoopContainer.add = _save_memory_in__EvalLoopContainer_add + + +class ComputeMetricsState: + def __init__(self): + self.metrics = [] + + def add_metrics(self, batch_size, metrics_dict, compute_result: None | bool = None): + if compute_result is None: # pragma: no cover + return metrics_dict + elif compute_result is False: + self.metrics.append({"weight": batch_size, "metrics": metrics_dict}) + return metrics_dict + elif compute_result is True: + self.metrics.append({"weight": batch_size, "metrics": metrics_dict}) + + # Compute total weight + total_weight = sum([m["weight"] for m in self.metrics]) + + # Initialize a dictionary to store the weighted sums of metrics + weighted_sums = {} + + # Accumulate the weighted sum for each metric + for entry in self.metrics: + weight = entry["weight"] + metrics = entry["metrics"] + for key, value in metrics.items(): + if not ( + isinstance(value, int) + or isinstance(value, float) + or isinstance(value, JointMetric) + or isinstance(value, torch.Tensor) + or isinstance(value, np.ndarray) + or isinstance(value, np.floating) + or isinstance(value, np.integer) + ): # pragma: no cover + value = 0 + if key not in weighted_sums: + weighted_sums[key] = value * weight + else: + weighted_sums[key] += value * weight + + # Compute the weighted average for each metric + averaged_metrics = { + key: weighted_sums[key] / total_weight for key in weighted_sums + } + + # Reset the metrics state + self.metrics.clear() + + return averaged_metrics diff --git a/src/utils/import_utils.py b/src/utils/import_utils.py index 3a25d32..9f580f2 100644 --- a/src/utils/import_utils.py +++ b/src/utils/import_utils.py @@ -80,6 +80,9 @@ def ignore_training_warnings(): category=UserWarning, message="Merge.*may get different generations due to rounding error.*", ) + warnings.filterwarnings( + "ignore", category=UserWarning, message="To copy construct from a tensor.*" + ) warnings.filterwarnings( "ignore", category=FutureWarning, @@ -88,6 +91,15 @@ def ignore_training_warnings(): yield None +@contextlib.contextmanager +def ignore_inference_warnings(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=UserWarning, message="`do_sample` is set to `False`.*" + ) + yield None + + @contextlib.contextmanager def ignore_pydantic_warnings(): with warnings.catch_warnings():