|
51 | 51 | import numpy as np |
52 | 52 | from peewee import DoesNotExist |
53 | 53 | from common.constants import LLMType, ParserType, PipelineTaskType |
54 | | -from api.db.services.document_service import DocumentService |
| 54 | +from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks |
55 | 55 | from api.db.services.llm_service import LLMBundle |
56 | 56 | from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID |
57 | 57 | from api.db.services.file2document_service import File2DocumentService |
|
68 | 68 | from common.exceptions import TaskCanceledException |
69 | 69 | from common import settings |
70 | 70 | from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME |
| 71 | +from croniter import croniter |
71 | 72 |
|
72 | 73 | BATCH_SIZE = 64 |
73 | 74 |
|
@@ -641,6 +642,8 @@ def dict_update(meta): |
641 | 642 | logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) |
642 | 643 | PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) |
643 | 644 |
|
| 645 | + trigger_update_after(task_dataset_id, doc_id) |
| 646 | + |
644 | 647 |
|
645 | 648 | @timeout(3600) |
646 | 649 | async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]): |
@@ -747,6 +750,27 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c |
747 | 750 | return True |
748 | 751 |
|
749 | 752 |
|
| 753 | +def trigger_update_after(kb_id: str, doc_id: str): |
| 754 | + try: |
| 755 | + ok, kb = KnowledgebaseService.get_by_id(kb_id) |
| 756 | + if not ok: |
| 757 | + return |
| 758 | + conf = kb.parser_config or {} |
| 759 | + gconf = conf.get("graphrag") or {} |
| 760 | + rconf = conf.get("raptor") or {} |
| 761 | + if gconf.get("use_graphrag") and gconf.get("strategy") == "update_after": |
| 762 | + docs, _ = DocumentService.get_by_kb_id(kb_id=kb.id, page_number=0, items_per_page=0, orderby="create_time", desc=False, keywords="", run_status=[], types=[], suffix=[]) |
| 763 | + sample_document = docs[0] if docs else {"id": doc_id} |
| 764 | + tid = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=[doc_id]) |
| 765 | + KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": tid}) |
| 766 | + if rconf.get("use_raptor") and rconf.get("strategy") == "update_after": |
| 767 | + docs, _ = DocumentService.get_by_kb_id(kb_id=kb.id, page_number=0, items_per_page=0, orderby="create_time", desc=False, keywords="", run_status=[], types=[], suffix=[]) |
| 768 | + sample_document = docs[0] if docs else {"id": doc_id} |
| 769 | + tid = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=[doc_id]) |
| 770 | + KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": tid}) |
| 771 | + except Exception: |
| 772 | + pass |
| 773 | + |
750 | 774 | @timeout(60*60*3, 1) |
751 | 775 | async def do_handle_task(task): |
752 | 776 | task_type = task.get("task_type", "") |
@@ -948,6 +972,7 @@ async def do_handle_task(task): |
948 | 972 | "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, |
949 | 973 | task_to_page, len(chunks), |
950 | 974 | token_count, task_time_cost)) |
| 975 | + trigger_update_after(task_dataset_id, task_doc_id) |
951 | 976 |
|
952 | 977 |
|
953 | 978 | async def handle_task(): |
@@ -1062,6 +1087,90 @@ async def task_manager(): |
1062 | 1087 | task_limiter.release() |
1063 | 1088 |
|
1064 | 1089 |
|
| 1090 | +async def _due(cron: str, last_finish: datetime): |
| 1091 | + try: |
| 1092 | + if not cron: |
| 1093 | + return False |
| 1094 | + if not croniter.is_valid(cron): |
| 1095 | + return False |
| 1096 | + slot = datetime.now().replace(second=0, microsecond=0) |
| 1097 | + prev_time = croniter(cron, slot).get_prev(datetime) |
| 1098 | + if last_finish and last_finish >= prev_time: |
| 1099 | + return False |
| 1100 | + return True |
| 1101 | + except Exception: |
| 1102 | + return False |
| 1103 | + |
| 1104 | + |
| 1105 | +async def scheduler(): |
| 1106 | + while not stop_event.is_set(): |
| 1107 | + try: |
| 1108 | + def _doc_finish_ts_ms(doc): |
| 1109 | + pb = doc.get("process_begin_at") |
| 1110 | + dur = doc.get("process_duration") or 0 |
| 1111 | + if not pb: |
| 1112 | + return None |
| 1113 | + try: |
| 1114 | + pb_ts_ms = int(pb.timestamp() * 1000) |
| 1115 | + except Exception: |
| 1116 | + return None |
| 1117 | + return pb_ts_ms + int(dur * 1000) |
| 1118 | + |
| 1119 | + def _schedule_if_needed(kb, changed_docs, ty): |
| 1120 | + if not changed_docs: |
| 1121 | + return |
| 1122 | + if ty == "graphrag": |
| 1123 | + task_id = kb.graphrag_task_id |
| 1124 | + else: |
| 1125 | + task_id = kb.raptor_task_id |
| 1126 | + skip = False |
| 1127 | + if task_id: |
| 1128 | + ok, t = TaskService.get_by_id(task_id) |
| 1129 | + skip = bool(ok and t and t.progress not in [-1, 1]) |
| 1130 | + if skip: |
| 1131 | + return |
| 1132 | + sample_document = changed_docs[0] |
| 1133 | + document_ids = [d["id"] for d in changed_docs] |
| 1134 | + tid = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty=ty, priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=document_ids) |
| 1135 | + if ty == "graphrag": |
| 1136 | + KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": tid}) |
| 1137 | + else: |
| 1138 | + KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": tid}) |
| 1139 | + |
| 1140 | + ids = KnowledgebaseService.get_all_ids() |
| 1141 | + for kb_id in ids: |
| 1142 | + ok, kb = KnowledgebaseService.get_by_id(kb_id) |
| 1143 | + if not ok: |
| 1144 | + continue |
| 1145 | + conf = kb.parser_config or {} |
| 1146 | + gconf = (conf.get("graphrag") or {}) |
| 1147 | + rconf = (conf.get("raptor") or {}) |
| 1148 | + if gconf.get("use_graphrag") and gconf.get("strategy") == "timed" and gconf.get("cron"): |
| 1149 | + if await _due(gconf.get("cron"), kb.graphrag_task_finish_at): |
| 1150 | + documents, _ = DocumentService.get_by_kb_id(kb_id=kb.id, page_number=0, items_per_page=0, orderby="create_time", desc=False, keywords="", run_status=[], types=[], suffix=[]) |
| 1151 | + if documents: |
| 1152 | + finish_dt = kb.graphrag_task_finish_at |
| 1153 | + changed_docs = documents |
| 1154 | + if finish_dt: |
| 1155 | + finish_ts_ms = int(finish_dt.timestamp() * 1000) |
| 1156 | + changed_docs = [d for d in documents if (lambda t: t is not None and t > finish_ts_ms)(_doc_finish_ts_ms(d))] |
| 1157 | + _schedule_if_needed(kb, changed_docs, "graphrag") |
| 1158 | + if rconf.get("use_raptor") and rconf.get("strategy") == "timed" and rconf.get("cron"): |
| 1159 | + if await _due(rconf.get("cron"), kb.raptor_task_finish_at): |
| 1160 | + documents, _ = DocumentService.get_by_kb_id(kb_id=kb.id, page_number=0, items_per_page=0, orderby="create_time", desc=False, keywords="", run_status=[], types=[], suffix=[]) |
| 1161 | + if documents: |
| 1162 | + finish_dt = kb.raptor_task_finish_at |
| 1163 | + changed_docs = documents |
| 1164 | + if finish_dt: |
| 1165 | + finish_ts_ms = int(finish_dt.timestamp() * 1000) |
| 1166 | + changed_docs = [d for d in documents if (lambda t: t is not None and t > finish_ts_ms)(_doc_finish_ts_ms(d))] |
| 1167 | + _schedule_if_needed(kb, changed_docs, "raptor") |
| 1168 | + except Exception as e: |
| 1169 | + logging.exception(e) |
| 1170 | + pass |
| 1171 | + await trio.sleep(60) # Special tasks take a long time to run, so the start time of scheduled tasks does not need to be very precise |
| 1172 | + |
| 1173 | + |
1065 | 1174 | async def main(): |
1066 | 1175 | logging.info(r""" |
1067 | 1176 | ____ __ _ |
@@ -1089,6 +1198,7 @@ async def main(): |
1089 | 1198 |
|
1090 | 1199 | async with trio.open_nursery() as nursery: |
1091 | 1200 | nursery.start_soon(report_status) |
| 1201 | + nursery.start_soon(scheduler) |
1092 | 1202 | while not stop_event.is_set(): |
1093 | 1203 | await task_limiter.acquire() |
1094 | 1204 | nursery.start_soon(task_manager) |
|
0 commit comments