Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mem0/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import importlib.metadata

__version__ = importlib.metadata.version("mem0ai")
try:
__version__ = importlib.metadata.version("mem0ai")
except importlib.metadata.PackageNotFoundError:
__version__ = "dev"

from mem0.client.main import AsyncMemoryClient, MemoryClient # noqa
from mem0.memory.main import AsyncMemory, Memory # noqa
110 changes: 107 additions & 3 deletions mem0/vector_stores/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Any, Dict, Optional

from pydantic import BaseModel, Field, model_validator

Expand All @@ -8,7 +8,7 @@ class VectorStoreConfig(BaseModel):
description="Provider of the vector store (e.g., 'qdrant', 'chroma', 'upstash_vector')",
default="qdrant",
)
config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
config: Optional[Dict[str, Any]] = Field(description="Configuration for the specific vector store", default=None)

_provider_configs: Dict[str, str] = {
"qdrant": "QdrantConfig",
Expand Down Expand Up @@ -56,9 +56,113 @@ def validate_and_create_config(self) -> "VectorStoreConfig":
raise ValueError(f"Invalid config type for provider {provider}")
return self

# also check if path in allowed kays for pydantic model, and whether config extra fields are allowed
# also check if path in allowed keys for pydantic model, and whether config extra fields are allowed
if "path" not in config and "path" in config_class.__annotations__:
config["path"] = f"/tmp/{provider}"

self.config = config_class(**config)
return self

@property
def collection_name(self) -> str:
"""Get the collection/index name from the provider config."""
if self.config and hasattr(self.config, "collection_name"):
return self.config.collection_name
return "mem0"

@property
def embedding_model_dims(self) -> int:
"""Get the embedding dimensions."""
if self.config and hasattr(self.config, "embedding_model_dims"):
return self.config.embedding_model_dims
return 1536

@property
def api_key(self) -> Optional[str]:
"""Get the API key if the provider requires one."""
if self.config and hasattr(self.config, "api_key"):
return self.config.api_key
return None

@property
def host(self) -> Optional[str]:
"""Get the host for self-hosted vector stores."""
if self.config and hasattr(self.config, "host"):
return self.config.host
return None

@property
def port(self) -> Optional[int]:
"""Get the port for self-hosted vector stores."""
if self.config and hasattr(self.config, "port"):
return self.config.port
return None

@property
def url(self) -> Optional[str]:
"""Get the URL for cloud-based vector stores."""
if self.config and hasattr(self.config, "url"):
return self.config.url
return None

@property
def path(self) -> Optional[str]:
"""Get the path for local vector stores."""
if self.config and hasattr(self.config, "path"):
return self.config.path
return None

def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary for serialization."""
result = {
"provider": self.provider,
"collection_name": self.collection_name,
"embedding_model_dims": self.embedding_model_dims,
}

# Add optional fields if they exist
if self.api_key:
result["api_key"] = self.api_key
if self.host:
result["host"] = self.host
if self.port:
result["port"] = self.port
if self.url:
result["url"] = self.url
if self.path:
result["path"] = self.path

# Include full config if needed
if self.config:
result["config"] = self.config.model_dump() if hasattr(self.config, "model_dump") else vars(self.config)

return result

def get_migration_config(self) -> Dict[str, Any]:
"""Get configuration for migration operations."""
return {
"provider": self.provider,
"source_collection": self.collection_name,
"embedding_dims": self.embedding_model_dims,
"connection_params": self._get_connection_params(),
}

def _get_connection_params(self) -> Dict[str, Any]:
"""Extract connection parameters for migration."""
params = {}
if self.config:
for key in ["host", "port", "url", "api_key", "path"]:
if hasattr(self.config, key):
value = getattr(self.config, key)
if value is not None:
params[key] = value
return params

def rebuild_config(self, new_collection_name: Optional[str] = None) -> "VectorStoreConfig":
"""Create a new config for rebuilding the vector database."""
config_dict = self.model_dump()
if new_collection_name and self.config:
if "config" not in config_dict:
config_dict["config"] = {}
config_dict["config"]["collection_name"] = new_collection_name
return VectorStoreConfig(**config_dict)
105 changes: 105 additions & 0 deletions tests/configs/test_vector_store_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest

from mem0.vector_stores.configs import VectorStoreConfig


class TestVectorStoreConfig:
def test_default_config(self):
"""Test VectorStoreConfig with default values."""
config = VectorStoreConfig()
assert config.provider == "qdrant"
assert config.collection_name == "mem0"
assert config.embedding_model_dims == 1536
assert config.api_key is None
assert config.host is None
assert config.port is None
assert config.url is None

def test_qdrant_config_properties(self):
"""Test VectorStoreConfig properties with Qdrant provider."""
config = VectorStoreConfig(
provider="qdrant",
config={
"collection_name": "test_collection",
"embedding_model_dims": 768,
"host": "localhost",
"port": 6333,
"path": "/custom/path",
},
)
assert config.provider == "qdrant"
assert config.collection_name == "test_collection"
assert config.embedding_model_dims == 768
assert config.host == "localhost"
assert config.port == 6333
assert config.path == "/custom/path"

def test_chroma_config_properties(self):
"""Test VectorStoreConfig properties with Chroma provider - skip if chromadb not installed."""
try:
config = VectorStoreConfig(
provider="chroma", config={"collection_name": "chroma_test", "path": "/tmp/chroma_test"}
)
assert config.provider == "chroma"
assert config.collection_name == "chroma_test"
assert config.path == "/tmp/chroma_test"
except ImportError:
pytest.skip("chromadb not installed")

def test_to_dict_method(self):
"""Test to_dict method for serialization."""
config = VectorStoreConfig(
provider="qdrant",
config={"collection_name": "test", "embedding_model_dims": 512, "host": "localhost", "port": 6333},
)
result = config.to_dict()
assert result["provider"] == "qdrant"
assert result["collection_name"] == "test"
assert result["embedding_model_dims"] == 512
assert result["host"] == "localhost"
assert result["port"] == 6333
assert "config" in result

def test_get_migration_config(self):
"""Test get_migration_config method."""
config = VectorStoreConfig(
provider="qdrant",
config={"collection_name": "migrate_test", "embedding_model_dims": 1024, "host": "localhost", "port": 6333},
)
migration_config = config.get_migration_config()
assert migration_config["provider"] == "qdrant"
assert migration_config["source_collection"] == "migrate_test"
assert migration_config["embedding_dims"] == 1024
assert migration_config["connection_params"]["host"] == "localhost"
assert migration_config["connection_params"]["port"] == 6333

def test_rebuild_config(self):
"""Test rebuild_config method."""
original_config = VectorStoreConfig(
provider="qdrant", config={"collection_name": "original", "embedding_model_dims": 768}
)
new_config = original_config.rebuild_config(new_collection_name="rebuilt")
assert new_config.provider == "qdrant"
assert new_config.collection_name == "rebuilt"

def test_invalid_provider(self):
"""Test VectorStoreConfig with invalid provider."""
with pytest.raises(ValueError, match="Unsupported vector store provider"):
VectorStoreConfig(provider="invalid_provider")

def test_pinecone_config_with_api_key(self):
"""Test VectorStoreConfig with Pinecone provider and API key."""
config = VectorStoreConfig(
provider="pinecone", config={"api_key": "test-api-key", "collection_name": "test-index"}
)
assert config.provider == "pinecone"
assert config.api_key == "test-api-key"
assert config.collection_name == "test-index"

def test_config_with_none_values(self):
"""Test that None values are handled correctly."""
config = VectorStoreConfig(provider="qdrant", config={})
assert config.collection_name == "mem0" # Should return default
assert config.api_key is None
assert config.host is None
assert config.port is None