diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 469fefd2..21795f55 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -138,6 +138,7 @@ target_link_libraries(index_schema PUBLIC type_conversions) target_link_libraries(index_schema PUBLIC utils) target_link_libraries(index_schema PUBLIC status_macros) target_link_libraries(index_schema PUBLIC valkey_module) +target_link_libraries(index_schema PUBLIC memory_tracker) set(SRCS_ATTRIBUTE_DATA_TYPE ${CMAKE_CURRENT_LIST_DIR}/attribute_data_type.cc ${CMAKE_CURRENT_LIST_DIR}/attribute_data_type.h) diff --git a/src/index_schema.cc b/src/index_schema.cc index 4d0cf99c..15cec75d 100644 --- a/src/index_schema.cc +++ b/src/index_schema.cc @@ -48,6 +48,7 @@ #include "vmsdk/src/blocked_client.h" #include "vmsdk/src/log.h" #include "vmsdk/src/managed_pointers.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/status/status_macros.h" #include "vmsdk/src/thread_pool.h" #include "vmsdk/src/time_sliced_mrmw_mutex.h" @@ -77,10 +78,12 @@ absl::StatusOr> IndexFactory( const auto &index = attribute.index(); switch (index.index_type_case()) { case data_model::Index::IndexTypeCase::kTagIndex: { - return std::make_shared(index.tag_index()); + return std::make_shared(index.tag_index(), + index_schema->GetMemoryPool()); } case data_model::Index::IndexTypeCase::kNumericIndex: { - return std::make_shared(index.numeric_index()); + return std::make_shared(index.numeric_index(), + index_schema->GetMemoryPool()); } case data_model::Index::IndexTypeCase::kVectorIndex: { switch (index.vector_index().algorithm_case()) { @@ -93,10 +96,11 @@ absl::StatusOr> IndexFactory( ? indexes::VectorHNSW::LoadFromRDB( ctx, &index_schema->GetAttributeDataType(), index.vector_index(), attribute.identifier(), - std::move(*iter)) + std::move(*iter), index_schema->GetMemoryPool()) : indexes::VectorHNSW::Create( index.vector_index(), attribute.identifier(), - index_schema->GetAttributeDataType().ToProto())); + index_schema->GetAttributeDataType().ToProto(), + index_schema->GetMemoryPool())); index_schema->SubscribeToVectorExternalizer( attribute.identifier(), index.get()); return index; @@ -118,10 +122,11 @@ absl::StatusOr> IndexFactory( ? indexes::VectorFlat::LoadFromRDB( ctx, &index_schema->GetAttributeDataType(), index.vector_index(), attribute.identifier(), - std::move(*iter)) + std::move(*iter), index_schema->GetMemoryPool()) : indexes::VectorFlat::Create( index.vector_index(), attribute.identifier(), - index_schema->GetAttributeDataType().ToProto())); + index_schema->GetAttributeDataType().ToProto(), + index_schema->GetMemoryPool())); index_schema->SubscribeToVectorExternalizer( attribute.identifier(), index.get()); return index; @@ -1047,7 +1052,7 @@ void IndexSchema::OnLoadingEnded(ValkeyModuleCtx *ctx) { << " stale entries for {Index: " << name_ << "}"; for (auto &[key, attributes] : deletion_attributes) { - auto interned_key = std::make_shared(key); + auto interned_key = StringInternStore::Intern(key); ProcessMutation(ctx, attributes, interned_key, true); } VMSDK_LOG(NOTICE, ctx) << "Scanned index schema " << name_ diff --git a/src/index_schema.h b/src/index_schema.h index 8b8fa204..71318263 100644 --- a/src/index_schema.h +++ b/src/index_schema.h @@ -36,6 +36,7 @@ #include "src/utils/string_interning.h" #include "vmsdk/src/blocked_client.h" #include "vmsdk/src/managed_pointers.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/thread_pool.h" #include "vmsdk/src/time_sliced_mrmw_mutex.h" #include "vmsdk/src/utils.h" @@ -172,6 +173,8 @@ class IndexSchema : public KeyspaceEventSubscription, uint64_t GetBackfillDbSize() const; InfoIndexPartitionData GetInfoIndexPartitionData() const; + MemoryPool &GetMemoryPool() { return memory_pool_; } + protected: IndexSchema(ValkeyModuleCtx *ctx, const data_model::IndexSchema &index_schema_proto, @@ -264,6 +267,8 @@ class IndexSchema : public KeyspaceEventSubscription, vmsdk::MainThreadAccessGuard multi_mutations_; vmsdk::MainThreadAccessGuard schedule_multi_exec_processing_{false}; + MemoryPool memory_pool_{0}; + FRIEND_TEST(IndexSchemaRDBTest, SaveAndLoad); FRIEND_TEST(IndexSchemaRDBTest, ComprehensiveSkipLoadTest); FRIEND_TEST(IndexSchemaFriendTest, ConsistencyTest); diff --git a/src/indexes/CMakeLists.txt b/src/indexes/CMakeLists.txt index 80b9982c..33aaaff8 100644 --- a/src/indexes/CMakeLists.txt +++ b/src/indexes/CMakeLists.txt @@ -8,6 +8,7 @@ target_link_libraries(index_base INTERFACE rdb_serialization) target_link_libraries(index_base INTERFACE string_interning) target_link_libraries(index_base INTERFACE managed_pointers) target_link_libraries(index_base INTERFACE valkey_module) +target_link_libraries(index_base INTERFACE memory_tracker) set(SRCS_VECTOR_BASE ${CMAKE_CURRENT_LIST_DIR}/vector_base.cc ${CMAKE_CURRENT_LIST_DIR}/vector_base.h) @@ -32,6 +33,7 @@ target_link_libraries(vector_base PUBLIC managed_pointers) target_link_libraries(vector_base PUBLIC type_conversions) target_link_libraries(vector_base PUBLIC status_macros) target_link_libraries(vector_base PUBLIC valkey_module) +target_link_libraries(vector_base PUBLIC memory_tracker) set(SRCS_VECTOR_HNSW ${CMAKE_CURRENT_LIST_DIR}/vector_hnsw.cc ${CMAKE_CURRENT_LIST_DIR}/vector_hnsw.h) @@ -50,6 +52,7 @@ target_link_libraries(vector_hnsw PUBLIC memory_allocation_overrides) target_link_libraries(vector_hnsw PUBLIC utils) target_link_libraries(vector_hnsw PUBLIC status_macros) target_link_libraries(vector_hnsw PUBLIC valkey_module) +target_link_libraries(vector_hnsw PUBLIC memory_tracker) set(SRCS_NUMERIC ${CMAKE_CURRENT_LIST_DIR}/numeric.cc ${CMAKE_CURRENT_LIST_DIR}/numeric.h) @@ -62,6 +65,7 @@ target_link_libraries(numeric PUBLIC predicate_header) target_link_libraries(numeric PUBLIC segment_tree) target_link_libraries(numeric PUBLIC string_interning) target_link_libraries(numeric PUBLIC valkey_module) +target_link_libraries(numeric PUBLIC memory_tracker) set(SRCS_TAG ${CMAKE_CURRENT_LIST_DIR}/tag.cc ${CMAKE_CURRENT_LIST_DIR}/tag.h) @@ -76,6 +80,7 @@ target_link_libraries(tag PUBLIC managed_pointers) target_link_libraries(tag PUBLIC type_conversions) target_link_libraries(tag PUBLIC valkey_module) target_link_libraries(tag PUBLIC ${INDEX_SCHEMA_PROTO_LIB}) +target_link_libraries(tag PUBLIC memory_tracker) set(SRCS_VECTOR_FLAT ${CMAKE_CURRENT_LIST_DIR}/vector_flat.cc ${CMAKE_CURRENT_LIST_DIR}/vector_flat.h) @@ -93,3 +98,4 @@ target_link_libraries(vector_flat PUBLIC log) target_link_libraries(vector_flat PUBLIC memory_allocation_overrides) target_link_libraries(vector_flat PUBLIC status_macros) target_link_libraries(vector_flat PUBLIC valkey_module) +target_link_libraries(vector_flat PUBLIC memory_tracker) diff --git a/src/indexes/index_base.h b/src/indexes/index_base.h index 6db739c3..f8047167 100644 --- a/src/indexes/index_base.h +++ b/src/indexes/index_base.h @@ -22,6 +22,7 @@ #include "src/rdb_serialization.h" #include "src/utils/string_interning.h" #include "vmsdk/src/managed_pointers.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { @@ -40,7 +41,8 @@ const absl::NoDestructor> class IndexBase { public: - explicit IndexBase(IndexerType indexer_type) : indexer_type_(indexer_type) {} + explicit IndexBase(IndexerType indexer_type, MemoryPool& memory_pool) + : indexer_type_(indexer_type), memory_pool_(memory_pool) {} virtual ~IndexBase() = default; // Add/Remove/Modify will return true if the operation was successful, false @@ -67,6 +69,11 @@ class IndexBase { } virtual uint64_t GetRecordCount() const = 0; + MemoryPool& GetMemoryPool() { return memory_pool_; } + + protected: + MemoryPool& memory_pool_; + private: IndexerType indexer_type_{IndexerType::kNone}; }; diff --git a/src/indexes/numeric.cc b/src/indexes/numeric.cc index 450a8105..2f5e7c02 100644 --- a/src/indexes/numeric.cc +++ b/src/indexes/numeric.cc @@ -25,6 +25,7 @@ #include "src/indexes/index_base.h" #include "src/query/predicate.h" #include "src/utils/string_interning.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { @@ -38,29 +39,48 @@ std::optional ParseNumber(absl::string_view data) { } } // namespace -Numeric::Numeric(const data_model::NumericIndex& numeric_index_proto) - : IndexBase(IndexerType::kNumeric) { +Numeric::Numeric(const data_model::NumericIndex& numeric_index_proto, + MemoryPool& memory_pool) + : IndexBase(IndexerType::kNumeric, memory_pool) { + IsolatedMemoryScope scope{memory_pool}; + + tracked_keys_ = std::make_unique>(); + untracked_keys_ = std::make_unique(); index_ = std::make_unique(); } +Numeric::~Numeric() { + IsolatedMemoryScope scope{memory_pool_}; + + tracked_keys_.reset(); + untracked_keys_.reset(); + index_.reset(); +} + +// NOTE: key should be stored interned string. absl::StatusOr Numeric::AddRecord(const InternedStringPtr& key, absl::string_view data) { + IsolatedMemoryScope scope{memory_pool_}; + auto value = ParseNumber(data); absl::MutexLock lock(&index_mutex_); if (!value.has_value()) { - untracked_keys_.insert(key); + untracked_keys_->insert(key); return false; } - auto [_, succ] = tracked_keys_.insert({key, *value}); + auto [_, succ] = tracked_keys_->insert({key, *value}); if (!succ) { + // NOTE: don't track allocation error. + DisableMemoryTracking disable_tracking; return absl::AlreadyExistsError( absl::StrCat("Key `", key->Str(), "` already exists")); } - untracked_keys_.erase(key); + untracked_keys_->erase(key); index_->Add(key, *value); return true; } +// NOTE: key should be stored interned string. absl::StatusOr Numeric::ModifyRecord(const InternedStringPtr& key, absl::string_view data) { auto value = ParseNumber(data); @@ -69,35 +89,41 @@ absl::StatusOr Numeric::ModifyRecord(const InternedStringPtr& key, RemoveRecord(key, indexes::DeletionType::kIdentifier); return false; } + absl::MutexLock lock(&index_mutex_); - auto it = tracked_keys_.find(key); - if (it == tracked_keys_.end()) { + auto it = tracked_keys_->find(key); + if (it == tracked_keys_->end()) { return absl::NotFoundError( absl::StrCat("Key `", key->Str(), "` not found")); } + IsolatedMemoryScope scope{memory_pool_}; + index_->Modify(it->first, it->second, *value); it->second = *value; return true; } +// NOTE: key should be stored interned string. absl::StatusOr Numeric::RemoveRecord(const InternedStringPtr& key, DeletionType deletion_type) { + IsolatedMemoryScope scope{memory_pool_}; + absl::MutexLock lock(&index_mutex_); if (deletion_type == DeletionType::kRecord) { // If key is DELETED, remove it from untracked_keys_. - untracked_keys_.erase(key); + untracked_keys_->erase(key); } else { // If key doesn't have TAG but exists, insert it to untracked_keys_. - untracked_keys_.insert(key); + untracked_keys_->insert(key); } - auto it = tracked_keys_.find(key); - if (it == tracked_keys_.end()) { + auto it = tracked_keys_->find(key); + if (it == tracked_keys_->end()) { return false; } index_->Remove(it->first, it->second); - tracked_keys_.erase(it); + tracked_keys_->erase(it); return true; } @@ -107,13 +133,13 @@ int Numeric::RespondWithInfo(ValkeyModuleCtx* ctx) const { ValkeyModule_ReplyWithSimpleString(ctx, "size"); absl::MutexLock lock(&index_mutex_); ValkeyModule_ReplyWithCString(ctx, - std::to_string(tracked_keys_.size()).c_str()); + std::to_string(tracked_keys_->size()).c_str()); return 4; } bool Numeric::IsTracked(const InternedStringPtr& key) const { absl::MutexLock lock(&index_mutex_); - return tracked_keys_.contains(key); + return tracked_keys_->contains(key); } std::unique_ptr Numeric::ToProto() const { @@ -126,7 +152,7 @@ std::unique_ptr Numeric::ToProto() const { const double* Numeric::GetValue(const InternedStringPtr& key) const { // Note that the Numeric index is not mutated while the time sliced mutex is // in a read mode and therefor it is safe to skip lock acquiring. - if (auto it = tracked_keys_.find(key); it != tracked_keys_.end()) { + if (auto it = tracked_keys_->find(key); it != tracked_keys_->end()) { return &it->second; } return nullptr; @@ -154,8 +180,8 @@ std::unique_ptr Numeric::Search( ; additional_entries_range.second = btree.end(); return std::make_unique( - entries_range, size + untracked_keys_.size(), additional_entries_range, - &untracked_keys_); + entries_range, size + untracked_keys_->size(), additional_entries_range, + untracked_keys_.get()); } entries_range.first = predicate.IsStartInclusive() @@ -256,7 +282,7 @@ std::unique_ptr Numeric::EntriesFetcher::Begin() { uint64_t Numeric::GetRecordCount() const { absl::MutexLock lock(&index_mutex_); - return tracked_keys_.size(); + return tracked_keys_->size(); } } // namespace valkey_search::indexes diff --git a/src/indexes/numeric.h b/src/indexes/numeric.h index b83984a1..51a71a66 100644 --- a/src/indexes/numeric.h +++ b/src/indexes/numeric.h @@ -29,6 +29,7 @@ #include "src/rdb_serialization.h" #include "src/utils/segment_tree.h" #include "src/utils/string_interning.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { @@ -81,7 +82,9 @@ class BTreeNumeric { class Numeric : public IndexBase { public: - explicit Numeric(const data_model::NumericIndex& numeric_index_proto); + explicit Numeric(const data_model::NumericIndex& numeric_index_proto, + MemoryPool& memory_pool); + ~Numeric() override; absl::StatusOr AddRecord(const InternedStringPtr& key, absl::string_view data) override ABSL_LOCKS_EXCLUDED(index_mutex_); @@ -100,7 +103,7 @@ class Numeric : public IndexBase { inline void ForEachTrackedKey( absl::AnyInvocable fn) const override { absl::MutexLock lock(&index_mutex_); - for (const auto& [key, _] : tracked_keys_) { + for (const auto& [key, _] : *tracked_keys_) { fn(key); } } @@ -166,9 +169,11 @@ class Numeric : public IndexBase { private: mutable absl::Mutex index_mutex_; - InternedStringMap tracked_keys_ ABSL_GUARDED_BY(index_mutex_); + std::unique_ptr> tracked_keys_ + ABSL_GUARDED_BY(index_mutex_); // untracked keys is needed to support negate filtering - InternedStringSet untracked_keys_ ABSL_GUARDED_BY(index_mutex_); + std::unique_ptr untracked_keys_ + ABSL_GUARDED_BY(index_mutex_); std::unique_ptr index_ ABSL_GUARDED_BY(index_mutex_); }; } // namespace valkey_search::indexes diff --git a/src/indexes/tag.cc b/src/indexes/tag.cc index 16bb7669..87bdce99 100644 --- a/src/indexes/tag.cc +++ b/src/indexes/tag.cc @@ -26,6 +26,7 @@ #include "src/query/predicate.h" #include "src/utils/patricia_tree.h" #include "src/utils/string_interning.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { @@ -39,8 +40,8 @@ static bool IsValidPrefix(absl::string_view str) { str[str.length() - 2] != '*'; } -Tag::Tag(const data_model::TagIndex& tag_index_proto) - : IndexBase(IndexerType::kTag), +Tag::Tag(const data_model::TagIndex& tag_index_proto, MemoryPool& memory_pool) + : IndexBase(IndexerType::kTag, memory_pool), separator_(tag_index_proto.separator()[0]), case_sensitive_(tag_index_proto.case_sensitive()), tree_(case_sensitive_) {} diff --git a/src/indexes/tag.h b/src/indexes/tag.h index 74da9730..42f15545 100644 --- a/src/indexes/tag.h +++ b/src/indexes/tag.h @@ -25,13 +25,15 @@ #include "src/rdb_serialization.h" #include "src/utils/patricia_tree.h" #include "src/utils/string_interning.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { class Tag : public IndexBase { public: - explicit Tag(const data_model::TagIndex& tag_index_proto); + explicit Tag(const data_model::TagIndex& tag_index_proto, + MemoryPool& memory_pool); absl::StatusOr AddRecord(const InternedStringPtr& key, absl::string_view data) override ABSL_LOCKS_EXCLUDED(index_mutex_); diff --git a/src/indexes/vector_base.cc b/src/indexes/vector_base.cc index e636bf17..5624cc75 100644 --- a/src/indexes/vector_base.cc +++ b/src/indexes/vector_base.cc @@ -49,6 +49,7 @@ #include "third_party/hnswlib/space_l2.h" #include "vmsdk/src/log.h" #include "vmsdk/src/managed_pointers.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/status/status_macros.h" #include "vmsdk/src/type_conversions.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" @@ -131,6 +132,8 @@ template void VectorBase::Init(int dimensions, valkey_search::data_model::DistanceMetric distance_metric, std::unique_ptr> &space) { + NestedMemoryScope scope{memory_pool_}; + space = CreateSpace(dimensions, distance_metric); distance_metric_ = distance_metric; if (distance_metric == diff --git a/src/indexes/vector_base.h b/src/indexes/vector_base.h index a15af9be..5f0c2b16 100644 --- a/src/indexes/vector_base.h +++ b/src/indexes/vector_base.h @@ -38,6 +38,7 @@ #include "third_party/hnswlib/hnswlib.h" #include "third_party/hnswlib/iostream.h" #include "vmsdk/src/managed_pointers.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { @@ -150,8 +151,8 @@ class VectorBase : public IndexBase, public hnswlib::VectorTracker { protected: VectorBase(IndexerType indexer_type, int dimensions, data_model::AttributeDataType attribute_data_type, - absl::string_view attribute_identifier) - : IndexBase(indexer_type), + absl::string_view attribute_identifier, MemoryPool& memory_pool) + : IndexBase(indexer_type, memory_pool), dimensions_(dimensions), attribute_identifier_(attribute_identifier), attribute_data_type_(attribute_data_type) diff --git a/src/indexes/vector_flat.cc b/src/indexes/vector_flat.cc index 4e88fbea..4793f698 100644 --- a/src/indexes/vector_flat.cc +++ b/src/indexes/vector_flat.cc @@ -36,6 +36,7 @@ #include "src/utils/cancel.h" #include "src/utils/string_interning.h" #include "vmsdk/src/log.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/status/status_macros.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" @@ -53,13 +54,14 @@ template absl::StatusOr>> VectorFlat::Create( const data_model::VectorIndex &vector_index_proto, absl::string_view attribute_identifier, - data_model::AttributeDataType attribute_data_type) { + data_model::AttributeDataType attribute_data_type, + MemoryPool &memory_pool) { try { - auto index = std::shared_ptr>( - new VectorFlat(vector_index_proto.dimension_count(), - vector_index_proto.distance_metric(), - vector_index_proto.flat_algorithm().block_size(), - attribute_identifier, attribute_data_type)); + auto index = std::shared_ptr>(new VectorFlat( + vector_index_proto.dimension_count(), + vector_index_proto.distance_metric(), + vector_index_proto.flat_algorithm().block_size(), attribute_identifier, + attribute_data_type, memory_pool)); index->Init(vector_index_proto.dimension_count(), vector_index_proto.distance_metric(), index->space_); index->algo_ = std::make_unique>( @@ -100,14 +102,14 @@ template absl::StatusOr>> VectorFlat::LoadFromRDB( ValkeyModuleCtx *ctx, const AttributeDataType *attribute_data_type, const data_model::VectorIndex &vector_index_proto, - absl::string_view attribute_identifier, - SupplementalContentChunkIter &&iter) { + absl::string_view attribute_identifier, SupplementalContentChunkIter &&iter, + MemoryPool &memory_pool) { try { auto index = std::shared_ptr>(new VectorFlat( vector_index_proto.dimension_count(), vector_index_proto.distance_metric(), vector_index_proto.flat_algorithm().block_size(), attribute_identifier, - attribute_data_type->ToProto())); + attribute_data_type->ToProto(), memory_pool)); index->Init(vector_index_proto.dimension_count(), vector_index_proto.distance_metric(), index->space_); index->algo_ = @@ -127,9 +129,9 @@ template VectorFlat::VectorFlat( int dimensions, valkey_search::data_model::DistanceMetric distance_metric, uint32_t block_size, absl::string_view attribute_identifier, - data_model::AttributeDataType attribute_data_type) + data_model::AttributeDataType attribute_data_type, MemoryPool &memory_pool) : VectorBase(IndexerType::kFlat, dimensions, attribute_data_type, - attribute_identifier), + attribute_identifier, memory_pool), block_size_(block_size) {} template diff --git a/src/indexes/vector_flat.h b/src/indexes/vector_flat.h index a9c265d0..791af4c0 100644 --- a/src/indexes/vector_flat.h +++ b/src/indexes/vector_flat.h @@ -27,6 +27,7 @@ #include "src/utils/string_interning.h" #include "third_party/hnswlib/bruteforce.h" #include "third_party/hnswlib/hnswlib.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { @@ -37,13 +38,14 @@ class VectorFlat : public VectorBase { static absl::StatusOr>> Create( const data_model::VectorIndex& vector_index_proto, absl::string_view attribute_identifier, - data_model::AttributeDataType attribute_data_type) - ABSL_NO_THREAD_SAFETY_ANALYSIS; + data_model::AttributeDataType attribute_data_type, + MemoryPool& memory_pool) ABSL_NO_THREAD_SAFETY_ANALYSIS; static absl::StatusOr>> LoadFromRDB( ValkeyModuleCtx* ctx, const AttributeDataType* attribute_data_type, const data_model::VectorIndex& vector_index_proto, absl::string_view attribute_identifier, - SupplementalContentChunkIter&& iter) ABSL_NO_THREAD_SAFETY_ANALYSIS; + SupplementalContentChunkIter&& iter, + MemoryPool& memory_pool) ABSL_NO_THREAD_SAFETY_ANALYSIS; ~VectorFlat() override = default; size_t GetDataTypeSize() const override { return sizeof(T); } @@ -94,7 +96,8 @@ class VectorFlat : public VectorBase { private: VectorFlat(int dimensions, data_model::DistanceMetric distance_metric, uint32_t block_size, absl::string_view attribute_identifier, - data_model::AttributeDataType attribute_data_type); + data_model::AttributeDataType attribute_data_type, + MemoryPool& memory_pool); std::unique_ptr> algo_ ABSL_GUARDED_BY(resize_mutex_); std::unique_ptr> space_; diff --git a/src/indexes/vector_hnsw.cc b/src/indexes/vector_hnsw.cc index 344332db..ded246fa 100644 --- a/src/indexes/vector_hnsw.cc +++ b/src/indexes/vector_hnsw.cc @@ -38,6 +38,7 @@ #include "src/valkey_search.h" #include "valkey_search_options.h" #include "vmsdk/src/log.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/status/status_macros.h" #include "vmsdk/src/utils.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" @@ -83,11 +84,12 @@ template absl::StatusOr>> VectorHNSW::Create( const data_model::VectorIndex &vector_index_proto, absl::string_view attribute_identifier, - data_model::AttributeDataType attribute_data_type) { + data_model::AttributeDataType attribute_data_type, + MemoryPool &memory_pool) { try { - auto index = std::shared_ptr>( - new VectorHNSW(vector_index_proto.dimension_count(), - attribute_identifier, attribute_data_type)); + auto index = std::shared_ptr>(new VectorHNSW( + vector_index_proto.dimension_count(), attribute_identifier, + attribute_data_type, memory_pool)); index->Init(vector_index_proto.dimension_count(), vector_index_proto.distance_metric(), index->space_); const auto &hnsw_proto = vector_index_proto.hnsw_algorithm(); @@ -140,12 +142,12 @@ template absl::StatusOr>> VectorHNSW::LoadFromRDB( ValkeyModuleCtx *ctx, const AttributeDataType *attribute_data_type, const data_model::VectorIndex &vector_index_proto, - absl::string_view attribute_identifier, - SupplementalContentChunkIter &&iter) { + absl::string_view attribute_identifier, SupplementalContentChunkIter &&iter, + MemoryPool &memory_pool) { try { auto index = std::shared_ptr>(new VectorHNSW( vector_index_proto.dimension_count(), attribute_identifier, - attribute_data_type->ToProto())); + attribute_data_type->ToProto(), memory_pool)); index->Init(vector_index_proto.dimension_count(), vector_index_proto.distance_metric(), index->space_); @@ -175,9 +177,10 @@ absl::StatusOr>> VectorHNSW::LoadFromRDB( template VectorHNSW::VectorHNSW(int dimensions, absl::string_view attribute_identifier, - data_model::AttributeDataType attribute_data_type) + data_model::AttributeDataType attribute_data_type, + MemoryPool &memory_pool) : VectorBase(IndexerType::kHNSW, dimensions, attribute_data_type, - attribute_identifier) {} + attribute_identifier, memory_pool) {} template absl::Status VectorHNSW::AddRecordImpl(uint64_t internal_id, diff --git a/src/indexes/vector_hnsw.h b/src/indexes/vector_hnsw.h index 2bb179be..736a4500 100644 --- a/src/indexes/vector_hnsw.h +++ b/src/indexes/vector_hnsw.h @@ -26,6 +26,7 @@ #include "src/utils/string_interning.h" #include "third_party/hnswlib/hnswalg.h" #include "third_party/hnswlib/hnswlib.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { @@ -36,13 +37,14 @@ class VectorHNSW : public VectorBase { static absl::StatusOr>> Create( const data_model::VectorIndex& vector_index_proto, absl::string_view attribute_identifier, - data_model::AttributeDataType attribute_data_type) - ABSL_NO_THREAD_SAFETY_ANALYSIS; + data_model::AttributeDataType attribute_data_type, + MemoryPool& memory_pool) ABSL_NO_THREAD_SAFETY_ANALYSIS; static absl::StatusOr>> LoadFromRDB( ValkeyModuleCtx* ctx, const AttributeDataType* attribute_data_type, const data_model::VectorIndex& vector_index_proto, absl::string_view attribute_identifier, - SupplementalContentChunkIter&& iter) ABSL_NO_THREAD_SAFETY_ANALYSIS; + SupplementalContentChunkIter&& iter, + MemoryPool& memory_pool) ABSL_NO_THREAD_SAFETY_ANALYSIS; ~VectorHNSW() override = default; size_t GetDataTypeSize() const override { return sizeof(T); } @@ -104,7 +106,8 @@ class VectorHNSW : public VectorBase { private: VectorHNSW(int dimensions, absl::string_view attribute_identifier, - data_model::AttributeDataType attribute_data_type); + data_model::AttributeDataType attribute_data_type, + MemoryPool& memory_pool); std::unique_ptr> algo_ ABSL_GUARDED_BY(resize_mutex_); std::unique_ptr> space_; diff --git a/testing/CMakeLists.txt b/testing/CMakeLists.txt index 878d849a..757193c3 100644 --- a/testing/CMakeLists.txt +++ b/testing/CMakeLists.txt @@ -42,6 +42,7 @@ target_link_libraries(testing_common_base PUBLIC module) target_link_libraries(testing_common_base PUBLIC utils) target_link_libraries(testing_common_base PUBLIC valkey_module) target_link_libraries(testing_common_base PUBLIC vmsdk_testing_infra) +target_link_libraries(testing_common_base PUBLIC memory_tracker) # Coordinator common library - used by tests that need coordinator functionality add_library(testing_common_coordinator INTERFACE) diff --git a/testing/common.cc b/testing/common.cc index fa196dfb..00bd5dfd 100644 --- a/testing/common.cc +++ b/testing/common.cc @@ -81,7 +81,8 @@ absl::StatusOr> CreateVectorHNSWSchema( CreateHNSWVectorIndexProto(dimensions, data_model::DISTANCE_METRIC_COSINE, 1000, 10, 300, 30), "vector_identifier", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + test_index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index); VMSDK_EXPECT_OK(test_index_schema->AddIndex("vector", "vector", *index)); return test_index_schema; diff --git a/testing/common.h b/testing/common.h index 3b9ca012..1b498d5a 100644 --- a/testing/common.h +++ b/testing/common.h @@ -44,6 +44,7 @@ #include "src/vector_externalizer.h" #include "third_party/hnswlib/iostream.h" #include "vmsdk/src/managed_pointers.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/module_config.h" #include "vmsdk/src/status/status_macros.h" #include "vmsdk/src/testing_infra/module.h" @@ -55,7 +56,8 @@ namespace valkey_search { template class IndexTeser : public T { public: - explicit IndexTeser(K proto) : T(K(proto)) {} + explicit IndexTeser(K proto, MemoryPool& memory_pool) + : T(K(proto), memory_pool) {} absl::StatusOr AddRecord(absl::string_view key, absl::string_view data) { auto interned_key = StringInternStore::Intern(key); @@ -80,8 +82,10 @@ class IndexTeser : public T { class MockIndex : public indexes::IndexBase { public: - MockIndex() : indexes::IndexBase(indexes::IndexerType::kNone) {} - MockIndex(indexes::IndexerType type) : indexes::IndexBase(type) {} + MockIndex(MemoryPool& memory_pool) + : indexes::IndexBase(indexes::IndexerType::kNone, memory_pool) {} + MockIndex(indexes::IndexerType type, MemoryPool& memory_pool) + : indexes::IndexBase(type, memory_pool) {} MOCK_METHOD(absl::StatusOr, AddRecord, (const InternedStringPtr& key, absl::string_view data), (override)); diff --git a/testing/filter_test.cc b/testing/filter_test.cc index b445a5a7..f24e5513 100644 --- a/testing/filter_test.cc +++ b/testing/filter_test.cc @@ -17,6 +17,8 @@ #include "src/query/predicate.h" #include "src/utils/string_interning.h" #include "testing/common.h" +#include "vmsdk/src/memory_tracker.h" + namespace valkey_search { namespace { @@ -36,18 +38,19 @@ struct FilterTestCase { class FilterTest : public ValkeySearchTestWithParam { public: indexes::InlineVectorEvaluator evaluator_; + MemoryPool memory_pool_; }; -void InitIndexSchema(MockIndexSchema *index_schema) { +void InitIndexSchema(MockIndexSchema *index_schema, MemoryPool &memory_pool) { data_model::NumericIndex numeric_index_proto; auto numeric_index_1_5 = std::make_shared>( - numeric_index_proto); + numeric_index_proto, memory_pool); auto numeric_index_2_0 = std::make_shared>( - numeric_index_proto); + numeric_index_proto, memory_pool); VMSDK_EXPECT_OK(numeric_index_1_5->AddRecord("key1", "1.5")); VMSDK_EXPECT_OK(numeric_index_2_0->AddRecord("key1", "2.0")); VMSDK_EXPECT_OK(index_schema->AddIndex("num_field_1.5", "num_field_1.5", @@ -60,19 +63,19 @@ void InitIndexSchema(MockIndexSchema *index_schema) { tag_index_proto.set_case_sensitive(true); auto tag_index_1 = std::make_shared>( - tag_index_proto); + tag_index_proto, memory_pool); VMSDK_EXPECT_OK(tag_index_1->AddRecord("key1", "tag1")); VMSDK_EXPECT_OK( index_schema->AddIndex("tag_field_1", "tag_field_1", tag_index_1)); auto tag_index_1_2 = std::make_shared>( - tag_index_proto); + tag_index_proto, memory_pool); VMSDK_EXPECT_OK(tag_index_1_2->AddRecord("key1", "tag2,tag1")); VMSDK_EXPECT_OK( index_schema->AddIndex("tag_field_1_2", "tag_field_1_2", tag_index_1_2)); auto tag_index_with_space = std::make_shared>( - tag_index_proto); + tag_index_proto, memory_pool); VMSDK_EXPECT_OK(tag_index_with_space->AddRecord("key1", "tag 1 ,tag 2")); VMSDK_EXPECT_OK(index_schema->AddIndex( "tag_field_with_space", "tag_field_with_space", tag_index_with_space)); @@ -82,7 +85,7 @@ void InitIndexSchema(MockIndexSchema *index_schema) { tag_case_insensitive_proto.set_case_sensitive(false); auto tag_field_case_insensitive = std::make_shared>( - tag_case_insensitive_proto); + tag_case_insensitive_proto, memory_pool); VMSDK_EXPECT_OK(tag_field_case_insensitive->AddRecord("key1", "tag1")); VMSDK_EXPECT_OK(index_schema->AddIndex("tag_field_case_insensitive", "tag_field_case_insensitive", @@ -92,7 +95,7 @@ void InitIndexSchema(MockIndexSchema *index_schema) { TEST_P(FilterTest, ParseParams) { const FilterTestCase &test_case = GetParam(); auto index_schema = CreateIndexSchema("index_schema_name").value(); - InitIndexSchema(index_schema.get()); + InitIndexSchema(index_schema.get(), memory_pool_); EXPECT_CALL(*index_schema, GetIdentifier(::testing::_)) .Times(::testing::AnyNumber()); FilterParser parser(*index_schema, test_case.filter); diff --git a/testing/ft_search_parser_test.cc b/testing/ft_search_parser_test.cc index 14e88bab..d7f8dd0b 100644 --- a/testing/ft_search_parser_test.cc +++ b/testing/ft_search_parser_test.cc @@ -126,19 +126,21 @@ void DoVectorSearchParserTest(const FTSearchParserTestCase &test_case, flat_algorithm_proto.release()); auto index = indexes::VectorFlat::Create( vector_index_proto, "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK( index_schema->AddIndex(test_case.attribute_alias, "id1", index)); } else { // Non Vector index setup data_model::NumericIndex numeric_index_proto; - auto numeric_index = - std::make_shared(numeric_index_proto); + auto numeric_index = std::make_shared( + numeric_index_proto, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK( index_schema->AddIndex("attribute_identifier_1", "id1", numeric_index)); data_model::TagIndex tag_index_proto; - auto tag_index = std::make_shared(tag_index_proto); + auto tag_index = std::make_shared( + tag_index_proto, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK( index_schema->AddIndex("attribute_identifier_2", "id2", tag_index)); } diff --git a/testing/index_schema_test.cc b/testing/index_schema_test.cc index a35fc099..eeef9ea9 100644 --- a/testing/index_schema_test.cc +++ b/testing/index_schema_test.cc @@ -125,7 +125,8 @@ TEST_P(IndexSchemaSubscriptionTest, OnKeyspaceNotificationTest) { .value(); EXPECT_TRUE( KeyspaceEventManager::Instance().HasSubscription(index_schema.get())); - auto mock_index = std::make_shared(test_case.index_type); + auto mock_index = std::make_shared( + test_case.index_type, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex("attribute_name", test_case.hash_field, mock_index)); @@ -606,7 +607,8 @@ TEST_P(IndexSchemaSubscriptionSimpleTest, DropIndexPrematurely) { .value(); EXPECT_TRUE( KeyspaceEventManager::Instance().HasSubscription(index_schema.get())); - auto mock_index = std::make_shared(); + auto mock_index = + std::make_shared(index_schema->GetMemoryPool()); VMSDK_EXPECT_OK( index_schema->AddIndex("attribute_name", "vector", mock_index)); @@ -723,7 +725,7 @@ TEST_P(IndexSchemaSubscriptionSimpleTest, IndexSchemaInDifferentDBTest) { std::make_unique(), use_thread_pool ? &mutations_thread_pool : nullptr) .value(); - auto mock_index = std::make_shared(); + auto mock_index = std::make_shared(index_schema->GetMemoryPool()); VMSDK_EXPECT_OK( index_schema->AddIndex("attribute_name", "test_identifier", mock_index)); @@ -751,7 +753,7 @@ TEST_P(IndexSchemaSubscriptionSimpleTest, std::make_unique(), use_thread_pool ? &mutations_thread_pool : nullptr) .value(); - auto mock_index = std::make_shared(); + auto mock_index = std::make_shared(index_schema->GetMemoryPool()); VMSDK_EXPECT_OK( index_schema->AddIndex("attribute_name", "test_identifier", mock_index)); @@ -782,7 +784,7 @@ TEST_P(IndexSchemaSubscriptionSimpleTest, KeyspaceNotificationWithNullptrTest) { std::make_unique(), use_thread_pool ? &mutations_thread_pool : nullptr) .value(); - auto mock_index = std::make_shared(); + auto mock_index = std::make_shared(index_schema->GetMemoryPool()); VMSDK_EXPECT_OK( index_schema->AddIndex("attribute_name", "test_identifier", mock_index)); EXPECT_CALL(*kMockValkeyModule, OpenKey(&fake_ctx_, testing::_, testing::_)) @@ -882,7 +884,7 @@ TEST_P(IndexSchemaBackfillTest, PerformBackfillTest) { std::make_unique(), use_thread_pool ? &thread_pool : nullptr) .value(); - auto mock_index = std::make_shared(); + auto mock_index = std::make_shared(index_schema->GetMemoryPool()); VMSDK_EXPECT_OK( index_schema->AddIndex("attribute_name", "test_identifier", mock_index)); @@ -1187,7 +1189,8 @@ TEST_F(IndexSchemaRDBTest, SaveAndLoad) ABSL_NO_THREAD_SAFETY_ANALYSIS { CreateHNSWVectorIndexProto(dimensions, distance_metric, initial_cap, m, ef_construction, ef_runtime), "hnsw_attribute", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK(index_schema->AddIndex("hnsw_attribute", "hnsw_identifier", hnsw_index)); @@ -1210,7 +1213,8 @@ TEST_F(IndexSchemaRDBTest, SaveAndLoad) ABSL_NO_THREAD_SAFETY_ANALYSIS { CreateFlatVectorIndexProto(dimensions, distance_metric, initial_cap, block_size), "flat_identifier", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK(index_schema->AddIndex("flat_attribute", "flat_identifier", flat_index)); @@ -1270,7 +1274,8 @@ TEST_F(IndexSchemaRDBTest, LoadEndedDeletesOrphanedKeys) { vmsdk::ThreadPool mutations_thread_pool("writer-thread-pool-", 1); mutations_thread_pool.StartWorkers(); for (bool use_thread_pool : {true, false}) { - auto mock_index = std::make_shared(); + MemoryPool memory_pool{}; + auto mock_index = std::make_shared(memory_pool); absl::flat_hash_map keys_in_index = { {"key1", 1}, {"key2", 2}, {"key3", 3}}; EXPECT_CALL(*mock_index, ForEachTrackedKey(testing::_)) @@ -1340,7 +1345,8 @@ class IndexSchemaFriendTest : public ValkeySearchTest { CreateHNSWVectorIndexProto(dimensions, distance_metric, initial_cap, m, ef_construction, ef_runtime), attribute_identifier, - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK(index_schema->AddIndex(attribute_identifier, "hnsw_identifier", hnsw_index)); @@ -1646,7 +1652,8 @@ TEST_F(IndexSchemaRDBTest, ComprehensiveSkipLoadTest) { CreateHNSWVectorIndexProto(dimensions, distance_metric, initial_cap, m, ef_construction, ef_runtime), "embedding", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK(index_schema->AddIndex("embedding", "emb_id", hnsw_index)); @@ -1792,18 +1799,19 @@ TEST_F(IndexSchemaRDBTest, ComprehensiveSkipLoadTest) { CreateHNSWVectorIndexProto(dimensions, distance_metric, initial_cap, m, ef_construction, ef_runtime), "embedding", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK(index_schema->AddIndex("embedding", "emb_id", hnsw_index)); // Add numeric index - auto numeric_index = - std::make_shared(CreateNumericIndexProto()); + auto numeric_index = std::make_shared( + CreateNumericIndexProto(), index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex("price", "price_id", numeric_index)); // Add tag index - auto tag_index = - std::make_shared(CreateTagIndexProto(",", false)); + auto tag_index = std::make_shared( + CreateTagIndexProto(",", false), index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex("category", "cat_id", tag_index)); // Add test data for all indexes @@ -1984,7 +1992,8 @@ TEST_F(IndexSchemaRDBTest, ComprehensiveSkipLoadTest) { CreateHNSWVectorIndexProto(dimensions, distance_metric, initial_cap, m, ef_construction, ef_runtime), "embedding1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK( index_schema->AddIndex("embedding1", "emb1_id", hnsw_index1)); @@ -1994,7 +2003,8 @@ TEST_F(IndexSchemaRDBTest, ComprehensiveSkipLoadTest) { CreateHNSWVectorIndexProto(dimensions, distance_metric, initial_cap, m, ef_construction, ef_runtime), "embedding2", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK( index_schema->AddIndex("embedding2", "emb2_id", hnsw_index2)); @@ -2004,7 +2014,8 @@ TEST_F(IndexSchemaRDBTest, ComprehensiveSkipLoadTest) { CreateFlatVectorIndexProto(dimensions, distance_metric, initial_cap, block_size), "embedding3", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK( index_schema->AddIndex("embedding3", "emb3_id", flat_index)); diff --git a/testing/multi_exec_test.cc b/testing/multi_exec_test.cc index 59912079..809b4e04 100644 --- a/testing/multi_exec_test.cc +++ b/testing/multi_exec_test.cc @@ -43,7 +43,7 @@ class MultiExecTest : public ValkeySearchTest { index_schema = CreateVectorHNSWSchema(index_schema_name_str, &fake_ctx_, mutations_thread_pool) .value(); - mock_index = std::make_shared(); + mock_index = std::make_shared(index_schema->GetMemoryPool()); const char *identifier = "test_identifier"; VMSDK_EXPECT_OK( index_schema->AddIndex("attribute_name", identifier, mock_index)); diff --git a/testing/numeric_index_test.cc b/testing/numeric_index_test.cc index cfaf29c3..f9c26cc5 100644 --- a/testing/numeric_index_test.cc +++ b/testing/numeric_index_test.cc @@ -9,6 +9,7 @@ #include #include +#include "absl/container/btree_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "gmock/gmock.h" @@ -16,8 +17,12 @@ #include "src/indexes/index_base.h" #include "src/indexes/numeric.h" #include "src/query/predicate.h" +#include "src/utils/segment_tree.h" #include "testing/common.h" +#include "vmsdk/src/memory_allocation.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/testing_infra/utils.h" +#include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search::indexes { @@ -26,7 +31,9 @@ namespace { class NumericIndexTest : public vmsdk::ValkeyTest { protected: data_model::NumericIndex numeric_index_proto; - IndexTeser index{numeric_index_proto}; + MemoryPool memory_pool; + IndexTeser index{numeric_index_proto, + memory_pool}; }; std::vector Fetch( @@ -312,6 +319,271 @@ TEST_F(NumericIndexTest, DeletedKeysNegativeSearchTest) { EXPECT_THAT(Fetch(*entries_fetcher), testing::UnorderedElementsAre("doc0")); } +#ifndef SAN_BUILD +TEST_F(NumericIndexTest, MemoryTrackingAddRecord) { + auto key = absl::string_view{"key"}; + auto record = absl::string_view{"1.5"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + EXPECT_TRUE(index.AddRecord(key, record).value()); + int64_t after_first_add = memory_pool.GetUsage(); + EXPECT_GT(after_first_add, initial_memory); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); + + EXPECT_TRUE(index.RemoveRecord(key).ok()); +} + +TEST_F(NumericIndexTest, MemoryTrackingAddDuplicatedRecord) { + auto key = absl::string_view{"key"}; + auto record1 = absl::string_view{"1.5"}; + auto record2 = absl::string_view{"2.5"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + EXPECT_TRUE(index.AddRecord(key, record1).value()); + int64_t after_first_add = memory_pool.GetUsage(); + + auto status = index.AddRecord(key, record2); + EXPECT_EQ(status.status().code(), absl::StatusCode::kAlreadyExists); + int64_t after_duplicate_add = memory_pool.GetUsage(); + EXPECT_EQ(after_duplicate_add, after_first_add); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); + + EXPECT_TRUE(index.RemoveRecord(key).ok()); +} + +TEST_F(NumericIndexTest, MemoryTrackingAddInvalidRecord) { + auto key = absl::string_view{"key"}; + auto invalid_record = absl::string_view{"not_a_number"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + EXPECT_FALSE(index.AddRecord(key, invalid_record).value()); + int64_t after_non_numeric = memory_pool.GetUsage(); + // Memory might increase due to untracked_keys_ expansion + EXPECT_GE(after_non_numeric, initial_memory); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); + + EXPECT_TRUE(index.RemoveRecord(key, DeletionType::kRecord).ok()); +} + +TEST_F(NumericIndexTest, MemoryTrackingAddReplaceInvalidRecord) { + auto key = absl::string_view{"key"}; + auto invalid_record = absl::string_view{"not_a_number"}; + auto valid_record = absl::string_view{"1.5"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + EXPECT_FALSE(index.AddRecord(key, invalid_record).value()); + int64_t after_non_numeric = memory_pool.GetUsage(); + + EXPECT_TRUE(index.AddRecord(key, valid_record).value()); + int64_t after_valid_add = memory_pool.GetUsage(); + EXPECT_GT(after_valid_add, after_non_numeric); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); + + EXPECT_TRUE(index.RemoveRecord(key).ok()); +} + +TEST_F(NumericIndexTest, MemoryTrackingModifyRecord) { + auto key = absl::string_view{"key"}; + auto record1 = absl::string_view{"1.5"}; + auto record2 = absl::string_view{"2.5"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + EXPECT_TRUE(index.AddRecord(key, record1).value()); + int64_t after_add = memory_pool.GetUsage(); + EXPECT_GT(after_add, initial_memory); + + EXPECT_TRUE(index.ModifyRecord(key, record2).value()); + int64_t after_modify = memory_pool.GetUsage(); + EXPECT_EQ(after_modify, after_add); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); + + EXPECT_TRUE(index.RemoveRecord(key).ok()); +} + +TEST_F(NumericIndexTest, MemoryTrackingModifyRecordNotFound) { + auto key = absl::string_view{"key"}; + auto record = absl::string_view{"1.5"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + auto status = index.ModifyRecord(key, record); + EXPECT_EQ(status.status().code(), absl::StatusCode::kNotFound); + int64_t after_modify = memory_pool.GetUsage(); + EXPECT_EQ(after_modify, initial_memory); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); +} + +TEST_F(NumericIndexTest, MemoryTrackingModifyRecordInvalid) { + auto key = absl::string_view{"key"}; + auto invalid_record = absl::string_view{"not_a_number"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + EXPECT_FALSE(index.ModifyRecord(key, invalid_record).value()); + int64_t after_invalid_modify = memory_pool.GetUsage(); + // Memory might increase due to untracked_keys_ expansion + EXPECT_GE(after_invalid_modify, initial_memory); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); + + EXPECT_TRUE(index.RemoveRecord(key, DeletionType::kRecord).ok()); +} + +TEST_F(NumericIndexTest, MemoryTrackingRemoveRecord) { + auto key = absl::string_view{"key"}; + auto record = absl::string_view{"1.5"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + EXPECT_TRUE(index.AddRecord(key, record).value()); + int64_t after_add = memory_pool.GetUsage(); + EXPECT_GT(after_add, initial_memory); + + EXPECT_TRUE(index.RemoveRecord(key).value()); + int64_t after_remove = memory_pool.GetUsage(); + EXPECT_LT(after_remove, after_add); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); +} + +TEST_F(NumericIndexTest, MemoryTrackingRemoveUntracked) { + auto key = absl::string_view{"key"}; + auto invalid_record = absl::string_view{"not_a_number"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + EXPECT_FALSE(index.AddRecord(key, invalid_record).value()); + int64_t after_add_invalid = memory_pool.GetUsage(); + + EXPECT_FALSE(index.RemoveRecord(key).value()); + int64_t after_remove_untracked = memory_pool.GetUsage(); + EXPECT_LE(after_remove_untracked, after_add_invalid); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); +} + +TEST_F(NumericIndexTest, MemoryTrackingRemoveWithDeletionTypes) { + auto key1 = absl::string_view{"key1"}; + auto key2 = absl::string_view{"key2"}; + auto record = absl::string_view{"1.5"}; + + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + + EXPECT_TRUE(index.AddRecord(key1, record).value()); + EXPECT_TRUE(index.AddRecord(key2, record).value()); + int64_t after_add = memory_pool.GetUsage(); + + EXPECT_TRUE(index.RemoveRecord(key1, DeletionType::kIdentifier).value()); + int64_t after_soft_delete = memory_pool.GetUsage(); + // Memory might stay similar or increase slightly due to untracked_keys_ + // insertion + EXPECT_LE(after_soft_delete, after_add); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); +} + +TEST_F(NumericIndexTest, MemoryTrackingDestructor) { + static auto track_malloc_size = [](void* ptr) -> size_t { return 16; }; + + vmsdk::test_utils::SetTestSystemMallocSizeFunction(track_malloc_size); + + memory_pool.Reset(); + int64_t initial_memory = memory_pool.GetUsage(); + + // Keep references to interned strings outside the scope to prevent + // deallocation + std::vector string_refs; + std::unique_ptr index_ptr; + { + data_model::NumericIndex local_numeric_proto; + index_ptr = std::make_unique(local_numeric_proto, memory_pool); + + auto key1 = StringInternStore::Intern("key1"); + auto key2 = StringInternStore::Intern("key2"); + auto key3 = StringInternStore::Intern("key3"); + + string_refs.push_back(key1); + string_refs.push_back(key2); + string_refs.push_back(key3); + + EXPECT_TRUE(index_ptr->AddRecord(key1, "1.5").value()); + EXPECT_TRUE(index_ptr->AddRecord(key2, "2.5").value()); + EXPECT_TRUE(index_ptr->AddRecord(key3, "3.5").value()); + + int64_t memory_with_records = memory_pool.GetUsage(); + EXPECT_GT(memory_with_records, initial_memory); + } + + index_ptr.reset(); + + int64_t memory_after_destructor = memory_pool.GetUsage(); + EXPECT_EQ(memory_after_destructor, initial_memory); + + vmsdk::test_utils::ClearTestSystemMallocSizeFunction(); +} + +#endif + } // namespace } // namespace valkey_search::indexes diff --git a/testing/query/fanout_test.cc b/testing/query/fanout_test.cc index 941b4606..03694ab0 100644 --- a/testing/query/fanout_test.cc +++ b/testing/query/fanout_test.cc @@ -385,12 +385,15 @@ TEST_P(FanoutTest, TestFanout) { tag_index.set_separator(","); tag_index.set_case_sensitive(false); VMSDK_EXPECT_OK(schema.value()->AddIndex( - "tag_alias", "tag_id", std::make_shared(tag_index))); + "tag_alias", "tag_id", + std::make_shared(tag_index, + schema.value()->GetMemoryPool()))); data_model::NumericIndex numeric_index; VMSDK_EXPECT_OK(schema.value()->AddIndex( "numeric_alias", "numeric_id", - std::make_shared(numeric_index))); + std::make_shared(numeric_index, + schema.value()->GetMemoryPool()))); InitThreadPools(5, 0); auto mock_coordinator_client_pool = diff --git a/testing/search_test.cc b/testing/search_test.cc index dc99d694..73f9420f 100644 --- a/testing/search_test.cc +++ b/testing/search_test.cc @@ -42,6 +42,7 @@ #include "src/utils/string_interning.h" #include "testing/common.h" #include "vmsdk/src/managed_pointers.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/type_conversions.h" namespace valkey_search { @@ -68,8 +69,9 @@ auto VectorToStr = [](const std::vector &v) { class MockNumeric : public indexes::Numeric { public: - MockNumeric(const data_model::NumericIndex &numeric_index_proto) - : indexes::Numeric(numeric_index_proto) {} + MockNumeric(const data_model::NumericIndex &numeric_index_proto, + MemoryPool &memory_pool) + : indexes::Numeric(numeric_index_proto, memory_pool) {} MOCK_METHOD(std::unique_ptr, Search, (const query::NumericPredicate &predicate, bool negate), (const, override)); @@ -124,8 +126,8 @@ class TestedNumericEntriesFetcher : public indexes::Numeric::EntriesFetcher { class MockTag : public indexes::Tag { public: - MockTag(const data_model::TagIndex &tag_index_proto) - : indexes::Tag(tag_index_proto) {} + MockTag(const data_model::TagIndex &tag_index_proto, MemoryPool &memory_pool) + : indexes::Tag(tag_index_proto, memory_pool) {} MOCK_METHOD(std::unique_ptr, Search, (const query::TagPredicate &predicate, bool negate), (const, override)); @@ -167,10 +169,10 @@ void InitIndexSchema(MockIndexSchema *index_schema) { EXPECT_CALL(*index_schema, GetIdentifier(::testing::_)) .Times(::testing::AnyNumber()); - auto numeric_index_100_10 = - std::make_shared(numeric_index_proto); - auto numeric_index_100_30 = - std::make_shared(numeric_index_proto); + auto numeric_index_100_10 = std::make_shared( + numeric_index_proto, index_schema->GetMemoryPool()); + auto numeric_index_100_30 = std::make_shared( + numeric_index_proto, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex( "numeric_index_100_10", "numeric_index_100_10", numeric_index_100_10)); VMSDK_EXPECT_OK(index_schema->AddIndex( @@ -193,7 +195,8 @@ void InitIndexSchema(MockIndexSchema *index_schema) { data_model::TagIndex tag_index_proto; tag_index_proto.set_separator(","); tag_index_proto.set_case_sensitive(false); - auto tag_index_100_15 = std::make_shared(tag_index_proto); + auto tag_index_100_15 = + std::make_shared(tag_index_proto, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex("tag_index_100_15", "tag_index_100_15", tag_index_100_15)); @@ -331,7 +334,8 @@ std::shared_ptr CreateIndexSchemaWithMultipleAttributes( kVectorDimensions, data_model::DISTANCE_METRIC_L2, 1000, 10, 300, 30), "vector_attribute_identifier", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); } else { vector_index = @@ -339,7 +343,8 @@ std::shared_ptr CreateIndexSchemaWithMultipleAttributes( CreateFlatVectorIndexProto( kVectorDimensions, data_model::DISTANCE_METRIC_L2, 1000, 250), "vector_attribute_identifier", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); } VMSDK_EXPECT_OK(index_schema->AddIndex(kVectorAttributeAlias, @@ -347,14 +352,16 @@ std::shared_ptr CreateIndexSchemaWithMultipleAttributes( // Add numeric index data_model::NumericIndex numeric_index_proto; - auto numeric_index = std::make_shared(numeric_index_proto); + auto numeric_index = std::make_shared( + numeric_index_proto, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex("numeric", "numeric", numeric_index)); // Add tag index data_model::TagIndex tag_index_proto; tag_index_proto.set_separator(","); tag_index_proto.set_case_sensitive(false); - auto tag_index = std::make_shared(tag_index_proto); + auto tag_index = std::make_shared( + tag_index_proto, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex("tag", "tag", tag_index)); // Add records @@ -787,7 +794,8 @@ TEST_P(IndexedContentTest, MaybeAddIndexedContentTest) { auto vector_index = indexes::VectorHNSW::Create( vector_index_proto, "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK(index_schema->AddIndex( index.attribute_alias, index.attribute_identifier, vector_index)); @@ -800,7 +808,8 @@ TEST_P(IndexedContentTest, MaybeAddIndexedContentTest) { auto flat_index = indexes::VectorFlat::Create( vector_index_proto, "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, + index_schema->GetMemoryPool()) .value(); VMSDK_EXPECT_OK(index_schema->AddIndex( index.attribute_alias, index.attribute_identifier, flat_index)); @@ -811,7 +820,8 @@ TEST_P(IndexedContentTest, MaybeAddIndexedContentTest) { data_model::TagIndex tag_index_proto; tag_index_proto.set_separator(","); tag_index_proto.set_case_sensitive(false); - auto tag_index = std::make_shared(tag_index_proto); + auto tag_index = std::make_shared( + tag_index_proto, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex( index.attribute_alias, index.attribute_identifier, tag_index)); index_base = tag_index; @@ -819,8 +829,8 @@ TEST_P(IndexedContentTest, MaybeAddIndexedContentTest) { } case IndexerType::kNumeric: { data_model::NumericIndex numeric_index_proto; - auto numeric_index = - std::make_shared(numeric_index_proto); + auto numeric_index = std::make_shared( + numeric_index_proto, index_schema->GetMemoryPool()); VMSDK_EXPECT_OK(index_schema->AddIndex( index.attribute_alias, index.attribute_identifier, numeric_index)); index_base = numeric_index; diff --git a/testing/tag_index_test.cc b/testing/tag_index_test.cc index 111cc4b0..709b1408 100644 --- a/testing/tag_index_test.cc +++ b/testing/tag_index_test.cc @@ -16,6 +16,7 @@ #include "src/indexes/tag.h" #include "src/query/predicate.h" #include "testing/common.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/testing_infra/utils.h" namespace valkey_search::indexes { @@ -29,8 +30,9 @@ class TagIndexTest : public vmsdk::ValkeyTest { data_model::TagIndex tag_index_proto; tag_index_proto.set_separator(","); tag_index_proto.set_case_sensitive(false); + MemoryPool memory_pool{}; index = std::make_unique>( - tag_index_proto); + tag_index_proto, memory_pool); } std::unique_ptr> index; std::string identifier = "attribute_id"; diff --git a/testing/vector_test.cc b/testing/vector_test.cc index e545f657..ed11cab1 100644 --- a/testing/vector_test.cc +++ b/testing/vector_test.cc @@ -35,6 +35,7 @@ #include "third_party/hnswlib/space_ip.h" #include "third_party/hnswlib/space_l2.h" #include "vmsdk/src/managed_pointers.h" +#include "vmsdk/src/memory_tracker.h" #include "vmsdk/src/type_conversions.h" namespace valkey_search::indexes { @@ -65,16 +66,16 @@ class VectorIndexTest : public ValkeySearchTest { HashAttributeDataType hash_attribute_data_type_; }; -void TestInitializationHNSW(int dimensions, - data_model::DistanceMetric distance_metric, - const std::string& distance_metric_name, - int initial_cap, int m, int ef_construction, - size_t ef_runtime) ABSL_NO_THREAD_SAFETY_ANALYSIS { +void TestInitializationHNSW( + int dimensions, data_model::DistanceMetric distance_metric, + const std::string& distance_metric_name, int initial_cap, int m, + int ef_construction, size_t ef_runtime, + MemoryPool& memory_pool) ABSL_NO_THREAD_SAFETY_ANALYSIS { auto index = VectorHNSW::Create( CreateHNSWVectorIndexProto(dimensions, distance_metric, initial_cap, m, ef_construction, ef_runtime), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); auto* space = index.value()->GetSpace(); EXPECT_EQ(distance_metric_name, typeid(*space).name()); EXPECT_EQ(index.value()->GetDimensions(), dimensions); @@ -88,18 +89,20 @@ void TestInitializationHNSW(int dimensions, TEST_F(VectorIndexTest, InitializationHNSW) { for (auto& distance_metric : kExpectedSpaces) { + MemoryPool memory_pool{}; TestInitializationHNSW(kDimensions, distance_metric.first, distance_metric.second, kInitialCap, kM, - kEFConstruction, kEFRuntime); + kEFConstruction, kEFRuntime, memory_pool); } } TEST_F(VectorIndexTest, InitializationFlat) ABSL_NO_THREAD_SAFETY_ANALYSIS { for (auto& distance_metric : kExpectedSpaces) { + MemoryPool memory_pool{}; auto index = VectorFlat::Create( CreateFlatVectorIndexProto(kDimensions, distance_metric.first, kInitialCap, kBlockSize), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); auto* space = index.value()->GetSpace(); EXPECT_EQ(distance_metric.second, typeid(*space).name()); EXPECT_EQ(index.value()->GetDimensions(), kDimensions); @@ -228,12 +231,13 @@ class NormalizeStringRecordTest TEST_P(NormalizeStringRecordTest, NormalizeStringRecord) { auto& params = GetParam(); + MemoryPool memory_pool{}; auto index = VectorHNSW::Create( CreateHNSWVectorIndexProto(kDimensions, data_model::DISTANCE_METRIC_L2, kInitialCap, kM, kEFConstruction, kEFRuntime), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); auto record = vmsdk::MakeUniqueValkeyString(params.record); auto norm_record = index.value()->NormalizeStringRecord(std::move(record)); if (!params.success) { @@ -279,11 +283,12 @@ INSTANTIATE_TEST_SUITE_P( TEST_F(VectorIndexTest, BasicHNSW) { for (auto& distance_metric : {data_model::DISTANCE_METRIC_COSINE, data_model::DISTANCE_METRIC_L2}) { + MemoryPool memory_pool{}; auto index = VectorHNSW::Create( CreateHNSWVectorIndexProto(kDimensions, distance_metric, kInitialCap, kM, kEFConstruction, kEFRuntime), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); TestIndex>(index->get(), kDimensions, 100); } } @@ -291,11 +296,12 @@ TEST_F(VectorIndexTest, BasicHNSW) { TEST_F(VectorIndexTest, BasicFlat) { for (auto& distance_metric : {data_model::DISTANCE_METRIC_COSINE, data_model::DISTANCE_METRIC_L2}) { + MemoryPool memory_pool{}; auto index = VectorFlat::Create( CreateFlatVectorIndexProto(kDimensions, distance_metric, kInitialCap, kBlockSize), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); TestIndex>(index->get(), kDimensions, 100); } } @@ -304,11 +310,12 @@ TEST_F(VectorIndexTest, ResizeHNSW) ABSL_NO_THREAD_SAFETY_ANALYSIS { for (auto& distance_metric : {data_model::DISTANCE_METRIC_COSINE, data_model::DISTANCE_METRIC_L2}) { const int initial_cap = 10; + MemoryPool memory_pool{}; auto index = VectorHNSW::Create( CreateHNSWVectorIndexProto(kDimensions, distance_metric, initial_cap, kM, kEFConstruction, kEFRuntime), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); ValkeySearch::Instance().SetHNSWBlockSize(1024); uint32_t block_size = ValkeySearch::Instance().GetHNSWBlockSize(); EXPECT_EQ(index.value()->GetCapacity(), initial_cap); @@ -338,11 +345,12 @@ TEST_F(VectorIndexTest, ResizeFlat) ABSL_NO_THREAD_SAFETY_ANALYSIS { for (auto& distance_metric : {data_model::DISTANCE_METRIC_COSINE, data_model::DISTANCE_METRIC_L2}) { const int initial_cap = 10; + MemoryPool memory_pool{}; auto index = VectorFlat::Create( CreateFlatVectorIndexProto(kDimensions, distance_metric, initial_cap, kBlockSize), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); auto vectors = DeterministicallyGenerateVectors( initial_cap + kBlockSize + 100, kDimensions, 10.0); EXPECT_EQ(index.value()->GetCapacity(), initial_cap); @@ -388,11 +396,12 @@ TEST_F(VectorIndexTest, EfRuntimeRecall) { for (auto& distance_metric : {data_model::DISTANCE_METRIC_L2}) { // Use a large cap to make sure chunked array is properly exercised const int initial_cap = 31000; + MemoryPool memory_pool{}; auto index_hnsw = VectorHNSW::Create( CreateHNSWVectorIndexProto(kDimensions, distance_metric, initial_cap, kM, kEFConstruction, kEFRuntime), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); auto vectors = DeterministicallyGenerateVectors(1000, kDimensions, 2.2); for (size_t i = 0; i < vectors.size(); ++i) { VerifyAdd(index_hnsw->get(), vectors, i, ExpectedResults::kSuccess); @@ -402,7 +411,7 @@ TEST_F(VectorIndexTest, EfRuntimeRecall) { CreateFlatVectorIndexProto(kDimensions, distance_metric, initial_cap, kBlockSize), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); for (size_t i = 0; i < vectors.size(); ++i) { VerifyAdd(index_flat->get(), vectors, i, ExpectedResults::kSuccess); } @@ -424,6 +433,7 @@ TEST_F(VectorIndexTest, SaveAndLoadHnsw) { {data_model::DISTANCE_METRIC_COSINE, data_model::DISTANCE_METRIC_L2}) { const int initial_cap = 1000; const uint64_t k = 10; + MemoryPool memory_pool{}; FakeSafeRDB rdb; auto vectors = DeterministicallyGenerateVectors(1000, kDimensions, 2.2); // Load the vectors into a Flat index. This will be used for computing the @@ -432,7 +442,7 @@ TEST_F(VectorIndexTest, SaveAndLoadHnsw) { CreateFlatVectorIndexProto(kDimensions, distance_metric, initial_cap, kBlockSize), "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); VMSDK_EXPECT_OK(index_flat); for (size_t i = 0; i < vectors.size(); ++i) { VerifyAdd(index_flat->get(), vectors, i, ExpectedResults::kSuccess); @@ -445,7 +455,7 @@ TEST_F(VectorIndexTest, SaveAndLoadHnsw) { { auto index_hnsw = VectorHNSW::Create( hnsw_proto, "attribute_identifier_2", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); VMSDK_EXPECT_OK(index_hnsw); if (distance_metric == data_model::DISTANCE_METRIC_COSINE) { EXPECT_TRUE((*index_hnsw)->GetNormalize()); @@ -460,7 +470,8 @@ TEST_F(VectorIndexTest, SaveAndLoadHnsw) { { auto loaded_index_hnsw = VectorHNSW::LoadFromRDB( &fake_ctx_, &hash_attribute_data_type_, hnsw_proto, - "attribute_identifier_3", SupplementalContentChunkIter(&rdb)); + "attribute_identifier_3", SupplementalContentChunkIter(&rdb), + memory_pool); VMSDK_EXPECT_OK(loaded_index_hnsw); VMSDK_EXPECT_OK( (*loaded_index_hnsw) @@ -485,7 +496,8 @@ TEST_F(VectorIndexTest, SaveAndLoadHnsw) { { auto loaded_index_hnsw = VectorHNSW::LoadFromRDB( &fake_ctx_, &hash_attribute_data_type_, hnsw_proto, - "attribute_identifier_4", SupplementalContentChunkIter(&rdb)); + "attribute_identifier_4", SupplementalContentChunkIter(&rdb), + memory_pool); VMSDK_EXPECT_OK(loaded_index_hnsw); VMSDK_EXPECT_OK( (*loaded_index_hnsw) @@ -504,6 +516,7 @@ TEST_F(VectorIndexTest, SaveAndLoadFlat) { {data_model::DISTANCE_METRIC_COSINE, data_model::DISTANCE_METRIC_L2}) { const int initial_cap = 1000; const uint64_t k = 10; + MemoryPool memory_pool{}; FakeSafeRDB rdb; auto vectors = DeterministicallyGenerateVectors(1000, kDimensions, 2.2); auto search_vectors = @@ -516,7 +529,7 @@ TEST_F(VectorIndexTest, SaveAndLoadFlat) { { auto index = VectorFlat::Create( flat_proto, "attribute_identifier_1", - data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH); + data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH, memory_pool); if (distance_metric == data_model::DISTANCE_METRIC_COSINE) { EXPECT_TRUE(index.value()->GetNormalize()); } @@ -529,7 +542,8 @@ TEST_F(VectorIndexTest, SaveAndLoadFlat) { { auto index_pr = VectorFlat::LoadFromRDB( &fake_ctx_, &hash_attribute_data_type_, flat_proto, - "attribute_identifier_2", SupplementalContentChunkIter(&rdb)); + "attribute_identifier_2", SupplementalContentChunkIter(&rdb), + memory_pool); VMSDK_EXPECT_OK(index_pr); auto index = std::move(index_pr.value()); VMSDK_EXPECT_OK( @@ -553,7 +567,8 @@ TEST_F(VectorIndexTest, SaveAndLoadFlat) { { auto index_pr = VectorFlat::LoadFromRDB( &fake_ctx_, &hash_attribute_data_type_, flat_proto, - "attribute_identifier_3", SupplementalContentChunkIter(&rdb)); + "attribute_identifier_3", SupplementalContentChunkIter(&rdb), + memory_pool); VMSDK_EXPECT_OK(index_pr); auto index = std::move(index_pr.value()); VMSDK_EXPECT_OK( diff --git a/vmsdk/src/memory_allocation_overrides.cc b/vmsdk/src/memory_allocation_overrides.cc index 1a1b541d..17d7677c 100644 --- a/vmsdk/src/memory_allocation_overrides.cc +++ b/vmsdk/src/memory_allocation_overrides.cc @@ -130,13 +130,32 @@ void* PerformAndTrackAlignedAlloc(size_t align, size_t size, } return ptr; } + +namespace test_utils { + +thread_local size_t (*test_malloc_size_fn)(void*) = nullptr; + +void SetTestSystemMallocSizeFunction(size_t (*fn)(void*)) { + test_malloc_size_fn = fn; +} + +void ClearTestSystemMallocSizeFunction() { + test_malloc_size_fn = nullptr; +} + +} // namespace test_utils } // namespace vmsdk extern "C" { -// Our allocator doesn't support tracking system memory size, so we just -// return 0. +// Basically our allocator doesn't support tracking system memory size, so we just +// return 0. But if test_malloc_size_fn is set, tracking system memory size is possible. // NOLINTNEXTLINE -__attribute__((weak)) size_t empty_usable_size(void* ptr) noexcept { return 0; } +__attribute__((weak)) size_t usable_size(void* ptr) noexcept { + if (vmsdk::test_utils::test_malloc_size_fn) { + return vmsdk::test_utils::test_malloc_size_fn(ptr); + } + return 0; +} // For Valkey allocation - we need to ensure alignment by taking advantage of // jemalloc alignment properties, as there is no aligned malloc module @@ -152,7 +171,7 @@ size_t AlignSize(size_t size, int alignment = 16) { void* __wrap_malloc(size_t size) noexcept { if (!vmsdk::IsUsingValkeyAlloc()) { auto ptr = - vmsdk::PerformAndTrackMalloc(size, __real_malloc, empty_usable_size); + vmsdk::PerformAndTrackMalloc(size, __real_malloc, usable_size); vmsdk::SystemAllocTracker::GetInstance().TrackPointer(ptr); return ptr; } @@ -172,7 +191,7 @@ void __wrap_free(void* ptr) noexcept { // another DSO which doesn't have our wrapped symbols (namely libc.so). For // this reason, we bypass the tracking during the bootstrap phase. if (was_tracked || !vmsdk::IsUsingValkeyAlloc()) { - vmsdk::PerformAndTrackFree(ptr, __real_free, empty_usable_size); + vmsdk::PerformAndTrackFree(ptr, __real_free, usable_size); } else { vmsdk::PerformAndTrackFree(ptr, ValkeyModule_Free, ValkeyModule_MallocUsableSize); @@ -182,7 +201,7 @@ void __wrap_free(void* ptr) noexcept { void* __wrap_calloc(size_t __nmemb, size_t size) noexcept { if (!vmsdk::IsUsingValkeyAlloc()) { auto ptr = vmsdk::PerformAndTrackCalloc(__nmemb, size, __real_calloc, - empty_usable_size); + usable_size); vmsdk::SystemAllocTracker::GetInstance().TrackPointer(ptr); return ptr; } @@ -203,7 +222,7 @@ void* __wrap_realloc(void* ptr, size_t size) noexcept { ValkeyModule_MallocUsableSize); } else { auto new_ptr = vmsdk::PerformAndTrackRealloc(ptr, size, __real_realloc, - empty_usable_size); + usable_size); vmsdk::SystemAllocTracker::GetInstance().TrackPointer(new_ptr); return new_ptr; } @@ -212,7 +231,7 @@ void* __wrap_realloc(void* ptr, size_t size) noexcept { void* __wrap_aligned_alloc(size_t __alignment, size_t __size) noexcept { if (!vmsdk::IsUsingValkeyAlloc()) { auto ptr = vmsdk::PerformAndTrackAlignedAlloc( - __alignment, __size, __real_aligned_alloc, empty_usable_size); + __alignment, __size, __real_aligned_alloc, usable_size); vmsdk::SystemAllocTracker::GetInstance().TrackPointer(ptr); return ptr; } @@ -224,7 +243,7 @@ void* __wrap_aligned_alloc(size_t __alignment, size_t __size) noexcept { int __wrap_malloc_usable_size(void* ptr) noexcept { if (vmsdk::SystemAllocTracker::GetInstance().IsTracked(ptr)) { - return empty_usable_size(ptr); + return usable_size(ptr); } return ValkeyModule_MallocUsableSize(ptr); } diff --git a/vmsdk/src/memory_allocation_overrides.h b/vmsdk/src/memory_allocation_overrides.h index d082b0d4..65655e05 100644 --- a/vmsdk/src/memory_allocation_overrides.h +++ b/vmsdk/src/memory_allocation_overrides.h @@ -36,7 +36,7 @@ WEAK_SYMBOL int (*__real_posix_memalign)(void**, size_t, // NOLINTNEXTLINE WEAK_SYMBOL void* (*__real_valloc)(size_t) = valloc; // NOLINTNEXTLINE -__attribute__((weak)) size_t empty_usable_size(void* ptr) noexcept; +__attribute__((weak)) size_t usable_size(void* ptr) noexcept; } // extern "C" // Different exception specifier between CLANG & GCC @@ -108,4 +108,23 @@ void operator delete[](void* p, std::align_val_t alignment, void operator delete[](void* p, size_t size, std::align_val_t alignment) noexcept; #endif // !SAN_BUILD + +namespace vmsdk { +namespace test_utils { + +// Set a custom malloc size function for testing purposes. +// This allows tests to provide their own implementation of malloc_usable_size +// for system allocations, enabling accurate memory tracking in tests. +// The function pointer is thread-local, so it only affects the calling thread. +// +// @param fn Function pointer that takes a void* and returns the allocated size. +// Pass nullptr to clear the test function. +void SetTestSystemMallocSizeFunction(size_t (*fn)(void*)); + +// Clear the test malloc size function, reverting to default behavior. +void ClearTestSystemMallocSizeFunction(); + +} // namespace test_utils +} // namespace vmsdk + #endif // VMSDK_SRC_MEMORY_ALLOCATION_OVERRIDES_H_ diff --git a/vmsdk/src/memory_tracker.cc b/vmsdk/src/memory_tracker.cc index 9d1b4b67..5f2a396c 100644 --- a/vmsdk/src/memory_tracker.cc +++ b/vmsdk/src/memory_tracker.cc @@ -52,3 +52,10 @@ NestedMemoryScope::~NestedMemoryScope() { int64_t net_change = current_delta - baseline_memory_; target_pool_.Add(net_change); } + +DisableMemoryTracking::DisableMemoryTracking() + : saved_delta_(vmsdk::GetMemoryDelta()) {} + +DisableMemoryTracking::~DisableMemoryTracking() { + vmsdk::SetMemoryDelta(saved_delta_); +} diff --git a/vmsdk/src/memory_tracker.h b/vmsdk/src/memory_tracker.h index 02f46871..d001f9f9 100644 --- a/vmsdk/src/memory_tracker.h +++ b/vmsdk/src/memory_tracker.h @@ -100,4 +100,25 @@ class NestedMemoryScope final : public MemoryScope { VMSDK_NON_COPYABLE_NON_MOVABLE(NestedMemoryScope); }; +/** + * A helper class that temporarily disables memory tracking. + * + * When this class is created, it saves the current memory delta. When + * destroyed, it restores the saved delta, effectively discarding any memory + * allocations that occurred during its lifetime from being tracked. + * + * Use this when you need to allocate memory that should not be counted as part + * of the tracked memory usage (e.g., error status objects). + */ +class DisableMemoryTracking final { + public: + DisableMemoryTracking(); + ~DisableMemoryTracking(); + + VMSDK_NON_COPYABLE_NON_MOVABLE(DisableMemoryTracking); + + private: + int64_t saved_delta_; +}; + #endif // VMSDK_SRC_MEMORY_TRACKER_H_ diff --git a/vmsdk/testing/memory_allocation_test.cc b/vmsdk/testing/memory_allocation_test.cc index 476d7fd8..f0746340 100644 --- a/vmsdk/testing/memory_allocation_test.cc +++ b/vmsdk/testing/memory_allocation_test.cc @@ -421,9 +421,9 @@ TEST_F(MemoryAllocationTest, IsolatedMemoryScopeAllocationIsolation) { { IsolatedMemoryScope outer_scope{outer_pool}; - EXPECT_CALL(*kMockRedisModule, Alloc(112)) + EXPECT_CALL(*kMockValkeyModule, Alloc(112)) .WillOnce(testing::Return(reinterpret_cast(0x1000))); - EXPECT_CALL(*kMockRedisModule, + EXPECT_CALL(*kMockValkeyModule, MallocUsableSize(reinterpret_cast(0x1000))) .WillRepeatedly(testing::Return(128)); outer_ptr = __wrap_malloc(100); @@ -435,9 +435,9 @@ TEST_F(MemoryAllocationTest, IsolatedMemoryScopeAllocationIsolation) { { IsolatedMemoryScope inner_scope{inner_pool}; - EXPECT_CALL(*kMockRedisModule, Alloc(80)) + EXPECT_CALL(*kMockValkeyModule, Alloc(80)) .WillOnce(testing::Return(reinterpret_cast(0x2000))); - EXPECT_CALL(*kMockRedisModule, + EXPECT_CALL(*kMockValkeyModule, MallocUsableSize(reinterpret_cast(0x2000))) .WillRepeatedly(testing::Return(96)); inner_ptr = __wrap_malloc(75); @@ -478,9 +478,9 @@ TEST_F(MemoryAllocationTest, IsolatedMemoryScopeFreeIsolation) { { IsolatedMemoryScope scope{outer_pool}; - EXPECT_CALL(*kMockRedisModule, Alloc(112)) + EXPECT_CALL(*kMockValkeyModule, Alloc(112)) .WillOnce(testing::Return(reinterpret_cast(0x1000))); - EXPECT_CALL(*kMockRedisModule, + EXPECT_CALL(*kMockValkeyModule, MallocUsableSize(reinterpret_cast(0x1000))) .WillRepeatedly(testing::Return(128)); outer_ptr = __wrap_malloc(100); @@ -500,9 +500,9 @@ TEST_F(MemoryAllocationTest, IsolatedMemoryScopeFreeIsolation) { { IsolatedMemoryScope scope{inner_pool}; - EXPECT_CALL(*kMockRedisModule, Alloc(80)) + EXPECT_CALL(*kMockValkeyModule, Alloc(80)) .WillOnce(testing::Return(reinterpret_cast(0x2000))); - EXPECT_CALL(*kMockRedisModule, + EXPECT_CALL(*kMockValkeyModule, MallocUsableSize(reinterpret_cast(0x2000))) .WillRepeatedly(testing::Return(96)); inner_ptr = __wrap_malloc(75); @@ -524,7 +524,7 @@ TEST_F(MemoryAllocationTest, IsolatedMemoryScopeFreeIsolation) { { IsolatedMemoryScope inner_scope{inner_pool}; - EXPECT_CALL(*kMockRedisModule, Free(reinterpret_cast(0x2000))) + EXPECT_CALL(*kMockValkeyModule, Free(reinterpret_cast(0x2000))) .Times(1); __wrap_free(inner_ptr); @@ -539,7 +539,7 @@ TEST_F(MemoryAllocationTest, IsolatedMemoryScopeFreeIsolation) { EXPECT_EQ(outer_pool.GetUsage(), 128); EXPECT_EQ(inner_pool.GetUsage(), 0); - EXPECT_CALL(*kMockRedisModule, Free(reinterpret_cast(0x1000))) + EXPECT_CALL(*kMockValkeyModule, Free(reinterpret_cast(0x1000))) .Times(1); __wrap_free(outer_ptr); @@ -570,9 +570,9 @@ TEST_F(MemoryAllocationTest, NestedMemoryScopeAllocation) { { NestedMemoryScope outer_scope{outer_pool}; - EXPECT_CALL(*kMockRedisModule, Alloc(112)) + EXPECT_CALL(*kMockValkeyModule, Alloc(112)) .WillOnce(testing::Return(reinterpret_cast(0x1000))); - EXPECT_CALL(*kMockRedisModule, + EXPECT_CALL(*kMockValkeyModule, MallocUsableSize(reinterpret_cast(0x1000))) .WillRepeatedly(testing::Return(128)); outer_ptr = __wrap_malloc(100); @@ -584,9 +584,9 @@ TEST_F(MemoryAllocationTest, NestedMemoryScopeAllocation) { { NestedMemoryScope inner_scope{inner_pool}; - EXPECT_CALL(*kMockRedisModule, Alloc(80)) + EXPECT_CALL(*kMockValkeyModule, Alloc(80)) .WillOnce(testing::Return(reinterpret_cast(0x2000))); - EXPECT_CALL(*kMockRedisModule, + EXPECT_CALL(*kMockValkeyModule, MallocUsableSize(reinterpret_cast(0x2000))) .WillRepeatedly(testing::Return(96)); inner_ptr = __wrap_malloc(75); @@ -625,9 +625,9 @@ TEST_F(MemoryAllocationTest, NestedMemoryScopeFree) { { NestedMemoryScope scope{outer_pool}; - EXPECT_CALL(*kMockRedisModule, Alloc(112)) + EXPECT_CALL(*kMockValkeyModule, Alloc(112)) .WillOnce(testing::Return(reinterpret_cast(0x1000))); - EXPECT_CALL(*kMockRedisModule, + EXPECT_CALL(*kMockValkeyModule, MallocUsableSize(reinterpret_cast(0x1000))) .WillRepeatedly(testing::Return(128)); outer_ptr = __wrap_malloc(100); @@ -642,9 +642,9 @@ TEST_F(MemoryAllocationTest, NestedMemoryScopeFree) { { NestedMemoryScope scope{inner_pool}; - EXPECT_CALL(*kMockRedisModule, Alloc(80)) + EXPECT_CALL(*kMockValkeyModule, Alloc(80)) .WillOnce(testing::Return(reinterpret_cast(0x2000))); - EXPECT_CALL(*kMockRedisModule, + EXPECT_CALL(*kMockValkeyModule, MallocUsableSize(reinterpret_cast(0x2000))) .WillRepeatedly(testing::Return(96)); inner_ptr = __wrap_malloc(75); @@ -661,7 +661,7 @@ TEST_F(MemoryAllocationTest, NestedMemoryScopeFree) { { NestedMemoryScope inner_scope{inner_pool}; - EXPECT_CALL(*kMockRedisModule, Free(reinterpret_cast(0x2000))) + EXPECT_CALL(*kMockValkeyModule, Free(reinterpret_cast(0x2000))) .Times(1); __wrap_free(inner_ptr); @@ -676,7 +676,7 @@ TEST_F(MemoryAllocationTest, NestedMemoryScopeFree) { EXPECT_EQ(outer_pool.GetUsage(), 128); EXPECT_EQ(inner_pool.GetUsage(), 0); - EXPECT_CALL(*kMockRedisModule, Free(reinterpret_cast(0x1000))) + EXPECT_CALL(*kMockValkeyModule, Free(reinterpret_cast(0x1000))) .Times(1); __wrap_free(outer_ptr); @@ -692,6 +692,101 @@ TEST_F(MemoryAllocationTest, NestedMemoryScopeFree) { EXPECT_EQ(inner_pool.GetUsage(), 0); } +TEST_F(MemoryAllocationTest, DisableMemoryTrackingAllocation) { + vmsdk::UseValkeyAlloc(); + + MemoryPool pool{0}; + + void* first_ptr = nullptr; + void* second_ptr = nullptr; + + EXPECT_EQ(vmsdk::GetUsedMemoryCnt(), 0); + EXPECT_EQ(vmsdk::GetMemoryDelta(), 0); + + { + NestedMemoryScope scope{pool}; + + EXPECT_CALL(*kMockValkeyModule, Alloc(112)) + .WillOnce(testing::Return(reinterpret_cast(0x1000))); + EXPECT_CALL(*kMockValkeyModule, + MallocUsableSize(reinterpret_cast(0x1000))) + .WillRepeatedly(testing::Return(128)); + first_ptr = __wrap_malloc(100); + + EXPECT_EQ(vmsdk::GetUsedMemoryCnt(), 128); + EXPECT_EQ(vmsdk::GetMemoryDelta(), 128); + EXPECT_EQ(pool.GetUsage(), 0); + + DisableMemoryTracking disable_tracking; + + EXPECT_CALL(*kMockValkeyModule, Alloc(80)) + .WillOnce(testing::Return(reinterpret_cast(0x2000))); + EXPECT_CALL(*kMockValkeyModule, + MallocUsableSize(reinterpret_cast(0x2000))) + .WillRepeatedly(testing::Return(96)); + second_ptr = __wrap_malloc(75); + + EXPECT_EQ(vmsdk::GetUsedMemoryCnt(), 224); + EXPECT_EQ(vmsdk::GetMemoryDelta(), 224); + EXPECT_EQ(pool.GetUsage(), 0); + } + + EXPECT_EQ(vmsdk::GetUsedMemoryCnt(), 224); + EXPECT_EQ(vmsdk::GetMemoryDelta(), 128); + EXPECT_EQ(pool.GetUsage(), 128); + + __wrap_free(first_ptr); + __wrap_free(second_ptr); + + vmsdk::ResetValkeyAlloc(); +} + +TEST_F(MemoryAllocationTest, DisableMemoryTrackingFree) { + vmsdk::UseValkeyAlloc(); + + EXPECT_EQ(vmsdk::GetUsedMemoryCnt(), 0); + EXPECT_EQ(vmsdk::GetMemoryDelta(), 0); + + MemoryPool pool{0}; + void* ptr = nullptr; + + // Allocate pool + { + NestedMemoryScope scope{pool}; + + EXPECT_CALL(*kMockValkeyModule, Alloc(112)) + .WillOnce(testing::Return(reinterpret_cast(0x1000))); + EXPECT_CALL(*kMockValkeyModule, + MallocUsableSize(reinterpret_cast(0x1000))) + .WillRepeatedly(testing::Return(128)); + ptr = __wrap_malloc(100); + } + + EXPECT_EQ(vmsdk::GetUsedMemoryCnt(), 128); + EXPECT_EQ(vmsdk::GetMemoryDelta(), 128); + EXPECT_EQ(pool.GetUsage(), 128); + + { + NestedMemoryScope scope{pool}; + + DisableMemoryTracking disable_scope; + + EXPECT_CALL(*kMockValkeyModule, Free(reinterpret_cast(0x1000))) + .Times(1); + __wrap_free(ptr); + + EXPECT_EQ(vmsdk::GetUsedMemoryCnt(), 0); + EXPECT_EQ(vmsdk::GetMemoryDelta(), 0); + EXPECT_EQ(pool.GetUsage(), 128); + } + + EXPECT_EQ(vmsdk::GetUsedMemoryCnt(), 0); + EXPECT_EQ(vmsdk::GetMemoryDelta(), 128); + EXPECT_EQ(pool.GetUsage(), 128); + + vmsdk::ResetValkeyAlloc(); +} + #endif // TESTING_TMP_DISABLED } // namespace