Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8c4380d
context provider
michaelchia Sep 12, 2024
6e79d8b
split base and base command context providers + replacing prompt
michaelchia Sep 12, 2024
e9c394b
comment
michaelchia Sep 12, 2024
67cfcf4
only replace prompt if context variable in template
michaelchia Sep 12, 2024
4cde53a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
9746f66
Run mypy on CI, fix or ignore typing issues (#987)
krassowski Sep 12, 2024
fb97764
context provider
michaelchia Sep 12, 2024
b54d844
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
c382279
mypy
michaelchia Sep 12, 2024
2ee8456
black
michaelchia Sep 12, 2024
2514beb
modify backtick logic
michaelchia Sep 13, 2024
29987e7
allow for spaces in filepath
michaelchia Sep 14, 2024
bf5060a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
b7e7142
refactor
michaelchia Sep 14, 2024
86150a8
fixes
michaelchia Sep 14, 2024
c0ae4d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
d93da8f
fix test
michaelchia Sep 14, 2024
139e37a
refactor autocomplete to remove hardcoded '/' and '@' prefix
michaelchia Sep 17, 2024
83197a7
modify context prompt template
michaelchia Sep 19, 2024
61bab49
refactor
michaelchia Sep 19, 2024
f184dd7
docstrings + refactor
michaelchia Sep 21, 2024
b4e00c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2024
f5a20cf
mypy
michaelchia Sep 21, 2024
758c187
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2024
691cec8
add context providers to help
michaelchia Sep 22, 2024
2b29429
remove _examples.py and remove @learned from defaults
michaelchia Sep 25, 2024
eb5eeb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
ea10442
make find_commands unoverridable
michaelchia Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,25 @@
The following is a friendly conversation between you and a human.
""".strip()

CHAT_DEFAULT_TEMPLATE = """Current conversation:
{history}
Human: {input}
CHAT_DEFAULT_TEMPLATE = """
{% if context %}
Context:
{{context}}

{% endif %}
Current conversation:
{{history}}
Human: {{input}}
AI:"""

HUMAN_MESSAGE_TEMPLATE = """
{% if context %}
Context:
{{context}}

{% endif %}
{{input}}
"""

COMPLETION_SYSTEM_PROMPT = """
You are an application built to provide helpful code completion suggestions.
Expand Down Expand Up @@ -400,17 +414,21 @@ def get_chat_prompt_template(self) -> PromptTemplate:
CHAT_SYSTEM_PROMPT
).format(provider_name=name, local_model_id=self.model_id),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessagePromptTemplate.from_template(
HUMAN_MESSAGE_TEMPLATE,
template_format="jinja2",
),
]
)
else:
return PromptTemplate(
input_variables=["history", "input"],
input_variables=["history", "input", "context"],
template=CHAT_SYSTEM_PROMPT.format(
provider_name=name, local_model_id=self.model_id
)
+ "\n\n"
+ CHAT_DEFAULT_TEMPLATE,
template_format="jinja2",
)

def get_completion_prompt_template(self) -> PromptTemplate:
Expand Down
18 changes: 17 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from langchain.pydantic_v1 import BaseModel

if TYPE_CHECKING:
from jupyter_ai.context_providers import BaseCommandContextProvider
from jupyter_ai.handlers import RootChatHandler
from jupyter_ai.history import BoundedChatHistory
from langchain_core.chat_history import BaseChatMessageHistory
Expand Down Expand Up @@ -121,6 +122,10 @@ class BaseChatHandler:
chat handlers, which is necessary for some use-cases like printing the help
message."""

context_providers: Dict[str, "BaseCommandContextProvider"]
"""Dictionary of context providers. Allows chat handlers to reference
context providers, which can be used to provide context to the LLM."""

def __init__(
self,
log: Logger,
Expand All @@ -134,6 +139,7 @@ def __init__(
dask_client_future: Awaitable[DaskClient],
help_message_template: str,
chat_handlers: Dict[str, "BaseChatHandler"],
context_providers: Dict[str, "BaseCommandContextProvider"],
):
self.log = log
self.config_manager = config_manager
Expand All @@ -154,6 +160,7 @@ def __init__(
self.dask_client_future = dask_client_future
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers
self.context_providers = context_providers

self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
Expand Down Expand Up @@ -430,8 +437,17 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non
]
)

context_commands_list = "\n".join(
[
f"* `{cp.command_id}` — {cp.help}"
for cp in self.context_providers.values()
]
)

help_message_body = self.help_message_template.format(
persona_name=self.persona.name, slash_commands_list=slash_commands_list
persona_name=self.persona.name,
slash_commands_list=slash_commands_list,
context_commands_list=context_commands_list,
)
help_message = AgentChatMessage(
id=uuid4().hex,
Expand Down
35 changes: 34 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import time
from typing import Dict, Type
from uuid import uuid4
Expand All @@ -12,6 +13,7 @@
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

from ..context_providers import ContextProviderException, find_commands
from ..models import HumanChatMessage
from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -27,6 +29,7 @@ class DefaultChatHandler(BaseChatHandler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_template = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -40,6 +43,7 @@ def create_llm_chain(

prompt_template = llm.get_chat_prompt_template()
self.llm = llm
self.prompt_template = prompt_template

runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
Expand Down Expand Up @@ -101,14 +105,25 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False

inputs = {"input": message.body}
if "context" in self.prompt_template.input_variables:
# include context from context providers.
try:
context_prompt = await self.make_context_prompt(message)
except ContextProviderException as e:
self.reply(str(e), message)
return
inputs["context"] = context_prompt
inputs["input"] = self.replace_prompt(inputs["input"])

# start with a pending message
with self.pending("Generating response", message) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
{"input": message.body},
inputs,
config={"configurable": {"last_human_msg": message}},
):
if not received_first_chunk:
Expand All @@ -128,3 +143,21 @@ async def process_message(self, message: HumanChatMessage):

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
await asyncio.gather(
*[
provider.make_context_prompt(human_msg)
for provider in self.context_providers.values()
if find_commands(provider, human_msg.prompt)
]
)
)

def replace_prompt(self, prompt: str) -> str:
# modifies prompt by the context providers.
# some providers may modify or remove their '@' commands from the prompt.
for provider in self.context_providers.values():
prompt = provider.replace_prompt(prompt)
return prompt
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/context_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .base import (
BaseCommandContextProvider,
ContextCommand,
ContextProviderException,
find_commands,
)
from .file import FileContextProvider
53 changes: 53 additions & 0 deletions packages/jupyter-ai/jupyter_ai/context_providers/_learned.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Currently unused as it is duplicating the functionality of the /ask command.
# TODO: Rename "learned" to something better.
from typing import List

from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai.models import HumanChatMessage

from .base import BaseCommandContextProvider, ContextCommand
from .file import FileContextProvider

FILE_CHUNK_TEMPLATE = """
Snippet from file: {filepath}
```
{content}
```
""".strip()


class LearnedContextProvider(BaseCommandContextProvider):
id = "learned"
help = "Include content indexed from `/learn`"
remove_from_prompt = True
header = "Following are snippets from potentially relevant files:"

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.retriever = Retriever(learn_chat_handler=self.chat_handlers["/learn"])

async def _make_context_prompt(
self, message: HumanChatMessage, commands: List[ContextCommand]
) -> str:
if not self.retriever:
return ""
query = self._clean_prompt(message.body)
docs = await self.retriever.ainvoke(query)
excluded = self._get_repeated_files(message)
context = "\n\n".join(
[
FILE_CHUNK_TEMPLATE.format(
filepath=d.metadata["path"], content=d.page_content
)
for d in docs
if d.metadata["path"] not in excluded and d.page_content
]
)
return self.header + "\n" + context

def _get_repeated_files(self, message: HumanChatMessage) -> List[str]:
# don't include files that are already provided by the file context provider
file_context_provider = self.context_providers.get("file")
if isinstance(file_context_provider, FileContextProvider):
return file_context_provider.get_filepaths(message)
return []
Loading
Loading