Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "DataDreamer"
version = "0.39.0"
version = "0.40.0"
description = "Prompt. Generate Synthetic Data. Train & Align Models."
license = "MIT"
authors= [
Expand Down Expand Up @@ -73,7 +73,7 @@ ignore_missing_imports = true
exclude = [".*/"]

[tool.coverage.run]
omit = ["src/__main__.py", "src/project/*", "src/tests/*"]
omit = ["src/__main__.py", "src/project/*", "src/tests/*", "src/utils/hf_structured_decoding_utils.py"]

[tool.coverage.report]
exclude_lines = [
Expand Down
2 changes: 1 addition & 1 deletion src/llms/hf_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_generated_texts(self, kwargs, prompt) -> list[str]:
func=self.client.text_generation,
model=self.endpoint,
prompt=prompt,
do_sample=True,
do_sample=kwargs.pop("do_sample", True),
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=False,
Expand Down
77 changes: 44 additions & 33 deletions src/llms/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
is_encoder_decoder,
validate_quantization_config,
)
from ..utils.import_utils import ignore_transformers_warnings
from ..utils.import_utils import ignore_inference_warnings, ignore_transformers_warnings
from .llm import (
DEFAULT_BATCH_SIZE,
LLM,
Expand Down Expand Up @@ -412,12 +412,22 @@ def _run_batch( # noqa: C901
stop=stop, prompts=prompts, tokenizer=cached_tokenizer
)
stopping_criteria_list.append(sequence_stopping_criteria)
logits_processor = LogitsProcessorList(logits_processor_list)
logits_processor = LogitsProcessorList(
kwargs.pop(
"logits_processor_list",
(
kwargs.pop("pre_logits_processor", [])
+ logits_processor_list
+ kwargs.pop("logits_processor", [])
+ kwargs.pop("post_logits_processor", [])
),
)
)
stopping_criteria = StoppingCriteriaList(stopping_criteria_list)
generation_kwargs = dict(
max_new_tokens=max_new_tokens,
pad_token_id=cached_tokenizer.eos_token_id,
do_sample=True,
do_sample=kwargs.pop("do_sample", True),
temperature=temperature,
top_p=top_p,
logits_processor=logits_processor,
Expand Down Expand Up @@ -562,36 +572,37 @@ def max_length_func(
"cached_tokenizer": cached_tokenizer,
}

results_generator = self._run_over_batches(
run_batch=self._run_batch,
get_max_input_length_function=partial(
get_max_length_function, self.tokenizer
),
max_model_length=self.get_max_context_length,
inputs=prompts,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
repetition_penalty=repetition_penalty,
logit_bias=logit_bias,
batch_size=batch_size,
batch_scheduler_buffer_size=batch_scheduler_buffer_size,
adaptive_batch_size=adaptive_batch_size,
seed=seed,
progress_interval=progress_interval,
force=force,
cache_only=cache_only,
verbose=verbose,
log_level=log_level,
total_num_inputs=total_num_prompts,
**kwargs,
)
if not return_generator:
return list(results_generator)
else:
return results_generator
with ignore_inference_warnings():
results_generator = self._run_over_batches(
run_batch=self._run_batch,
get_max_input_length_function=partial(
get_max_length_function, self.tokenizer
),
max_model_length=self.get_max_context_length,
inputs=prompts,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
repetition_penalty=repetition_penalty,
logit_bias=logit_bias,
batch_size=batch_size,
batch_scheduler_buffer_size=batch_scheduler_buffer_size,
adaptive_batch_size=adaptive_batch_size,
seed=seed,
progress_interval=progress_interval,
force=force,
cache_only=cache_only,
verbose=verbose,
log_level=log_level,
total_num_inputs=total_num_prompts,
**kwargs,
)
if not return_generator:
return list(results_generator)
else:
return results_generator

@cached_property
def model_card(self) -> None | str:
Expand Down
7 changes: 4 additions & 3 deletions src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ ring>=0.10.1,<1.0.0
psutil>=6.1.1
faiss-cpu>=1.9.0.post1,<2.0.0
evaluate>=0.4.3,<1.0.0
tiktoken>=0.8.0,<1.0.0
sentence-transformers>=3.3.1,<4.0.0
tiktoken>=0.7.0,<1.0.0
sentence-transformers>=3.4.0,<4.0.0
setfit>=1.1.1,<2.0.0
openai>=1.59.6,<2.0.0
datasets>=3.2.0,<4.0.0
Expand All @@ -23,8 +23,9 @@ bitsandbytes>=0.45.0,<1.0.0
huggingface-hub>=0.27.1,<1.0.0
optimum>=1.21.2,<2.0.0
accelerate>=1.3.0,<2.0.0
transformers>=4.48.0,<4.50.0
transformers>=4.48.1,<4.50.0
ctransformers>=0.2.27,<1.0.0
outlines-core==0.1.27
Pyro5>=5.15
litellm==1.57.8
trl==0.9.6
20 changes: 16 additions & 4 deletions src/steps/step_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __create_step_operation_step( # noqa: C901
run: Callable,
no_save: bool = False,
setup: None | Callable = None,
steps: "list[Step] | None" = None,
**kwargs,
) -> "Step":
from .step import LazyRows
Expand Down Expand Up @@ -287,12 +288,22 @@ def run(self):
wait(step)

# Create the op step
step_op_step = _StepOpStep(
name=final_name,
inputs={
if steps is not None and len(steps) > 0:
prev_step_outputs = {}
for s_idx, s in enumerate(steps):
each_prev_step_outputs = {
f"{s_idx}_{column_name}": s.output[column_name]
for column_name in s.output.column_names
}
prev_step_outputs.update(each_prev_step_outputs)
else:
prev_step_outputs = {
column_name: step.output[column_name]
for column_name in step.output.column_names
},
}
step_op_step = _StepOpStep(
name=final_name,
inputs=prev_step_outputs,
verbose=step.verbose,
log_level=step.log_level,
**kwargs,
Expand Down Expand Up @@ -375,6 +386,7 @@ def run(self):
)

kwargs["step"] = steps[0]
kwargs["steps"] = steps
kwargs["no_save"] = lazy
kwargs["args"] = {"fingerprint": [[step.fingerprint for step in steps], axis]}
kwargs["run"] = run
Expand Down
49 changes: 49 additions & 0 deletions src/tests/llms/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
_chat_prompt_template_and_system_prompt_from_tokenizer,
)
from ...utils.hf_model_utils import get_model_prompt_template, get_orig_model
from ...utils.hf_structured_decoding_utils import ( # type:ignore[attr-defined]
JSONLogitProcessor,
)
from ...utils.import_utils import (
ignore_litellm_warnings,
ignore_transformers_warnings,
Expand Down Expand Up @@ -1749,6 +1752,52 @@ def _run_batch_mocked(*args, **kwargs):
]
)

def test_structured_json_decoding(self, create_datadreamer):
from typing import Literal

from pydantic import BaseModel, Field

class AnswerOutputSchema(BaseModel):
answer_to_question: Literal["blue", "green"]
rgb_red_value: float = Field(..., ge=0.0, le=1.0)
rgb_green_value: float = Field(..., ge=0.0, le=1.0)
rgb_blue_value: float = Field(..., ge=0.0, le=1.0)
item_being_asked_about: str

with create_datadreamer():
llm = HFTransformers("gpt2")
generated_texts = llm.run(
[
"Question: What is the color of a UPS truck?\nAnswer in JSON Format:",
"Question: What is the color of the sky?\nAnswer in JSON Format:",
],
max_new_tokens=100,
do_sample=False,
batch_size=2,
logits_processor=[
JSONLogitProcessor(
tokenizer=llm.tokenizer,
json_spec=AnswerOutputSchema.model_json_schema(),
whitespace_pattern=r"[\n ]{0,1}",
)
],
)
objs = [json.loads(t) for t in generated_texts] # type:ignore[arg-type]
assert set(objs[0].keys()) == {
"answer_to_question",
"rgb_red_value",
"rgb_green_value",
"rgb_blue_value",
"item_being_asked_about",
}
assert set(objs[1].keys()) == {
"answer_to_question",
"rgb_red_value",
"rgb_green_value",
"rgb_blue_value",
"item_being_asked_about",
}


class TestCTransformers:
def test_init(self, create_datadreamer):
Expand Down
21 changes: 20 additions & 1 deletion src/tests/test_datadreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .. import DataDreamer, __version__
from ..datasets import OutputDataset
from ..errors import StepOutputError
from ..steps import DataCardType, LazyRows, Step
from ..steps import DataCardType, LazyRows, Step, zipped
from ..utils.background_utils import check_if_fault_handler_is_setup
from ..utils.fs_utils import dir_size

Expand Down Expand Up @@ -461,6 +461,10 @@ def setup_2(self):
self.register_data_card(DataCardType.CITATION, "citation2")
self.register_data_card(DataCardType.URL, "http://example2.com")

def setup_4(self):
self.register_data_card(DataCardType.CITATION, "citation4")
self.register_data_card(DataCardType.URL, "http://example4.com")

with create_datadreamer():
step_1 = create_test_step(
name="my-step-1", inputs=None, output_names=["out1"], setup=setup_1
Expand Down Expand Up @@ -506,6 +510,21 @@ def setup_2(self):
assert captured.out.startswith(
"""{\n "my-step-1": {\n "Date & Time":"""
)
step_4 = create_test_step(
name="my-step-4", inputs=None, output_names=["out4"], setup=setup_4
)
step_4._set_output({"out4": ["a4", "b4", "c4"]})
step_34 = zipped(step_4, step_3)
capsys.readouterr()
step_34.data_card()
captured = capsys.readouterr()
assert captured.out.startswith(
"""{\n "my-step-4": {\n "Date & Time":"""
)
assert "my-step-1" in captured.out
assert "my-step-2" in captured.out
assert "my-step-2 (shuffle)" in captured.out
assert "zipped(my-step-4, my-step-2 (shuffle))" in captured.out

def test_num_shards(
self, create_datadreamer, create_test_step: Callable[..., Step]
Expand Down
34 changes: 21 additions & 13 deletions src/trainers/train_hf_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..utils.arg_utils import AUTO, Default
from ..utils.distributed_utils import not_distributed_or_main_process
from ..utils.hf_training_utils import (
ComputeMetricsState,
TrainingArguments,
_monkey_patch_TrainerState__post_init__,
get_logging_callback,
Expand Down Expand Up @@ -157,7 +158,9 @@ def _train( # type:ignore[override]
)

# Prepare compute metrics
def compute_accuracy_metrics(accuracy, f1, eval_pred):
compute_metrics_state = ComputeMetricsState()

def compute_accuracy_metrics(accuracy, f1, eval_pred, compute_result=None):
predictions, labels = eval_pred
if isinstance(predictions, tuple):
predictions = predictions[0]
Expand All @@ -179,18 +182,22 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred):
f1_metrics = f1.compute(
predictions=hard_predictions, references=labels, average="micro"
)
return {
**accuracy_metrics,
**f1_metrics,
"joint_metric": JointMetric(
is_joint_metric=True,
primary=f1_metrics["f1"],
primary_name="f1",
secondary=(-1 * loss),
secondary_name="loss",
secondary_inversed=True,
),
}
return compute_metrics_state.add_metrics(
batch_size=len(labels),
metrics_dict={
**accuracy_metrics,
**f1_metrics,
"joint_metric": JointMetric(
is_joint_metric=True,
primary=f1_metrics["f1"],
primary_name="f1",
secondary=(-1 * loss),
secondary_name="loss",
secondary_inversed=True,
),
},
compute_result=compute_result,
)

compute_metrics = kwargs.pop("compute_metrics", None) or partial(
compute_accuracy_metrics,
Expand Down Expand Up @@ -250,6 +257,7 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred):
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
batch_eval_metrics=kwargs.pop("batch_eval_metrics", True),
optim=optim,
learning_rate=learning_rate,
weight_decay=weight_decay,
Expand Down
Loading
Loading