Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion graphrag/vector_stores/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
"""A package containing a factory and supported vector store types."""

from enum import Enum
from typing import ClassVar
from typing import Any, ClassVar

from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.local_vector_store import LocalVectorStore


class VectorStoreType(str, Enum):
Expand All @@ -18,6 +19,7 @@ class VectorStoreType(str, Enum):
LanceDB = "lancedb"
AzureAISearch = "azure_ai_search"
CosmosDB = "cosmosdb"
Local = "local"


class VectorStoreFactory:
Expand Down Expand Up @@ -45,8 +47,32 @@ def create_vector_store(
return AzureAISearchVectorStore(**kwargs)
case VectorStoreType.CosmosDB:
return CosmosDBVectoreStore(**kwargs)
case VectorStoreType.Local:
return LocalVectorStore(**kwargs)
case _:
if vector_store_type in cls.vector_store_types:
return cls.vector_store_types[vector_store_type](**kwargs)
msg = f"Unknown vector store type: {vector_store_type}"
raise ValueError(msg)

def get_vector_store(
store_type: VectorStoreType,
collection_name: str,
**kwargs: Any,
) -> BaseVectorStore:
"""Get a vector store instance based on the store type."""
store_map: ClassVar[dict[VectorStoreType, type[BaseVectorStore]]] = {
VectorStoreType.LanceDB: LanceDBVectorStore,
VectorStoreType.AzureAISearch: AzureAISearchVectorStore,
VectorStoreType.CosmosDB: CosmosDBVectoreStore,
VectorStoreType.Local: LocalVectorStore,
}

store_class = store_map.get(store_type)
if store_class is None:
msg = f"Unsupported vector store type: {store_type}"
raise ValueError(msg)

store = store_class(collection_name=collection_name)
store.connect(**kwargs)
return store
Loading