Skip to content

Commit 66b7d40

Browse files
authored
Query transform (#24)
* Working on query transform retriever * Use different prompts for retrieval vs LLM prompting * Remove unused delimiter * Update doc title formatting * Update deepeval version * Update for prompt separation
1 parent d8bcdc2 commit 66b7d40

File tree

8 files changed

+219
-332
lines changed

8 files changed

+219
-332
lines changed

bcorag/bcorag.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
get_response_synthesizer,
1010
Response,
1111
)
12-
from llama_index.core.callbacks import CallbackManager, TokenCountingHandler
12+
from llama_index.core.callbacks import (
13+
CallbackManager,
14+
TokenCountingHandler,
15+
)
16+
from llama_index.core.prompts import PromptTemplate
17+
from llama_index.core.schema import QueryBundle
1318
from llama_index.core.retrievers import VectorIndexRetriever
1419
from llama_index.core.query_engine import RetrieverQueryEngine
1520
from llama_index.llms.openai import OpenAI # type: ignore
@@ -47,7 +52,13 @@
4752
default_output_tracker_file,
4853
)
4954
import bcorag.misc_functions as misc_fns
50-
from .prompts import DOMAIN_MAP, QUERY_PROMPT, SUPPLEMENT_PROMPT
55+
from .prompts import (
56+
PROMPT_DOMAIN_MAP,
57+
RETRIEVAL_PROMPT,
58+
LLM_PROMPT,
59+
SUPPLEMENT_PROMPT,
60+
LLM_PROMPT_TEMPLATE,
61+
)
5162

