diff --git a/integration/indexes.py b/integration/indexes.py index 1c85b5e2..81ef9f88 100644 --- a/integration/indexes.py +++ b/integration/indexes.py @@ -8,7 +8,9 @@ import valkey from ft_info_parser import FTInfoParser import logging, json -import struct +import struct, threading +from typing import Union +from valkey_search_test_case import ValkeySearchClusterTestCase def float_to_bytes(flt: list[float]) -> bytes: @@ -151,9 +153,21 @@ def create(self, client: valkey.client): print(f"Creating Index: {cmd}") client.execute_command(*cmd) - def load_data(self, client: valkey.client, rows: int): - print("Loading data to ", client) - for i in range(0, rows): + def load_data(self, client: Union[valkey.client, ValkeySearchClusterTestCase], rows: int): + print("Loading data to ", client, " rows:", rows) + if not isinstance(client, ValkeySearchClusterTestCase): + self.load_data_inner(client, 0, rows, 1) + else: + NUM_CONNECTIONS = 20 + threads = [] + for i in range(NUM_CONNECTIONS): + threads.append(threading.Thread(target=self.load_data_inner, args=(client.new_cluster_client(), i, rows, NUM_CONNECTIONS))) + threads[-1].start() + for i in range(NUM_CONNECTIONS): + threads[i].join() + + def load_data_inner(self, client: valkey.client, start: int, end: int , incr: int): + for i in range(start, end, incr): data = self.make_data(i) if self.type == "HASH": #print("Loading ", self.keyname(i), data) diff --git a/integration/test_shard_down.py b/integration/test_shard_down.py new file mode 100644 index 00000000..5e0015eb --- /dev/null +++ b/integration/test_shard_down.py @@ -0,0 +1,76 @@ +from valkey_search_test_case import * +import valkey, time +import pytest +from valkeytestframework.conftest import resource_port_tracker +from indexes import * +from valkeytestframework.util import waiters +from valkey.cluster import ValkeyCluster, ClusterNode + +def search_command(index: str) -> list[str]: + return [ + "FT.SEARCH", + index, + "*=>[KNN 10 @v $BLOB]", + "PARAMS", + "2", + "BLOB", + float_to_bytes([10.0, 10.0, 10.0]), + ] + +class TestShardDown(ValkeySearchClusterTestCaseDebugMode): + @pytest.mark.parametrize( + "setup_test", [{"replica_count": 1}], indirect=True + ) + def test_shard_down(self): + """ + Validate that query logic works when shard down + """ + client = self.new_cluster_client() + index = Index("test", [Vector("v", 3, type="FLAT")], type="HASH") + index.create(client) + index.load_data(self, 100) + for n in self.nodes: + n.client.execute_command("config set search.enable-partial-results no") + + # + # Mark one shard down. + # + rg = self.get_replication_group(2) + shard_nodes = [rg.primary] + rg.replicas + for n in shard_nodes: + n.client.execute_command("FT._DEBUG PAUSEPOINT SET Search.gRPC") + # + # Execute command + # + with pytest.raises(ResponseError): + r_result = self.get_replication_group(0).get_primary_connection().execute_command(*search_command(index.name)) + print("Result: ", r_result) + + sum = 0 + for n in shard_nodes: + t = n.client.execute_command("FT._DEBUG PAUSEPOINT TEST Search.gRPC") + n.client.execute_command("FT._DEBUG PAUSEPOINT RESET Search.gRPC") + sum += t + assert sum > 0 + + # + # Flip partial results + # + for n in self.nodes: + n.client.execute_command("config set search.enable-partial-results yes") + + for n in shard_nodes: + n.client.execute_command("FT._DEBUG PAUSEPOINT SET Search.gRPC") + + # + # Execute command, no exception + # + r_result = self.get_replication_group(0).get_primary_connection().execute_command(*search_command(index.name)) + print("Result: ", r_result) + + sum = 0 + for n in shard_nodes: + t = n.client.execute_command("FT._DEBUG PAUSEPOINT TEST Search.gRPC") + n.client.execute_command("FT._DEBUG PAUSEPOINT RESET Search.gRPC") + sum += t + assert sum > 0 diff --git a/src/coordinator/server.cc b/src/coordinator/server.cc index c747c802..9b29b816 100644 --- a/src/coordinator/server.cc +++ b/src/coordinator/server.cc @@ -32,11 +32,11 @@ #include "src/index_schema.h" #include "src/indexes/vector_base.h" #include "src/metrics.h" -#include "src/query/fanout_operation_base.h" #include "src/query/response_generator.h" #include "src/query/search.h" #include "src/schema_manager.h" -#include "valkey_search_options.h" +#include "src/valkey_search_options.h" +#include "vmsdk/src/debug.h" #include "vmsdk/src/latency_sampler.h" #include "vmsdk/src/log.h" #include "vmsdk/src/managed_pointers.h" @@ -112,6 +112,7 @@ grpc::ServerUnaryReactor* Service::SearchIndexPartition( grpc::CallbackServerContext* context, const SearchIndexPartitionRequest* request, SearchIndexPartitionResponse* response) { + PAUSEPOINT("Search.gRPC"); GRPCSuspensionGuard guard(GRPCSuspender::Instance()); auto latency_sample = SAMPLE_EVERY_N(100); grpc::ServerUnaryReactor* reactor = context->DefaultReactor(); diff --git a/vmsdk/src/debug.cc b/vmsdk/src/debug.cc index f2a833c0..21f4fc0f 100644 --- a/vmsdk/src/debug.cc +++ b/vmsdk/src/debug.cc @@ -69,7 +69,7 @@ void PausePoint(absl::string_view point, std::source_location location) { } } if (absl::Now() > message_time) { - VMSDK_IO_LOG_EVERY_N_SEC(WARNING, nullptr, 10) + VMSDK_LOG_EVERY_N_SEC(WARNING, nullptr, 10) << "Waiting > 10 seconds at pause point " << point << " Location:" << ToString(location); }