|
9 | 9 | get_response_synthesizer, |
10 | 10 | Response, |
11 | 11 | ) |
12 | | -from llama_index.core.callbacks import CallbackManager, TokenCountingHandler |
| 12 | +from llama_index.core.callbacks import ( |
| 13 | + CallbackManager, |
| 14 | + TokenCountingHandler, |
| 15 | +) |
| 16 | +from llama_index.core.prompts import PromptTemplate |
| 17 | +from llama_index.core.schema import QueryBundle |
13 | 18 | from llama_index.core.retrievers import VectorIndexRetriever |
14 | 19 | from llama_index.core.query_engine import RetrieverQueryEngine |
15 | 20 | from llama_index.llms.openai import OpenAI # type: ignore |
|
47 | 52 | default_output_tracker_file, |
48 | 53 | ) |
49 | 54 | import bcorag.misc_functions as misc_fns |
50 | | -from .prompts import DOMAIN_MAP, QUERY_PROMPT, SUPPLEMENT_PROMPT |
| 55 | +from .prompts import ( |
| 56 | + PROMPT_DOMAIN_MAP, |
| 57 | + RETRIEVAL_PROMPT, |
| 58 | + LLM_PROMPT, |
| 59 | + SUPPLEMENT_PROMPT, |
| 60 | + LLM_PROMPT_TEMPLATE, |
| 61 | +) |
51 | 62 |
|
52 | 63 | # import llama_index.core |
53 | 64 | # llama_index.core.set_global_handler("simple") |
@@ -133,7 +144,7 @@ def __init__( |
133 | 144 | load_dotenv() |
134 | 145 |
|
135 | 146 | self._parameter_set_hash = self._user_selection_hash(user_selections) |
136 | | - self._domain_map = DOMAIN_MAP |
| 147 | + self._domain_map = PROMPT_DOMAIN_MAP |
137 | 148 | self._file_name = user_selections["filename"] |
138 | 149 | self._file_path = user_selections["filepath"] |
139 | 150 | self._output_path_root = os.path.join( |
@@ -285,16 +296,24 @@ def __init__( |
285 | 296 | ) |
286 | 297 | self._index = VectorStoreIndex(nodes=nodes) |
287 | 298 |
|
288 | | - retriever = VectorIndexRetriever( |
289 | | - index=self._index, similarity_top_k=self._similarity_top_k * 3 |
| 299 | + base_retriever = VectorIndexRetriever( |
| 300 | + index=self._index, |
| 301 | + similarity_top_k=self._similarity_top_k * 3, |
| 302 | + ) |
| 303 | + # transform_retriever = TransformRetriever( |
| 304 | + # retriever=base_retriever, |
| 305 | + # query_transform=CustomQueryTransform(delimiter=DELIMITER), |
| 306 | + # ) |
| 307 | + llm_prompt_template = PromptTemplate(template=LLM_PROMPT_TEMPLATE) |
| 308 | + response_synthesizer = get_response_synthesizer( |
| 309 | + text_qa_template=llm_prompt_template |
290 | 310 | ) |
291 | | - response_synthesizer = get_response_synthesizer() |
292 | 311 | rerank_postprocessor = SentenceTransformerRerank( |
293 | 312 | top_n=self._similarity_top_k, |
294 | 313 | keep_retrieval_score=True, |
295 | 314 | ) |
296 | 315 | self._query_engine = RetrieverQueryEngine( |
297 | | - retriever=retriever, |
| 316 | + retriever=base_retriever, |
298 | 317 | response_synthesizer=response_synthesizer, |
299 | 318 | node_postprocessors=[rerank_postprocessor], |
300 | 319 | ) |
@@ -322,16 +341,28 @@ def perform_query(self, domain: DomainKey) -> str: |
322 | 341 | The generated domain. |
323 | 342 | """ |
324 | 343 | query_start_time = time.time() |
325 | | - domain_prompt = self._domain_map[domain]["prompt"] |
| 344 | + domain_retrieval_prompt = self._domain_map[domain]["retrieval_prompt"] |
| 345 | + domain_llm_prompt = self._domain_map[domain]["llm_prompt"] |
| 346 | + |
326 | 347 | for dependency in self._domain_map[domain]["dependencies"]: |
327 | 348 | if self.domain_content[dependency] is not None: |
328 | 349 | dependency_prompt = f"The {domain} domain is dependent on the {dependency} domain. Here is the {dependency} domain: {self.domain_content[dependency]}." |
329 | | - domain_prompt += dependency_prompt |
330 | | - query_prompt = QUERY_PROMPT.format(domain, domain_prompt) |
| 350 | + domain_llm_prompt += dependency_prompt |
| 351 | + |
| 352 | + # full_prompt = f"{RETRIEVAL_PROMPT.format(domain, domain_retrieval_prompt)} {DELIMITER} {LLM_PROMPT.format(domain, domain_llm_prompt)}" |
| 353 | + llm_prompt = f"{LLM_PROMPT.format(domain, domain_llm_prompt)}" |
331 | 354 | if self._domain_map[domain]["top_level"]: |
332 | | - query_prompt += f"\n{SUPPLEMENT_PROMPT}" |
| 355 | + llm_prompt += f"\n{SUPPLEMENT_PROMPT}" |
| 356 | + query_bundle = QueryBundle( |
| 357 | + query_str=llm_prompt, |
| 358 | + custom_embedding_strs=[ |
| 359 | + f"{RETRIEVAL_PROMPT.format(domain, domain_retrieval_prompt)}" |
| 360 | + ], |
| 361 | + embedding=None, |
| 362 | + ) |
| 363 | + |
| 364 | + response_object = self._query_engine.query(query_bundle) |
333 | 365 |
|
334 | | - response_object = self._query_engine.query(query_prompt) |
335 | 366 | if isinstance(response_object, Response): |
336 | 367 | response_object = Response( |
337 | 368 | response=response_object.response, |
@@ -369,7 +400,14 @@ def perform_query(self, domain: DomainKey) -> str: |
369 | 400 | source_str += "\n" |
370 | 401 |
|
371 | 402 | if self._debug: |
372 | | - self._display_info(query_prompt, f"QUERY PROMPT for the {domain} domain:") |
| 403 | + self._display_info( |
| 404 | + query_bundle.query_str, f"LLM PROMPT for the {domain} domain:" |
| 405 | + ) |
| 406 | + if query_bundle.custom_embedding_strs is not None: |
| 407 | + self._display_info( |
| 408 | + query_bundle.custom_embedding_strs[0], |
| 409 | + f"RETRIEVAL PROMPT for the {domain} domain:", |
| 410 | + ) |
373 | 411 | self._token_counts["input"] += self._token_counter.prompt_llm_token_count # type: ignore |
374 | 412 | self._token_counts["output"] += self._token_counter.completion_llm_token_count # type: ignore |
375 | 413 | self._token_counts["total"] += self._token_counter.total_llm_token_count # type: ignore |
|
0 commit comments