diff --git a/mem0/__init__.py b/mem0/__init__.py index 318347ecba..5bfb11fe58 100644 --- a/mem0/__init__.py +++ b/mem0/__init__.py @@ -1,6 +1,9 @@ import importlib.metadata -__version__ = importlib.metadata.version("mem0ai") +try: + __version__ = importlib.metadata.version("mem0ai") +except importlib.metadata.PackageNotFoundError: + __version__ = "dev" from mem0.client.main import AsyncMemoryClient, MemoryClient # noqa from mem0.memory.main import AsyncMemory, Memory # noqa diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index ff9fd995fc..a81e5d44cf 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, Field, model_validator @@ -8,7 +8,7 @@ class VectorStoreConfig(BaseModel): description="Provider of the vector store (e.g., 'qdrant', 'chroma', 'upstash_vector')", default="qdrant", ) - config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None) + config: Optional[Dict[str, Any]] = Field(description="Configuration for the specific vector store", default=None) _provider_configs: Dict[str, str] = { "qdrant": "QdrantConfig", @@ -56,9 +56,113 @@ def validate_and_create_config(self) -> "VectorStoreConfig": raise ValueError(f"Invalid config type for provider {provider}") return self - # also check if path in allowed kays for pydantic model, and whether config extra fields are allowed + # also check if path in allowed keys for pydantic model, and whether config extra fields are allowed if "path" not in config and "path" in config_class.__annotations__: config["path"] = f"/tmp/{provider}" self.config = config_class(**config) return self + + @property + def collection_name(self) -> str: + """Get the collection/index name from the provider config.""" + if self.config and hasattr(self.config, "collection_name"): + return self.config.collection_name + return "mem0" + + @property + def embedding_model_dims(self) -> int: + """Get the embedding dimensions.""" + if self.config and hasattr(self.config, "embedding_model_dims"): + return self.config.embedding_model_dims + return 1536 + + @property + def api_key(self) -> Optional[str]: + """Get the API key if the provider requires one.""" + if self.config and hasattr(self.config, "api_key"): + return self.config.api_key + return None + + @property + def host(self) -> Optional[str]: + """Get the host for self-hosted vector stores.""" + if self.config and hasattr(self.config, "host"): + return self.config.host + return None + + @property + def port(self) -> Optional[int]: + """Get the port for self-hosted vector stores.""" + if self.config and hasattr(self.config, "port"): + return self.config.port + return None + + @property + def url(self) -> Optional[str]: + """Get the URL for cloud-based vector stores.""" + if self.config and hasattr(self.config, "url"): + return self.config.url + return None + + @property + def path(self) -> Optional[str]: + """Get the path for local vector stores.""" + if self.config and hasattr(self.config, "path"): + return self.config.path + return None + + def to_dict(self) -> Dict[str, Any]: + """Convert config to dictionary for serialization.""" + result = { + "provider": self.provider, + "collection_name": self.collection_name, + "embedding_model_dims": self.embedding_model_dims, + } + + # Add optional fields if they exist + if self.api_key: + result["api_key"] = self.api_key + if self.host: + result["host"] = self.host + if self.port: + result["port"] = self.port + if self.url: + result["url"] = self.url + if self.path: + result["path"] = self.path + + # Include full config if needed + if self.config: + result["config"] = self.config.model_dump() if hasattr(self.config, "model_dump") else vars(self.config) + + return result + + def get_migration_config(self) -> Dict[str, Any]: + """Get configuration for migration operations.""" + return { + "provider": self.provider, + "source_collection": self.collection_name, + "embedding_dims": self.embedding_model_dims, + "connection_params": self._get_connection_params(), + } + + def _get_connection_params(self) -> Dict[str, Any]: + """Extract connection parameters for migration.""" + params = {} + if self.config: + for key in ["host", "port", "url", "api_key", "path"]: + if hasattr(self.config, key): + value = getattr(self.config, key) + if value is not None: + params[key] = value + return params + + def rebuild_config(self, new_collection_name: Optional[str] = None) -> "VectorStoreConfig": + """Create a new config for rebuilding the vector database.""" + config_dict = self.model_dump() + if new_collection_name and self.config: + if "config" not in config_dict: + config_dict["config"] = {} + config_dict["config"]["collection_name"] = new_collection_name + return VectorStoreConfig(**config_dict) diff --git a/tests/configs/test_vector_store_config.py b/tests/configs/test_vector_store_config.py new file mode 100644 index 0000000000..1406926955 --- /dev/null +++ b/tests/configs/test_vector_store_config.py @@ -0,0 +1,105 @@ +import pytest + +from mem0.vector_stores.configs import VectorStoreConfig + + +class TestVectorStoreConfig: + def test_default_config(self): + """Test VectorStoreConfig with default values.""" + config = VectorStoreConfig() + assert config.provider == "qdrant" + assert config.collection_name == "mem0" + assert config.embedding_model_dims == 1536 + assert config.api_key is None + assert config.host is None + assert config.port is None + assert config.url is None + + def test_qdrant_config_properties(self): + """Test VectorStoreConfig properties with Qdrant provider.""" + config = VectorStoreConfig( + provider="qdrant", + config={ + "collection_name": "test_collection", + "embedding_model_dims": 768, + "host": "localhost", + "port": 6333, + "path": "/custom/path", + }, + ) + assert config.provider == "qdrant" + assert config.collection_name == "test_collection" + assert config.embedding_model_dims == 768 + assert config.host == "localhost" + assert config.port == 6333 + assert config.path == "/custom/path" + + def test_chroma_config_properties(self): + """Test VectorStoreConfig properties with Chroma provider - skip if chromadb not installed.""" + try: + config = VectorStoreConfig( + provider="chroma", config={"collection_name": "chroma_test", "path": "/tmp/chroma_test"} + ) + assert config.provider == "chroma" + assert config.collection_name == "chroma_test" + assert config.path == "/tmp/chroma_test" + except ImportError: + pytest.skip("chromadb not installed") + + def test_to_dict_method(self): + """Test to_dict method for serialization.""" + config = VectorStoreConfig( + provider="qdrant", + config={"collection_name": "test", "embedding_model_dims": 512, "host": "localhost", "port": 6333}, + ) + result = config.to_dict() + assert result["provider"] == "qdrant" + assert result["collection_name"] == "test" + assert result["embedding_model_dims"] == 512 + assert result["host"] == "localhost" + assert result["port"] == 6333 + assert "config" in result + + def test_get_migration_config(self): + """Test get_migration_config method.""" + config = VectorStoreConfig( + provider="qdrant", + config={"collection_name": "migrate_test", "embedding_model_dims": 1024, "host": "localhost", "port": 6333}, + ) + migration_config = config.get_migration_config() + assert migration_config["provider"] == "qdrant" + assert migration_config["source_collection"] == "migrate_test" + assert migration_config["embedding_dims"] == 1024 + assert migration_config["connection_params"]["host"] == "localhost" + assert migration_config["connection_params"]["port"] == 6333 + + def test_rebuild_config(self): + """Test rebuild_config method.""" + original_config = VectorStoreConfig( + provider="qdrant", config={"collection_name": "original", "embedding_model_dims": 768} + ) + new_config = original_config.rebuild_config(new_collection_name="rebuilt") + assert new_config.provider == "qdrant" + assert new_config.collection_name == "rebuilt" + + def test_invalid_provider(self): + """Test VectorStoreConfig with invalid provider.""" + with pytest.raises(ValueError, match="Unsupported vector store provider"): + VectorStoreConfig(provider="invalid_provider") + + def test_pinecone_config_with_api_key(self): + """Test VectorStoreConfig with Pinecone provider and API key.""" + config = VectorStoreConfig( + provider="pinecone", config={"api_key": "test-api-key", "collection_name": "test-index"} + ) + assert config.provider == "pinecone" + assert config.api_key == "test-api-key" + assert config.collection_name == "test-index" + + def test_config_with_none_values(self): + """Test that None values are handled correctly.""" + config = VectorStoreConfig(provider="qdrant", config={}) + assert config.collection_name == "mem0" # Should return default + assert config.api_key is None + assert config.host is None + assert config.port is None