Skip to content

Commit c67e75d

Browse files
authored
Merge pull request #700 from Kiln-AI/leonard/kil-123-fix-tool-interface-changed-broke-ragtool
fix: kiln tool interface run method change
2 parents 938d4da + eba9422 commit c67e75d

File tree

2 files changed

+98
-10
lines changed

2 files changed

+98
-10
lines changed

libs/core/kiln_ai/tools/rag_tools.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import cached_property
2-
from typing import Any, Dict, List
2+
from typing import Any, Dict, List, TypedDict
33

44
from pydantic import BaseModel
55

@@ -18,7 +18,7 @@
1818
from kiln_ai.datamodel.rag import RagConfig
1919
from kiln_ai.datamodel.tool_id import ToolId
2020
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
21-
from kiln_ai.tools.base_tool import KilnToolInterface
21+
from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
2222
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
2323

2424

@@ -46,6 +46,10 @@ def format_search_results(search_results: List[SearchResult]) -> str:
4646
return "\n=========\n".join([result.serialize() for result in results])
4747

4848

49+
class RagParams(TypedDict):
50+
query: str
51+
52+
4953
class RagTool(KilnToolInterface):
5054
"""
5155
A tool that searches the vector store and returns the most relevant chunks.
@@ -126,7 +130,10 @@ async def toolcall_definition(self) -> Dict[str, Any]:
126130
},
127131
}
128132

129-
async def run(self, query: str) -> str:
133+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
134+
kwargs = RagParams(**kwargs)
135+
query = kwargs["query"]
136+
130137
_, embedding_adapter = self.embedding
131138

132139
vector_store_adapter = await self.vector_store()
@@ -152,6 +159,6 @@ async def run(self, query: str) -> str:
152159
store_query.query_embedding = query_embedding_result.embeddings[0].vector
153160

154161
search_results = await vector_store_adapter.search(store_query)
155-
context = format_search_results(search_results)
162+
search_results_as_text = format_search_results(search_results)
156163

157-
return context
164+
return search_results_as_text

libs/core/kiln_ai/tools/test_rag_tools.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from kiln_ai.datamodel.project import Project
99
from kiln_ai.datamodel.rag import RagConfig
1010
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
11+
from kiln_ai.tools.base_tool import ToolCallContext
1112
from kiln_ai.tools.rag_tools import ChunkContext, RagTool, format_search_results
1213

1314

@@ -420,7 +421,7 @@ async def test_rag_tool_run_vector_store_type(self, mock_rag_config, mock_projec
420421
tool = RagTool("tool_123", mock_rag_config)
421422

422423
# Run the tool
423-
result = await tool.run("test query")
424+
result = await tool.run(context=None, query="test query")
424425

425426
# Verify the result format
426427
expected_result = (
@@ -500,7 +501,7 @@ async def test_rag_tool_run_hybrid_store_type(self, mock_rag_config, mock_projec
500501
tool = RagTool("tool_123", mock_rag_config)
501502

502503
# Run the tool
503-
result = await tool.run("hybrid query")
504+
result = await tool.run(context=None, query="hybrid query")
504505

505506
# Verify embedding generation was called
506507
mock_embedding_adapter.generate_embeddings.assert_called_once_with(
@@ -566,7 +567,7 @@ async def test_rag_tool_run_fts_store_type(self, mock_rag_config, mock_project):
566567
tool = RagTool("tool_123", mock_rag_config)
567568

568569
# Run the tool
569-
result = await tool.run("fts query")
570+
result = await tool.run(context=None, query="fts query")
570571

571572
# Verify the result format
572573
expected_result = (
@@ -629,7 +630,7 @@ async def test_rag_tool_run_no_embeddings_generated(
629630

630631
# Run the tool and expect an error
631632
with pytest.raises(ValueError, match="No embeddings generated"):
632-
await tool.run("query with no embeddings")
633+
await tool.run(context=None, query="query with no embeddings")
633634

634635
async def test_rag_tool_run_empty_search_results(
635636
self, mock_rag_config, mock_project
@@ -675,11 +676,91 @@ async def test_rag_tool_run_empty_search_results(
675676
tool = RagTool("tool_123", mock_rag_config)
676677

677678
# Run the tool
678-
result = await tool.run("query with no results")
679+
result = await tool.run(context=None, query="query with no results")
679680

680681
# Should return empty string for no results
681682
assert result == ""
682683

684+
async def test_rag_tool_run_with_context_is_accepted(
685+
self, mock_rag_config, mock_project
686+
):
687+
"""Ensure RagTool.run accepts and works when a ToolCallContext is provided."""
688+
mock_rag_config.parent_project.return_value = mock_project
689+
690+
# Mock search results
691+
search_results = [
692+
SearchResult(
693+
document_id="doc_ctx",
694+
chunk_idx=3,
695+
chunk_text="Context ok",
696+
similarity=0.77,
697+
)
698+
]
699+
700+
with (
701+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
702+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
703+
patch(
704+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
705+
) as mock_adapter_factory,
706+
patch(
707+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
708+
new_callable=AsyncMock,
709+
) as mock_vs_adapter_factory,
710+
):
711+
# VECTOR type → embedding path taken
712+
mock_vector_store_config = Mock()
713+
mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_VECTOR
714+
mock_vs_config_class.from_id_and_parent_path.return_value = (
715+
mock_vector_store_config
716+
)
717+
718+
mock_embedding_config = Mock()
719+
mock_embed_config_class.from_id_and_parent_path.return_value = (
720+
mock_embedding_config
721+
)
722+
723+
mock_embedding_adapter = AsyncMock()
724+
mock_embedding_result = Mock()
725+
mock_embedding_result.embeddings = [Mock(vector=[1.0])]
726+
mock_embedding_adapter.generate_embeddings.return_value = (
727+
mock_embedding_result
728+
)
729+
mock_adapter_factory.return_value = mock_embedding_adapter
730+
731+
mock_vector_store_adapter = AsyncMock()
732+
mock_vector_store_adapter.search.return_value = search_results
733+
mock_vs_adapter_factory.return_value = mock_vector_store_adapter
734+
735+
tool = RagTool("tool_ctx", mock_rag_config)
736+
737+
ctx = ToolCallContext(allow_saving=False)
738+
result = await tool.run(context=ctx, query="with context")
739+
740+
# Works and returns formatted text
741+
assert result == "[document_id: doc_ctx, chunk_idx: 3]\nContext ok\n\n"
742+
743+
# Normal behavior still occurs
744+
mock_embedding_adapter.generate_embeddings.assert_called_once_with(
745+
["with context"]
746+
)
747+
mock_vector_store_adapter.search.assert_called_once()
748+
749+
async def test_rag_tool_run_missing_query_raises(
750+
self, mock_rag_config, mock_project
751+
):
752+
"""Ensure RagTool.run enforces the 'if not query' guard."""
753+
mock_rag_config.parent_project.return_value = mock_project
754+
755+
with (
756+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
757+
):
758+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
759+
tool = RagTool("tool_err", mock_rag_config)
760+
761+
with pytest.raises(KeyError, match="query"):
762+
await tool.run(context=None)
763+
683764

684765
class TestRagToolNameAndDescription:
685766
"""Test RagTool name and description functionality with tool_name and tool_description fields."""

0 commit comments

Comments
 (0)