Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8c4380d
context provider
michaelchia Sep 12, 2024
6e79d8b
split base and base command context providers + replacing prompt
michaelchia Sep 12, 2024
e9c394b
comment
michaelchia Sep 12, 2024
67cfcf4
only replace prompt if context variable in template
michaelchia Sep 12, 2024
4cde53a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
9746f66
Run mypy on CI, fix or ignore typing issues (#987)
krassowski Sep 12, 2024
fb97764
context provider
michaelchia Sep 12, 2024
b54d844
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
c382279
mypy
michaelchia Sep 12, 2024
2ee8456
black
michaelchia Sep 12, 2024
2514beb
modify backtick logic
michaelchia Sep 13, 2024
29987e7
allow for spaces in filepath
michaelchia Sep 14, 2024
bf5060a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
b7e7142
refactor
michaelchia Sep 14, 2024
86150a8
fixes
michaelchia Sep 14, 2024
c0ae4d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
d93da8f
fix test
michaelchia Sep 14, 2024
139e37a
refactor autocomplete to remove hardcoded '/' and '@' prefix
michaelchia Sep 17, 2024
83197a7
modify context prompt template
michaelchia Sep 19, 2024
61bab49
refactor
michaelchia Sep 19, 2024
f184dd7
docstrings + refactor
michaelchia Sep 21, 2024
b4e00c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2024
f5a20cf
mypy
michaelchia Sep 21, 2024
758c187
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2024
691cec8
add context providers to help
michaelchia Sep 22, 2024
2b29429
remove _examples.py and remove @learned from defaults
michaelchia Sep 25, 2024
eb5eeb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
ea10442
make find_commands unoverridable
michaelchia Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,25 @@
The following is a friendly conversation between you and a human.
""".strip()

CHAT_DEFAULT_TEMPLATE = """Current conversation:
{history}
Human: {input}
CHAT_DEFAULT_TEMPLATE = """
{% if context %}
Context:
{{context}}

{% endif %}
Current conversation:
{{history}}
Human: {{input}}
AI:"""

HUMAN_MESSAGE_TEMPLATE = """
{% if context %}
Context:
{{context}}

{% endif %}
{{input}}
"""

COMPLETION_SYSTEM_PROMPT = """
You are an application built to provide helpful code completion suggestions.
Expand Down Expand Up @@ -400,17 +414,21 @@ def get_chat_prompt_template(self) -> PromptTemplate:
CHAT_SYSTEM_PROMPT
).format(provider_name=name, local_model_id=self.model_id),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessagePromptTemplate.from_template(
HUMAN_MESSAGE_TEMPLATE,
template_format="jinja2",
),
]
)
else:
return PromptTemplate(
input_variables=["history", "input"],
input_variables=["history", "input", "context"],
template=CHAT_SYSTEM_PROMPT.format(
provider_name=name, local_model_id=self.model_id
)
+ "\n\n"
+ CHAT_DEFAULT_TEMPLATE,
template_format="jinja2",
)

def get_completion_prompt_template(self) -> PromptTemplate:
Expand Down
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from langchain.pydantic_v1 import BaseModel

if TYPE_CHECKING:
from jupyter_ai.context_providers import BaseCommandContextProvider
from jupyter_ai.handlers import RootChatHandler
from jupyter_ai.history import BoundedChatHistory
from langchain_core.chat_history import BaseChatMessageHistory
Expand Down Expand Up @@ -121,6 +122,10 @@ class BaseChatHandler:
chat handlers, which is necessary for some use-cases like printing the help
message."""

context_providers: Dict[str, "BaseCommandContextProvider"]
"""Dictionary of context providers. Allows chat handlers to reference
context providers, which can be used to provide context to the LLM."""

def __init__(
self,
log: Logger,
Expand All @@ -134,6 +139,7 @@ def __init__(
dask_client_future: Awaitable[DaskClient],
help_message_template: str,
chat_handlers: Dict[str, "BaseChatHandler"],
context_providers: Dict[str, "BaseCommandContextProvider"],
):
self.log = log
self.config_manager = config_manager
Expand All @@ -154,6 +160,7 @@ def __init__(
self.dask_client_future = dask_client_future
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers
self.context_providers = context_providers

self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
Expand Down
34 changes: 33 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import time
from typing import Dict, Type
from uuid import uuid4
Expand All @@ -12,6 +13,7 @@
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

from ..context_providers import ContextProviderException
from ..models import HumanChatMessage
from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -27,6 +29,7 @@ class DefaultChatHandler(BaseChatHandler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_template = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -40,6 +43,7 @@ def create_llm_chain(

prompt_template = llm.get_chat_prompt_template()
self.llm = llm
self.prompt_template = prompt_template

runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
Expand Down Expand Up @@ -101,14 +105,25 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False

inputs = {"input": message.body}
if "context" in self.prompt_template.input_variables:
# include context from context providers.
try:
context_prompt = await self.make_context_prompt(message)
except ContextProviderException as e:
self.reply(str(e), message)
return
inputs["context"] = context_prompt
inputs["input"] = self.replace_prompt(inputs["input"])

# start with a pending message
with self.pending("Generating response", message) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
{"input": message.body},
inputs,
config={"configurable": {"last_human_msg": message}},
):
if not received_first_chunk:
Expand All @@ -128,3 +143,20 @@ async def process_message(self, message: HumanChatMessage):

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
await asyncio.gather(
*[
provider.make_context_prompt(human_msg)
for provider in self.context_providers.values()
]
)
)

def replace_prompt(self, prompt: str) -> str:
# modifies prompt by the context providers.
# some providers may modify or remove their '@' commands from the prompt.
for provider in self.context_providers.values():
prompt = provider.replace_prompt(prompt)
return prompt
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/context_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseCommandContextProvider, ContextCommand, ContextProviderException
from .file import FileContextProvider
from .learned import LearnedContextProvider
133 changes: 133 additions & 0 deletions packages/jupyter-ai/jupyter_ai/context_providers/_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# This file is for illustrative purposes
# It is to be deleted before merging
from typing import List

from jupyter_ai.models import HumanChatMessage
from langchain_community.retrievers import ArxivRetriever, WikipediaRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

from .base import BaseCommandContextProvider, ContextCommand, _BaseContextProvider

# Examples of the ease of implementing retriever based context providers
ARXIV_TEMPLATE = """
Title: {title}
Publish Date: {publish_date}
'''
{content}
'''
""".strip()


class ArxivContextProvider(BaseCommandContextProvider):
id = "arvix"
description = "Include papers from Arxiv"
remove_from_prompt = True
header = "Following are snippets of research papers:"

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.retriever = ArxivRetriever()

async def _make_context_prompt(
self, message: HumanChatMessage, commands: List[ContextCommand]
) -> str:
query = self._clean_prompt(message.body)
docs = await self.retriever.ainvoke(query)
context = "\n\n".join(
[
ARXIV_TEMPLATE.format(
content=d.page_content,
title=d.metadata["Title"],
publish_date=d.metadata["Published"],
)
for d in docs
]
)
return self.header + "\n" + context


# Another retriever based context provider with a rewrite step using LLM
WIKI_TEMPLATE = """
Title: {title}
'''
{content}
'''
""".strip()

REWRITE_TEMPLATE = """Provide a better search query for \
web search engine to answer the given question, end \
the queries with ’**’. Question: \
{x} Answer:"""


class WikiContextProvider(BaseCommandContextProvider):
id = "wiki"
description = "Include knowledge from Wikipedia"
remove_from_prompt = True
header = "Following are information from wikipedia:"

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.retriever = WikipediaRetriever()

async def _make_context_prompt(
self, message: HumanChatMessage, commands: List[ContextCommand]
) -> str:
prompt = self._clean_prompt(message.body)
search_query = await self._rewrite_prompt(prompt)
docs = await self.retriever.ainvoke(search_query)
context = "\n\n".join(
[
WIKI_TEMPLATE.format(
content=d.page_content,
title=d.metadata["title"],
)
for d in docs
]
)
return self.header + "\n" + context

async def _rewrite_prompt(self, prompt: str) -> str:
return await self.get_llm_chain().ainvoke(prompt)

def get_llm_chain(self):
# from https://github.com/langchain-ai/langchain/blob/master/cookbook/rewrite.ipynb
llm = self.get_llm()
rewrite_prompt = ChatPromptTemplate.from_template(REWRITE_TEMPLATE)

def _parse(text):
return text.strip('"').strip("**")

return rewrite_prompt | llm | StrOutputParser() | _parse


# Partial example of non-command context provider for errors.
# Assuming there is an option in UI to add cell errors to messages,
# default chat will automatically invoke this context provider to add
# solutions retrieved from a custom error database or a stackoverflow / google
# retriever pipeline to find solutions for errors.
class ErrorContextProvider(_BaseContextProvider):
id = "error"
description = "Include custom error context"
remove_from_prompt = True
header = "Following are potential solutions for the error:"
is_command = False # will not show up in autocomplete

async def make_context_prompt(self, message: HumanChatMessage) -> str:
# will run for every message with a cell error since it does not
# use _find_instances to check for the presence of the command in
# the message.
if not (message.selection and message.selection.type == "cell-with-error"):
return ""
docs = await self.solution_retriever.ainvoke(message.selection)
if not docs:
return ""
context = "\n\n".join([d.page_content for d in docs])
return self.header + "\n" + context

@property
def solution_retriever(self):
# retriever that takes an error and returns a solutions from a database
# of error messages.
raise NotImplementedError("Error retriever not implemented")
Loading