Skip to content

Commit b6994da

Browse files
authored
[Refactor] Rename core to agent (#55)
* rename core to agent * add templates module and templates attribute to TaskHandler * changelog
1 parent 3d52224 commit b6994da

File tree

8 files changed

+87
-54
lines changed

8 files changed

+87
-54
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
1010

1111
### Changed
1212

13+
- Rename `llm_agents_from_scratch.core` to `llm_agents_from_scratch.agent` (#55)
1314
- Revised `TaskHandler.get_next_step()` to return `TaskStep | TaskResult` (#54)
1415
- Fixed bug in `OllamaLLM.chat()` where chat history was coming after user message (#51)
1516
- Fixed bug in `TaskHandler.run_step()` where tool names were passed to `llm.chat()` (#46)
1617

1718
### Added
1819

20+
- Add `~agent.templates` module and add `TaskHandler.templates` attribute (#55)
1921
- Add `enable_console_logging` and `disable_console_logging` to not stream logs as a library by default (#54)
2022
- Add first working cookbook for a simple `LLMAgent` and task (#54)
2123
- Add `data_structures.task_handler.GetNextStep` (#54)

src/llm_agents_from_scratch/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
# Disable the F403 warning for wildcard imports
66
# ruff: noqa: F403, F401
7-
from .core import *
8-
from .core import __all__ as _core_all
7+
from .agent import *
8+
from .agent import __all__ as _agent_all
99
from .tools import *
1010
from .tools import __all__ as _tool_all
1111

1212
__version__ = VERSION
1313

1414

15-
__all__ = sorted(_core_all + _tool_all) # noqa: PLE0605
15+
__all__ = sorted(_agent_all + _tool_all) # noqa: PLE0605
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .agent import LLMAgent
1+
from .core import LLMAgent
22
from .task_handler import TaskHandler
33

44
__all__ = ["LLMAgent", "TaskHandler"]

src/llm_agents_from_scratch/core/task_handler.py renamed to src/llm_agents_from_scratch/agent/task_handler.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,43 +18,7 @@
1818
from llm_agents_from_scratch.errors import TaskHandlerError
1919
from llm_agents_from_scratch.logger import get_logger
2020

21-
DEFAULT_GET_NEXT_INSTRUCTION_PROMPT = """You are overseeing an assistant's
22-
progress in accomplishing a user instruction. Provided below is the assistant's
23-
current response to the original task instruction. Also provided, is an
24-
internal 'thinking' process of the assistant that the user has not seen.
25-
26-
Determine if the current the response is sufficient to answer the original task
27-
instruction. In the case that it is not, provide a new instruction to the
28-
assistant in order to help them improve upon their current response.
29-
30-
<user-instruction>
31-
{instruction}
32-
</user-instruction>
33-
34-
<current-response>
35-
{current_response}
36-
</current-response>
37-
38-
<thinking-process>
39-
{current_rollout}
40-
</thinking-process>
41-
"""
42-
43-
DEFAULT_USER_MESSAGE = "{instruction}"
44-
45-
DEFAULT_ROLLOUT_BLOCK_FROM_CHAT_MESSAGE = "{role}: {content}"
46-
47-
DEFAULT_SYSTEM_MESSAGE_WITHOUT_ROLLOUT = """You are a helpful assistant."""
48-
49-
DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant.
50-
51-
Here is some past dialogue and context, where another assistant was working
52-
towards completing the task.
53-
54-
<history>
55-
{current_rollout}
56-
</history>
57-
"""
21+
from .templates import TaskHandlerTemplates, default_task_handler_templates
5822

5923

6024
class TaskHandler(asyncio.Future):
@@ -64,6 +28,7 @@ class TaskHandler(asyncio.Future):
6428
task: The task to execute.
6529
llm: The backbone LLM.
6630
tools_registry: The tools the LLM agent can use represented as a dict.
31+
templates: Associated prompt templates.
6732
rollout: The execution log of the task.
6833
logger: TaskHandler logger.
6934
"""
@@ -73,6 +38,7 @@ def __init__(
7338
task: Task,
7439
llm: BaseLLM,
7540
tools: list[BaseTool | AsyncBaseTool],
41+
templates: TaskHandlerTemplates = default_task_handler_templates,
7642
*args: Any,
7743
**kwargs: Any,
7844
) -> None:
@@ -82,7 +48,7 @@ def __init__(
8248
task (Task): The task to process.
8349
llm (BaseLLM): The backbone LLM.
8450
tools (list[BaseTool]): The tools the LLM can use.
85-
logger
51+
templates (TaskHandlerTemplates): Associated prompt templates.
8652
*args: Additional positional arguments.
8753
**kwargs: Additional keyword arguments.
8854
"""
@@ -91,6 +57,7 @@ def __init__(
9157
self.llm = llm
9258
self.tools_registry = {t.name: t for t in tools}
9359
self.rollout = ""
60+
self.templates = templates
9461
self._background_task: asyncio.Task | None = None
9562
self._lock: asyncio.Lock = asyncio.Lock()
9663
self.logger = get_logger(self.__class__.__name__)
@@ -135,7 +102,7 @@ def _rollout_contribution_from_single_run_step(
135102
)
136103

137104
rollout_contributions.append(
138-
DEFAULT_ROLLOUT_BLOCK_FROM_CHAT_MESSAGE.format(
105+
self.templates["rollout_block_from_chat_message"].format(
139106
role=role.value,
140107
content=content,
141108
),
@@ -161,7 +128,7 @@ async def get_next_step(
161128
instruction=self.task.instruction,
162129
last_step=False,
163130
)
164-
prompt = DEFAULT_GET_NEXT_INSTRUCTION_PROMPT.format(
131+
prompt = self.templates["get_next_step"].format(
165132
instruction=self.task.instruction,
166133
current_rollout=rollout,
167134
current_response=previous_step_result.content,
@@ -217,17 +184,17 @@ async def run_step(self, step: TaskStep) -> TaskStepResult:
217184
# include rollout as context in the system message
218185
system_message = ChatMessage(
219186
role=ChatRole.SYSTEM,
220-
content=DEFAULT_SYSTEM_MESSAGE.format(
187+
content=self.templates["system_message"].format(
221188
original_instruction=self.task.instruction,
222189
current_rollout=rollout,
223190
)
224191
if rollout
225-
else DEFAULT_SYSTEM_MESSAGE_WITHOUT_ROLLOUT,
192+
else self.templates["system_message_without_rollout"],
226193
)
227194
self.logger.debug(f"💬 SYSTEM: {system_message.content}")
228195
user_message = ChatMessage(
229196
role=ChatRole.USER,
230-
content=DEFAULT_USER_MESSAGE.format(
197+
content=self.templates["user_message"].format(
231198
instruction=step.instruction,
232199
),
233200
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Prompt templates for LLMAgent (TaskHandler)."""
2+
3+
from typing import TypedDict
4+
5+
DEFAULT_GET_NEXT_INSTRUCTION_PROMPT = """You are overseeing an assistant's
6+
progress in accomplishing a user instruction. Provided below is the assistant's
7+
current response to the original task instruction. Also provided, is an
8+
internal 'thinking' process of the assistant that the user has not seen.
9+
10+
Determine if the current the response is sufficient to answer the original task
11+
instruction. In the case that it is not, provide a new instruction to the
12+
assistant in order to help them improve upon their current response.
13+
14+
<user-instruction>
15+
{instruction}
16+
</user-instruction>
17+
18+
<current-response>
19+
{current_response}
20+
</current-response>
21+
22+
<thinking-process>
23+
{current_rollout}
24+
</thinking-process>
25+
"""
26+
27+
DEFAULT_USER_MESSAGE = "{instruction}"
28+
29+
DEFAULT_ROLLOUT_BLOCK_FROM_CHAT_MESSAGE = "{role}: {content}"
30+
31+
DEFAULT_SYSTEM_MESSAGE_WITHOUT_ROLLOUT = """You are a helpful assistant."""
32+
33+
DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant.
34+
35+
Here is some past dialogue and context, where another assistant was working
36+
towards completing the task.
37+
38+
<history>
39+
{current_rollout}
40+
</history>
41+
"""
42+
43+
44+
class TaskHandlerTemplates(TypedDict):
45+
"""Prompt templates dict for TaskHandler."""
46+
47+
get_next_step: str
48+
rollout_block_from_chat_message: str
49+
system_message_without_rollout: str
50+
system_message: str
51+
user_message: str
52+
53+
54+
default_task_handler_templates = TaskHandlerTemplates(
55+
get_next_step=DEFAULT_GET_NEXT_INSTRUCTION_PROMPT,
56+
rollout_block_from_chat_message=DEFAULT_ROLLOUT_BLOCK_FROM_CHAT_MESSAGE,
57+
system_message_without_rollout=DEFAULT_SYSTEM_MESSAGE_WITHOUT_ROLLOUT,
58+
system_message=DEFAULT_SYSTEM_MESSAGE,
59+
user_message=DEFAULT_USER_MESSAGE,
60+
)

tests/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import pytest
66

7+
from llm_agents_from_scratch.agent import LLMAgent, TaskHandler
78
from llm_agents_from_scratch.base.llm import BaseLLM
8-
from llm_agents_from_scratch.core import LLMAgent, TaskHandler
99
from llm_agents_from_scratch.data_structures.agent import (
1010
Task,
1111
TaskResult,

tests/test_task_handler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
import pytest
66

7-
from llm_agents_from_scratch.base.llm import BaseLLM
8-
from llm_agents_from_scratch.core import TaskHandler
9-
from llm_agents_from_scratch.core.task_handler import (
10-
DEFAULT_SYSTEM_MESSAGE_WITHOUT_ROLLOUT,
7+
from llm_agents_from_scratch.agent import TaskHandler
8+
from llm_agents_from_scratch.agent.templates import (
9+
default_task_handler_templates,
1110
)
11+
from llm_agents_from_scratch.base.llm import BaseLLM
1212
from llm_agents_from_scratch.data_structures import (
1313
ChatMessage,
1414
ChatRole,
@@ -367,7 +367,9 @@ async def plus_two(arg1: int) -> int:
367367
chat_messages=[
368368
ChatMessage(
369369
role=ChatRole.SYSTEM,
370-
content=DEFAULT_SYSTEM_MESSAGE_WITHOUT_ROLLOUT.format(
370+
content=default_task_handler_templates[
371+
"system_message_without_rollout"
372+
].format(
371373
original_instruction="mock instruction",
372374
current_rollout="",
373375
),
@@ -410,7 +412,9 @@ async def test_run_step_without_tool_calls() -> None:
410412
chat_messages=[
411413
ChatMessage(
412414
role=ChatRole.SYSTEM,
413-
content=DEFAULT_SYSTEM_MESSAGE_WITHOUT_ROLLOUT.format(
415+
content=default_task_handler_templates[
416+
"system_message_without_rollout"
417+
].format(
414418
original_instruction="mock instruction",
415419
current_rollout="",
416420
),

0 commit comments

Comments
 (0)