Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions integration/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import valkey
from ft_info_parser import FTInfoParser
import logging, json
import struct
import struct, threading
from typing import Union
from valkey_search_test_case import ValkeySearchClusterTestCase


def float_to_bytes(flt: list[float]) -> bytes:
Expand Down Expand Up @@ -151,9 +153,21 @@ def create(self, client: valkey.client):
print(f"Creating Index: {cmd}")
client.execute_command(*cmd)

def load_data(self, client: valkey.client, rows: int):
print("Loading data to ", client)
for i in range(0, rows):
def load_data(self, client: Union[valkey.client, ValkeySearchClusterTestCase], rows: int):
print("Loading data to ", client, " rows:", rows)
if not isinstance(client, ValkeySearchClusterTestCase):
self.load_data_inner(client, 0, rows, 1)
else:
NUM_CONNECTIONS = 20
threads = []
for i in range(NUM_CONNECTIONS):
threads.append(threading.Thread(target=self.load_data_inner, args=(client.new_cluster_client(), i, rows, NUM_CONNECTIONS)))
threads[-1].start()
for i in range(NUM_CONNECTIONS):
threads[i].join()

def load_data_inner(self, client: valkey.client, start: int, end: int , incr: int):
for i in range(start, end, incr):
data = self.make_data(i)
if self.type == "HASH":
#print("Loading ", self.keyname(i), data)
Expand Down
76 changes: 76 additions & 0 deletions integration/test_shard_down.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from valkey_search_test_case import *
import valkey, time
import pytest
from valkeytestframework.conftest import resource_port_tracker
from indexes import *
from valkeytestframework.util import waiters
from valkey.cluster import ValkeyCluster, ClusterNode

def search_command(index: str) -> list[str]:
return [
"FT.SEARCH",
index,
"*=>[KNN 10 @v $BLOB]",
"PARAMS",
"2",
"BLOB",
float_to_bytes([10.0, 10.0, 10.0]),
]

class TestShardDown(ValkeySearchClusterTestCaseDebugMode):
@pytest.mark.parametrize(
"setup_test", [{"replica_count": 1}], indirect=True
)
def test_shard_down(self):
"""
Validate that query logic works when shard down
"""
client = self.new_cluster_client()
index = Index("test", [Vector("v", 3, type="FLAT")], type="HASH")
index.create(client)
index.load_data(self, 100)
for n in self.nodes:
n.client.execute_command("config set search.enable-partial-results no")

#
# Mark one shard down.
#
rg = self.get_replication_group(2)
shard_nodes = [rg.primary] + rg.replicas
for n in shard_nodes:
n.client.execute_command("FT._DEBUG PAUSEPOINT SET Search.gRPC")
#
# Execute command
#
with pytest.raises(ResponseError):
r_result = self.get_replication_group(0).get_primary_connection().execute_command(*search_command(index.name))
print("Result: ", r_result)

sum = 0
for n in shard_nodes:
t = n.client.execute_command("FT._DEBUG PAUSEPOINT TEST Search.gRPC")
n.client.execute_command("FT._DEBUG PAUSEPOINT RESET Search.gRPC")
sum += t
assert sum > 0

#
# Flip partial results
#
for n in self.nodes:
n.client.execute_command("config set search.enable-partial-results yes")

for n in shard_nodes:
n.client.execute_command("FT._DEBUG PAUSEPOINT SET Search.gRPC")

#
# Execute command, no exception
#
r_result = self.get_replication_group(0).get_primary_connection().execute_command(*search_command(index.name))
print("Result: ", r_result)

sum = 0
for n in shard_nodes:
t = n.client.execute_command("FT._DEBUG PAUSEPOINT TEST Search.gRPC")
n.client.execute_command("FT._DEBUG PAUSEPOINT RESET Search.gRPC")
sum += t
assert sum > 0
5 changes: 3 additions & 2 deletions src/coordinator/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
#include "src/index_schema.h"
#include "src/indexes/vector_base.h"
#include "src/metrics.h"
#include "src/query/fanout_operation_base.h"
#include "src/query/response_generator.h"
#include "src/query/search.h"
#include "src/schema_manager.h"
#include "valkey_search_options.h"
#include "src/valkey_search_options.h"
#include "vmsdk/src/debug.h"
#include "vmsdk/src/latency_sampler.h"
#include "vmsdk/src/log.h"
#include "vmsdk/src/managed_pointers.h"
Expand Down Expand Up @@ -112,6 +112,7 @@ grpc::ServerUnaryReactor* Service::SearchIndexPartition(
grpc::CallbackServerContext* context,
const SearchIndexPartitionRequest* request,
SearchIndexPartitionResponse* response) {
PAUSEPOINT("Search.gRPC");
GRPCSuspensionGuard guard(GRPCSuspender::Instance());
auto latency_sample = SAMPLE_EVERY_N(100);
grpc::ServerUnaryReactor* reactor = context->DefaultReactor();
Expand Down
2 changes: 1 addition & 1 deletion vmsdk/src/debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void PausePoint(absl::string_view point, std::source_location location) {
}
}
if (absl::Now() > message_time) {
VMSDK_IO_LOG_EVERY_N_SEC(WARNING, nullptr, 10)
VMSDK_LOG_EVERY_N_SEC(WARNING, nullptr, 10)
<< "Waiting > 10 seconds at pause point " << point
<< " Location:" << ToString(location);
}
Expand Down
Loading