Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def run(

# === Expand each conclusion into multiple candidate tasks (rows)
expanded_rows = []
for idx, (row, output_str, identifier) in enumerate(zip(dataframe.itertuples(index=False), conclusions, identifiers)):
# Use to_dict('records') instead of itertuples() to preserve column names with special characters (e.g., "user:contents")
for idx, (row, output_str, identifier) in enumerate(zip(dataframe.to_dict('records'), conclusions, identifiers)):
try:
parsed = json.loads(self._clean_json_block(output_str))
parsed = parsed[:self.max_per_task] if isinstance(parsed, list) else parsed
Expand All @@ -339,7 +340,7 @@ def run(
for item in parsed:
if isinstance(item, dict) and "conclusion" in item and "R" in item:
expanded_rows.append({
**row._asdict(),
**row, # row is already a dict now
"identifier": str(identifier),
"candidate_tasks_str": json.dumps(item, ensure_ascii=False)
})
Expand All @@ -359,7 +360,8 @@ def run(
answers = []
valid_rows = []

for idx, (res, row) in enumerate(zip(question_outputs, dataframe.itertuples(index=False))):
# Use to_dict('records') instead of itertuples() to preserve column names with special characters
for idx, (res, row) in enumerate(zip(question_outputs, dataframe.to_dict('records'))):
try:
parsed = json.loads(self._clean_json_block(res))
except Exception as e:
Expand All @@ -369,11 +371,11 @@ def run(
if isinstance(parsed, dict) and "Q" in parsed:
question = parsed["Q"]
try:
task = json.loads(self._clean_json_block(row.candidate_tasks_str))
task = json.loads(self._clean_json_block(row['candidate_tasks_str']))
answer = task.get("conclusion", "")
except Exception:
answer = ""
valid_rows.append(row._asdict())
valid_rows.append(row) # row is already a dict
questions.append(str(question))
answers.append(str(answer))

Expand Down Expand Up @@ -412,6 +414,10 @@ def run(
dataframe["llm_score"] = llm_score
dataframe = dataframe[dataframe["llm_score"] < 1].reset_index(drop=True)

if dataframe.empty:
self.logger.warning("No data left after LLM score filtering. All questions were answered correctly by LLM.")
return

self.logger.info("Get golden doc answer...")
sys_prompts, user_prompts = self._reformat_prompt(dataframe, "golden_doc_answer")
llm_answer_results = self.llm_serving.generate_from_input(user_prompts, sys_prompts)
Expand Down
66 changes: 59 additions & 7 deletions dataflow/operators/code/eval/python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,42 @@ def base64_to_image(
class PersistentWorker:
"""Persistent worker process."""

# Runtime class registry for pickle-safe serialization
RUNTIME_REGISTRY = {
'ImageRuntime': None, # Will be set later to avoid circular import
'DateRuntime': None,
'ColorObjectRuntime': None,
'GenericRuntime': None,
}

@classmethod
def _get_runtime_class(cls, runtime_identifier):
"""Get runtime class from identifier (class name or class object)."""
if isinstance(runtime_identifier, str):
# String identifier - look up in registry
if runtime_identifier in cls.RUNTIME_REGISTRY:
return cls.RUNTIME_REGISTRY[runtime_identifier]
else:
# Default to ImageRuntime if not found
return cls.RUNTIME_REGISTRY.get('ImageRuntime', ImageRuntime)
elif isinstance(runtime_identifier, type):
# Class object - get its name and look up
class_name = runtime_identifier.__name__
return cls.RUNTIME_REGISTRY.get(class_name, runtime_identifier)
else:
# Default fallback
return cls.RUNTIME_REGISTRY.get('ImageRuntime', ImageRuntime)

@classmethod
def _get_runtime_identifier(cls, runtime_class):
"""Convert runtime class to pickle-safe identifier."""
if runtime_class is None:
return 'ImageRuntime'
elif isinstance(runtime_class, str):
return runtime_class
else:
return runtime_class.__name__

def __init__(self):
self.input_queue = multiprocessing.Queue()
self.output_queue = multiprocessing.Queue()
Expand Down Expand Up @@ -111,7 +147,8 @@ def _worker_loop(self):
if task_type == 'init':
# Initialize runtime
messages = task.get('messages', [])
runtime_class = task.get('runtime_class', ImageRuntime)
runtime_identifier = task.get('runtime_class', 'ImageRuntime')
runtime_class = self._get_runtime_class(runtime_identifier)
runtime = runtime_class(messages)
self.output_queue.put({
'status': 'success',
Expand All @@ -122,7 +159,8 @@ def _worker_loop(self):
# Execute code
if runtime is None:
messages = task.get('messages', [])
runtime_class = task.get('runtime_class', ImageRuntime)
runtime_identifier = task.get('runtime_class', 'ImageRuntime')
runtime_class = self._get_runtime_class(runtime_identifier)
runtime = runtime_class(messages)

code = task.get('code')
Expand Down Expand Up @@ -184,7 +222,8 @@ def _worker_loop(self):
elif task_type == 'reset':
# Reset runtime
messages = task.get('messages', [])
runtime_class = task.get('runtime_class', ImageRuntime)
runtime_identifier = task.get('runtime_class', 'ImageRuntime')
runtime_class = self._get_runtime_class(runtime_identifier)
runtime = runtime_class(messages)
self.output_queue.put({
'status': 'success',
Expand All @@ -201,11 +240,13 @@ def _worker_loop(self):
def execute(self, code: List[str], messages: list = None, runtime_class=None,
get_answer_from_stdout=True, answer_symbol=None, answer_expr=None, timeout: int = 30):
"""Execute code."""
# Convert runtime class to pickle-safe identifier
runtime_identifier = self._get_runtime_identifier(runtime_class)
self.input_queue.put({
'type': 'execute',
'code': code,
'messages': messages,
'runtime_class': runtime_class,
'runtime_class': runtime_identifier,
'get_answer_from_stdout': get_answer_from_stdout,
'answer_symbol': answer_symbol,
'answer_expr': answer_expr
Expand All @@ -223,19 +264,23 @@ def execute(self, code: List[str], messages: list = None, runtime_class=None,

def init_runtime(self, messages: list, runtime_class=None):
"""Initialize runtime."""
# Convert runtime class to pickle-safe identifier
runtime_identifier = self._get_runtime_identifier(runtime_class)
self.input_queue.put({
'type': 'init',
'messages': messages,
'runtime_class': runtime_class
'runtime_class': runtime_identifier
})
return self.output_queue.get()

def reset_runtime(self, messages: list = None, runtime_class=None):
"""Reset runtime."""
# Convert runtime class to pickle-safe identifier
runtime_identifier = self._get_runtime_identifier(runtime_class)
self.input_queue.put({
'type': 'reset',
'messages': messages,
'runtime_class': runtime_class
'runtime_class': runtime_identifier
})
return self.output_queue.get()

Expand Down Expand Up @@ -548,4 +593,11 @@ def reset(self, messages=None):
def __del__(self):
"""Clean up resources."""
if self.persistent_worker:
self.persistent_worker.terminate()
self.persistent_worker.terminate()


# Initialize runtime registry after all classes are defined
PersistentWorker.RUNTIME_REGISTRY['ImageRuntime'] = ImageRuntime
PersistentWorker.RUNTIME_REGISTRY['DateRuntime'] = DateRuntime
PersistentWorker.RUNTIME_REGISTRY['ColorObjectRuntime'] = ColorObjectRuntime
PersistentWorker.RUNTIME_REGISTRY['GenericRuntime'] = GenericRuntime
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pandas as pd
import random
import re
from typing import Union

@prompt_restrict(
Expand Down
24 changes: 20 additions & 4 deletions dataflow/utils/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def __init__(
pipeline_id: str = None,
input_task_id: str = None,
output_task_id: str = None,
parent_pipeline_id: str = None,
page_size: int = 10000,
page_num: int = 0
):
Expand All @@ -375,6 +376,7 @@ def __init__(
pipeline_id: str, 当前 pipeline 的标识(可选,默认 None)
input_task_id: str, 输入任务的标识(可选,默认 None)
output_task_id: str, 输出任务的标识(可选,默认 None)
parent_pipeline_id: str, 父 pipeline 的标识(可选,默认 None)
page_size: int, 分页时每页的记录数(默认 10000)
page_num: int, 当前页码(默认 0)
"""
Expand All @@ -384,6 +386,7 @@ def __init__(
self.pipeline_id: str = pipeline_id
self.input_task_id: str = input_task_id
self.output_task_id: str = output_task_id
self.parent_pipeline_id: str = parent_pipeline_id
self.page_size: int = page_size
self.page_num: int = page_num
self.validate_required_params()
Expand All @@ -402,6 +405,9 @@ def read(self, output_type: Literal["dataframe", "dict"]) -> Any:
if self.input_task_id:
where_clauses.append("task_id = %(task_id)s")
params['task_id'] = self.input_task_id
if hasattr(self, 'parent_pipeline_id') and self.parent_pipeline_id:
where_clauses.append("parent_pipeline_id = %(parent_pipeline_id)s")
params['parent_pipeline_id'] = self.parent_pipeline_id
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
limit_offset = f"LIMIT {self.page_size} OFFSET {(self.page_num-1)*self.page_size}" if self.page_size else ""
sql = f"SELECT * FROM {self.table} {where_sql} {limit_offset}"
Expand Down Expand Up @@ -450,20 +456,27 @@ def write(self, data: Any) -> Any:
# 统一处理 data 列
df['data'] = df['data'].apply(lambda x: x if isinstance(x, dict) else (json.loads(x) if isinstance(x, str) else {}))
# 合并所有非系统字段到 data 字段并删除原列
system_cols = {'pipeline_id', 'task_id', 'raw_data_id', 'min_hashes', 'data'}
system_cols = {'pipeline_id', 'task_id', 'raw_data_id', 'min_hashes', 'file_id', 'filename', 'parent_pipeline_id', 'data'}
for col in df.columns:
if col not in system_cols:
df['data'] = df.apply(lambda row: safe_merge(row, col), axis=1)
df = df.drop(columns=[col])
# 自动填充 pipeline_id, task_id, raw_data_id, min_hashes
# 自动填充 pipeline_id, task_id, raw_data_id, min_hashes, file_id, filename, parent_pipeline_id
df['pipeline_id'] = self.pipeline_id
df['task_id'] = self.output_task_id
df['raw_data_id'] = df['data'].apply(lambda d: d.get(SYS_FIELD_PREFIX + 'raw_data_id', 0) if isinstance(d, dict) else 0)
df['min_hashes'] = df['data'].apply(lambda d: _default_min_hashes(d) if isinstance(d, dict) else [0])
# 从 data 中提取 file_id、filename、parent_pipeline_id 字段
df['file_id'] = df['data'].apply(lambda d: d.get(SYS_FIELD_PREFIX + 'file_id', '') if isinstance(d, dict) else '')
df['filename'] = df['data'].apply(lambda d: d.get(SYS_FIELD_PREFIX + 'filename', '') if isinstance(d, dict) else '')
df['parent_pipeline_id'] = df['data'].apply(lambda d: d.get(SYS_FIELD_PREFIX + 'parent_pipeline_id', '') if isinstance(d, dict) else '')
# 若 data 中未提供 parent_pipeline_id,使用实例属性回填
if hasattr(self, 'parent_pipeline_id') and self.parent_pipeline_id:
df['parent_pipeline_id'] = df['parent_pipeline_id'].apply(lambda v: v if v else self.parent_pipeline_id)
# data 字段转为 JSON 字符串
df['data'] = df['data'].apply(lambda x: json.dumps(x, ensure_ascii=False) if not isinstance(x, str) else x)
# 只保留必需字段
required_cols = ['pipeline_id', 'task_id', 'raw_data_id', 'min_hashes', 'data']
required_cols = ['pipeline_id', 'task_id', 'raw_data_id', 'min_hashes', 'file_id', 'filename', 'parent_pipeline_id', 'data']
df = df[required_cols]
records = df.to_dict(orient="records")
values = [
Expand All @@ -472,11 +485,14 @@ def write(self, data: Any) -> Any:
rec['task_id'],
int(rec['raw_data_id']),
rec['min_hashes'],
rec['file_id'],
rec['filename'],
rec['parent_pipeline_id'],
rec['data']
) for rec in records
]
insert_sql = f"""
INSERT INTO {self.table} (pipeline_id, task_id, raw_data_id, min_hashes, data)
INSERT INTO {self.table} (pipeline_id, task_id, raw_data_id, min_hashes, file_id, filename, parent_pipeline_id, data)
VALUES
"""
self.logger.info(f"Inserting {len(values)} rows into {self.table}")
Expand Down