5263
# import llama_index.core
5364
# llama_index.core.set_global_handler("simple")
@@ -133,7 +144,7 @@ def __init__(
133144
load_dotenv()
134145

135146
self._parameter_set_hash = self._user_selection_hash(user_selections)
136-
self._domain_map = DOMAIN_MAP
147+
self._domain_map = PROMPT_DOMAIN_MAP
137148
self._file_name = user_selections["filename"]
138149
self._file_path = user_selections["filepath"]
139150
self._output_path_root = os.path.join(
@@ -285,16 +296,24 @@ def __init__(
285296
)
286297
self._index = VectorStoreIndex(nodes=nodes)
287298

288-
retriever = VectorIndexRetriever(
289-
index=self._index, similarity_top_k=self._similarity_top_k * 3
299+
base_retriever = VectorIndexRetriever(
300+
index=self._index,
301+
similarity_top_k=self._similarity_top_k * 3,
302+
)
303+
# transform_retriever = TransformRetriever(
304+
# retriever=base_retriever,
305+
# query_transform=CustomQueryTransform(delimiter=DELIMITER),
306+
# )
307+
llm_prompt_template = PromptTemplate(template=LLM_PROMPT_TEMPLATE)
308+
response_synthesizer = get_response_synthesizer(
309+
text_qa_template=llm_prompt_template
290310
)
291-
response_synthesizer = get_response_synthesizer()
292311
rerank_postprocessor = SentenceTransformerRerank(
293312
top_n=self._similarity_top_k,
294313
keep_retrieval_score=True,
295314
)
296315
self._query_engine = RetrieverQueryEngine(
297-
retriever=retriever,
316+
retriever=base_retriever,
298317
response_synthesizer=response_synthesizer,
299318
node_postprocessors=[rerank_postprocessor],
300319
)
@@ -322,16 +341,28 @@ def perform_query(self, domain: DomainKey) -> str:
322341
The generated domain.
323342
"""
324343
query_start_time = time.time()
325-
domain_prompt = self._domain_map[domain]["prompt"]
344+
domain_retrieval_prompt = self._domain_map[domain]["retrieval_prompt"]
345+
domain_llm_prompt = self._domain_map[domain]["llm_prompt"]
346+
326347
for dependency in self._domain_map[domain]["dependencies"]:
327348
if self.domain_content[dependency] is not None:
328349
dependency_prompt = f"The {domain} domain is dependent on the {dependency} domain. Here is the {dependency} domain: {self.domain_content[dependency]}."
329-
domain_prompt += dependency_prompt
330-
query_prompt = QUERY_PROMPT.format(domain, domain_prompt)
350+
domain_llm_prompt += dependency_prompt
351+
352+
# full_prompt = f"{RETRIEVAL_PROMPT.format(domain, domain_retrieval_prompt)} {DELIMITER} {LLM_PROMPT.format(domain, domain_llm_prompt)}"
353+
llm_prompt = f"{LLM_PROMPT.format(domain, domain_llm_prompt)}"
331354
if self._domain_map[domain]["top_level"]:
332-
query_prompt += f"\n{SUPPLEMENT_PROMPT}"
355+
llm_prompt += f"\n{SUPPLEMENT_PROMPT}"
356+
query_bundle = QueryBundle(
357+
query_str=llm_prompt,
358+
custom_embedding_strs=[
359+
f"{RETRIEVAL_PROMPT.format(domain, domain_retrieval_prompt)}"
360+
],
361+
embedding=None,
362+
)
363+
364+
response_object = self._query_engine.query(query_bundle)
333365

334-
response_object = self._query_engine.query(query_prompt)
335366
if isinstance(response_object, Response):
336367
response_object = Response(
337368
response=response_object.response,
@@ -369,7 +400,14 @@ def perform_query(self, domain: DomainKey) -> str:
369400
source_str += "\n"
370401

371402
if self._debug:
372-
self._display_info(query_prompt, f"QUERY PROMPT for the {domain} domain:")
403+
self._display_info(
404+
query_bundle.query_str, f"LLM PROMPT for the {domain} domain:"
405+
)
406+
if query_bundle.custom_embedding_strs is not None:
407+
self._display_info(
408+
query_bundle.custom_embedding_strs[0],
409+
f"RETRIEVAL PROMPT for the {domain} domain:",
410+
)
373411
self._token_counts["input"] += self._token_counter.prompt_llm_token_count # type: ignore
374412
self._token_counts["output"] += self._token_counter.completion_llm_token_count # type: ignore
375413
self._token_counts["total"] += self._token_counter.total_llm_token_count # type: ignore

bcorag/custom_types/core_types.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,10 @@ class IndividualDomainMapEntry(TypedDict):
350350
351351
Attributes
352352
----------
353-
prompt : str
354-
The prompt to use for querying the RAG pipeline for a specific domain generation.
353+
retrieval_prompt : str
354+
The prompt to use for the RAG pipeline retrieval process.
355+
llm_prompt : str
356+
The prompt to use for the LLM.
355357
top_level : bool
356358
Whether the specified domain includes object's defined in the top level JSON schema.
357359
user_prompt : str
@@ -362,7 +364,8 @@ class IndividualDomainMapEntry(TypedDict):
362364
The domain dependencies.
363365
"""
364366

365-
prompt: str
367+
retrieval_prompt: str
368+
llm_prompt: str
366369
top_level: bool
367370
user_prompt: str
368371
code: str

bcorag/prompts/__init__.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from bcorag.custom_types.core_types import DomainMap
2+
from .retrieval import (
3+
RETRIEVAL_PROMPT,
4+
USABILITY_DOMAIN_RETRIEVAL,
5+
IO_DOMAIN_RETRIEVAL,
6+
DESCRIPTION_DOMAIN_RETRIEVAL,
7+
EXECUTION_DOMAIN_RETRIEVAL,
8+
PARAMETRIC_DOMAIN_RETRIEVAL,
9+
ERROR_DOMAIN_RETRIEVAL,
10+
)
11+
from .llm_prompts import (
12+
LLM_PROMPT,
13+
USABILITY_DOMAIN_LLM,
14+
IO_DOMAIN_LLM,
15+
DESCRIPTION_DOMAIN_LLM,
16+
EXECUTION_DOMAIN_LLM,
17+
PARAMETRIC_DOMAIN_LLM,
18+
ERROR_DOMAIN_LLM,
19+
SUPPLEMENT_PROMPT,
20+
)
21+
22+
LLM_PROMPT_TEMPLATE = """
23+
Below is some excerpts from a bioinformatics project. The information is from the project's publication and could also contain some information from the project's code repository.
24+
25+
{context_str}
26+
27+
---------\n
28+
29+
{query_str}
30+
"""
31+
32+
33+
PROMPT_DOMAIN_MAP: DomainMap = {
34+
"usability": {
35+
"retrieval_prompt": USABILITY_DOMAIN_RETRIEVAL,
36+
"llm_prompt": USABILITY_DOMAIN_LLM,
37+
"top_level": False,
38+
"user_prompt": "[u]sability",
39+
"code": "u",
40+
"dependencies": [],
41+
},
42+
"io": {
43+
"retrieval_prompt": IO_DOMAIN_RETRIEVAL,
44+
"llm_prompt": IO_DOMAIN_LLM,
45+
"top_level": True,
46+
"user_prompt": "[i]o",
47+
"code": "i",
48+
"dependencies": [],
49+
},
50+
"description": {
51+
"retrieval_prompt": DESCRIPTION_DOMAIN_RETRIEVAL,
52+
"llm_prompt": DESCRIPTION_DOMAIN_LLM,
53+
"top_level": True,
54+
"user_prompt": "[d]escription",
55+
"code": "d",
56+
"dependencies": [],
57+
},
58+
"execution": {
59+
"retrieval_prompt": EXECUTION_DOMAIN_RETRIEVAL,
60+
"llm_prompt": EXECUTION_DOMAIN_LLM,
61+
"top_level": True,
62+
"user_prompt": "[e]xecution",
63+
"code": "e",
64+
"dependencies": [],
65+
},
66+
"parametric": {
67+
"retrieval_prompt": PARAMETRIC_DOMAIN_RETRIEVAL,
68+
"llm_prompt": PARAMETRIC_DOMAIN_LLM,
69+
"top_level": False,
70+
"user_prompt": "[p]arametric",
71+
"code": "p",
72+
"dependencies": ["description"],
73+
},
74+
"error": {
75+
"retrieval_prompt": ERROR_DOMAIN_RETRIEVAL,
76+
"llm_prompt": ERROR_DOMAIN_LLM,
77+
"top_level": False,
78+
"user_prompt": "[err]or",
79+
"code": "err",
80+
"dependencies": [],
81+
},
82+
}

0 commit comments

Comments
 (0)