|
8 | 8 | from kiln_ai.datamodel.project import Project |
9 | 9 | from kiln_ai.datamodel.rag import RagConfig |
10 | 10 | from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType |
| 11 | +from kiln_ai.tools.base_tool import ToolCallContext |
11 | 12 | from kiln_ai.tools.rag_tools import ChunkContext, RagTool, format_search_results |
12 | 13 |
|
13 | 14 |
|
@@ -420,7 +421,7 @@ async def test_rag_tool_run_vector_store_type(self, mock_rag_config, mock_projec |
420 | 421 | tool = RagTool("tool_123", mock_rag_config) |
421 | 422 |
|
422 | 423 | # Run the tool |
423 | | - result = await tool.run("test query") |
| 424 | + result = await tool.run(context=None, query="test query") |
424 | 425 |
|
425 | 426 | # Verify the result format |
426 | 427 | expected_result = ( |
@@ -500,7 +501,7 @@ async def test_rag_tool_run_hybrid_store_type(self, mock_rag_config, mock_projec |
500 | 501 | tool = RagTool("tool_123", mock_rag_config) |
501 | 502 |
|
502 | 503 | # Run the tool |
503 | | - result = await tool.run("hybrid query") |
| 504 | + result = await tool.run(context=None, query="hybrid query") |
504 | 505 |
|
505 | 506 | # Verify embedding generation was called |
506 | 507 | mock_embedding_adapter.generate_embeddings.assert_called_once_with( |
@@ -566,7 +567,7 @@ async def test_rag_tool_run_fts_store_type(self, mock_rag_config, mock_project): |
566 | 567 | tool = RagTool("tool_123", mock_rag_config) |
567 | 568 |
|
568 | 569 | # Run the tool |
569 | | - result = await tool.run("fts query") |
| 570 | + result = await tool.run(context=None, query="fts query") |
570 | 571 |
|
571 | 572 | # Verify the result format |
572 | 573 | expected_result = ( |
@@ -629,7 +630,7 @@ async def test_rag_tool_run_no_embeddings_generated( |
629 | 630 |
|
630 | 631 | # Run the tool and expect an error |
631 | 632 | with pytest.raises(ValueError, match="No embeddings generated"): |
632 | | - await tool.run("query with no embeddings") |
| 633 | + await tool.run(context=None, query="query with no embeddings") |
633 | 634 |
|
634 | 635 | async def test_rag_tool_run_empty_search_results( |
635 | 636 | self, mock_rag_config, mock_project |
@@ -675,11 +676,91 @@ async def test_rag_tool_run_empty_search_results( |
675 | 676 | tool = RagTool("tool_123", mock_rag_config) |
676 | 677 |
|
677 | 678 | # Run the tool |
678 | | - result = await tool.run("query with no results") |
| 679 | + result = await tool.run(context=None, query="query with no results") |
679 | 680 |
|
680 | 681 | # Should return empty string for no results |
681 | 682 | assert result == "" |
682 | 683 |
|
| 684 | + async def test_rag_tool_run_with_context_is_accepted( |
| 685 | + self, mock_rag_config, mock_project |
| 686 | + ): |
| 687 | + """Ensure RagTool.run accepts and works when a ToolCallContext is provided.""" |
| 688 | + mock_rag_config.parent_project.return_value = mock_project |
| 689 | + |
| 690 | + # Mock search results |
| 691 | + search_results = [ |
| 692 | + SearchResult( |
| 693 | + document_id="doc_ctx", |
| 694 | + chunk_idx=3, |
| 695 | + chunk_text="Context ok", |
| 696 | + similarity=0.77, |
| 697 | + ) |
| 698 | + ] |
| 699 | + |
| 700 | + with ( |
| 701 | + patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class, |
| 702 | + patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class, |
| 703 | + patch( |
| 704 | + "kiln_ai.tools.rag_tools.embedding_adapter_from_type" |
| 705 | + ) as mock_adapter_factory, |
| 706 | + patch( |
| 707 | + "kiln_ai.tools.rag_tools.vector_store_adapter_for_config", |
| 708 | + new_callable=AsyncMock, |
| 709 | + ) as mock_vs_adapter_factory, |
| 710 | + ): |
| 711 | + # VECTOR type → embedding path taken |
| 712 | + mock_vector_store_config = Mock() |
| 713 | + mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_VECTOR |
| 714 | + mock_vs_config_class.from_id_and_parent_path.return_value = ( |
| 715 | + mock_vector_store_config |
| 716 | + ) |
| 717 | + |
| 718 | + mock_embedding_config = Mock() |
| 719 | + mock_embed_config_class.from_id_and_parent_path.return_value = ( |
| 720 | + mock_embedding_config |
| 721 | + ) |
| 722 | + |
| 723 | + mock_embedding_adapter = AsyncMock() |
| 724 | + mock_embedding_result = Mock() |
| 725 | + mock_embedding_result.embeddings = [Mock(vector=[1.0])] |
| 726 | + mock_embedding_adapter.generate_embeddings.return_value = ( |
| 727 | + mock_embedding_result |
| 728 | + ) |
| 729 | + mock_adapter_factory.return_value = mock_embedding_adapter |
| 730 | + |
| 731 | + mock_vector_store_adapter = AsyncMock() |
| 732 | + mock_vector_store_adapter.search.return_value = search_results |
| 733 | + mock_vs_adapter_factory.return_value = mock_vector_store_adapter |
| 734 | + |
| 735 | + tool = RagTool("tool_ctx", mock_rag_config) |
| 736 | + |
| 737 | + ctx = ToolCallContext(allow_saving=False) |
| 738 | + result = await tool.run(context=ctx, query="with context") |
| 739 | + |
| 740 | + # Works and returns formatted text |
| 741 | + assert result == "[document_id: doc_ctx, chunk_idx: 3]\nContext ok\n\n" |
| 742 | + |
| 743 | + # Normal behavior still occurs |
| 744 | + mock_embedding_adapter.generate_embeddings.assert_called_once_with( |
| 745 | + ["with context"] |
| 746 | + ) |
| 747 | + mock_vector_store_adapter.search.assert_called_once() |
| 748 | + |
| 749 | + async def test_rag_tool_run_missing_query_raises( |
| 750 | + self, mock_rag_config, mock_project |
| 751 | + ): |
| 752 | + """Ensure RagTool.run enforces the 'if not query' guard.""" |
| 753 | + mock_rag_config.parent_project.return_value = mock_project |
| 754 | + |
| 755 | + with ( |
| 756 | + patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class, |
| 757 | + ): |
| 758 | + mock_vs_config_class.from_id_and_parent_path.return_value = Mock() |
| 759 | + tool = RagTool("tool_err", mock_rag_config) |
| 760 | + |
| 761 | + with pytest.raises(KeyError, match="query"): |
| 762 | + await tool.run(context=None) |
| 763 | + |
683 | 764 |
|
684 | 765 | class TestRagToolNameAndDescription: |
685 | 766 | """Test RagTool name and description functionality with tool_name and tool_description fields.""" |
|
0 commit comments