Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/vector_db_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,30 @@
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_pgvector.py

run-lancedb-tests:
name: LanceDB Tests
runs-on: ubuntu-22.04
steps:
- name: Check out
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: ${{ inputs.python-version }}

- name: Run LanceDB Tests
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_lancedb.py
Comment on lines +106 to +130

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {contents: read}

Copilot Autofix

AI 2 months ago

To fix this issue, set the least-privilege permissions for the GITHUB_TOKEN by adding a permissions block at the root of the workflow. This will apply the minimum necessary permissions (contents: read) to all jobs in the workflow, unless overridden at the job level. This approach is both the simplest and the most likely to continue existing functionality, as reading repository contents is required for workflow operations like checking out code, but no write-access is needed for the existing jobs shown.

The single best step is to add the following under the workflow name: and before the on: block:

permissions:
  contents: read

No additional code changes, imports, or definitions are needed beyond this single block.


Suggested changeset 1
.github/workflows/vector_db_tests.yml

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/.github/workflows/vector_db_tests.yml b/.github/workflows/vector_db_tests.yml
--- a/.github/workflows/vector_db_tests.yml
+++ b/.github/workflows/vector_db_tests.yml
@@ -1,4 +1,6 @@
 name: Reusable Vector DB Tests
+permissions:
+  contents: read
 
 on:
   workflow_call:
EOF
@@ -1,4 +1,6 @@
name: Reusable Vector DB Tests
permissions:
contents: read

on:
workflow_call:
Copilot is powered by AI and may make mistakes. Always verify output.
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async def search(
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = None,
limit: Optional[int] = None,
with_vector: bool = False,
):
"""
Expand Down Expand Up @@ -265,10 +265,10 @@ async def search(
"Use this option only when vector data is required."
)

# In the case of excessive limit, or zero / negative value, limit will be set to 10.
# In the case of excessive limit, or None / zero / negative value, limit will be set to 10.
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
logger.warning(
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
"Provided limit (%s) is invalid (None, zero, negative, or exceeds maximum). "
"Defaulting to limit=10.",
limit,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def search(
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
):
Expand Down Expand Up @@ -386,9 +386,13 @@ async def search(
try:
collection = await self.get_collection(collection_name)

if limit == 0:
if limit is None:
limit = await collection.count()

# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []

results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances", "embeddings"]
Expand Down Expand Up @@ -428,7 +432,7 @@ async def search(
for row in vector_list
]
except Exception as e:
logger.error(f"Error in search: {str(e)}")
logger.warning(f"Error in search: {str(e)}")
return []

async def batch_search(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ async def search(
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
):
Expand All @@ -235,11 +235,11 @@ async def search(

collection = await self.get_collection(collection_name)

if limit == 0:
if limit is None:
limit = await collection.count_rows()

# LanceDB search will break if limit is 0 so we must return
if limit == 0:
if limit <= 0:
return []

results = await collection.vector_search(query_vector).limit(limit).to_pandas()
Expand All @@ -264,7 +264,7 @@ async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
Expand Down
14 changes: 12 additions & 2 deletions cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
Expand Down Expand Up @@ -299,7 +299,7 @@ async def search(
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
Expand All @@ -311,6 +311,16 @@ async def search(
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)

if limit is None:
async with self.get_async_session() as session:
query = select(func.count()).select_from(PGVectorDataPoint)
result = await session.execute(query)
limit = result.scalar_one()

# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []

# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []

Expand Down
12 changes: 8 additions & 4 deletions cognee/infrastructure/databases/vector/vector_db_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def search(
collection_name: str,
query_text: Optional[str],
query_vector: Optional[List[float]],
limit: int,
limit: Optional[int],
with_vector: bool = False,
):
"""
Expand All @@ -98,15 +98,19 @@ async def search(
collection.
- query_vector (Optional[List[float]]): An optional vector representation for
searching the collection.
- limit (int): The maximum number of results to return from the search.
- limit (Optional[int]): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search
results. (default False)
"""
raise NotImplementedError

@abstractmethod
async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
self,
collection_name: str,
query_texts: List[str],
limit: Optional[int],
with_vectors: bool = False,
):
"""
Perform a batch search using multiple text queries against a collection.
Expand All @@ -116,7 +120,7 @@ async def batch_search(

- collection_name (str): The name of the collection to conduct the batch search in.
- query_texts (List[str]): A list of text queries to use for the search.
- limit (int): The maximum number of results to return for each query.
- limit (Optional[int]): The maximum number of results to return for each query.
- with_vectors (bool): Whether to include vector representations with search
results. (default False)
"""
Expand Down
2 changes: 1 addition & 1 deletion cognee/modules/graph/cognee_graph/CogneeGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def map_vector_distances_to_graph_edges(
edge_distances = await vector_engine.search(
collection_name="EdgeType_relationship_name",
query_vector=query_vector,
limit=0,
limit=None,
)
projection_time = time.time() - start_time
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion cognee/modules/retrieval/insights_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class InsightsRetriever(BaseGraphRetriever):
- top_k
"""

def __init__(self, exploration_levels: int = 1, top_k: int = 5):
def __init__(self, exploration_levels: int = 1, top_k: Optional[int] = 5):
"""Initialize retriever with exploration levels and search parameters."""
self.exploration_levels = exploration_levels
self.top_k = top_k
Expand Down
2 changes: 1 addition & 1 deletion cognee/modules/retrieval/temporal_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def get_context(self, query: str) -> Any:
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]

vector_search_results = await vector_engine.search(
collection_name="Event_name", query_vector=query_vector, limit=0
collection_name="Event_name", query_vector=query_vector, limit=None
)

top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=0
collection_name=collection_name, query_vector=query_vector, limit=None
)
except CollectionNotFoundError:
return []
Expand Down
35 changes: 35 additions & 0 deletions cognee/tests/test_chromadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,39 @@ async def test_getting_of_documents(dataset_name_1):
)


async def test_vector_engine_search_none_limit():
file_path = os.path.join(
pathlib.Path(__file__).resolve().parent.parent.parent,
"examples",
"data",
"alice_in_wonderland.txt",
)

await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

await cognee.add(file_path)

await cognee.cognify()

query_text = "List me all the important characters in Alice in Wonderland."

from cognee.infrastructure.databases.vector import get_vector_engine

vector_engine = get_vector_engine()

collection_name = "Entity_name"

query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]

result = await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=None
)

# Check that we did not accidentally use any default value for limit in vector search along the way (like 5, 10, or 15)
assert len(result) > 15


async def main():
cognee.config.set_vector_db_config(
{
Expand Down Expand Up @@ -165,6 +198,8 @@ async def main():
tables_in_database = await vector_engine.get_collection_names()
assert len(tables_in_database) == 0, "ChromaDB database is not empty"

await test_vector_engine_search_none_limit()


if __name__ == "__main__":
import asyncio
Expand Down
Loading
Loading