Skip to content

Commit 44f4ba8

Browse files
tweak pytest config
1 parent 94f27f6 commit 44f4ba8

File tree

3 files changed

+73
-52
lines changed

3 files changed

+73
-52
lines changed

scripts.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,20 @@ def check_mypy():
3030

3131

3232
def test():
33+
test_cmd = ["python", "-m", "pytest", "-n", "auto", "--log-level=CRITICAL"]
3334
# Get any extra arguments passed to the script
3435
extra_args = sys.argv[1:]
35-
if not extra_args:
36-
test_cmd = ["python", "-m", "pytest", "-n", "auto", "--log-level=CRITICAL"]
37-
else:
38-
test_cmd = ["python", "-m", "pytest", "-n", "auto", "--log-level=CRITICAL"] + extra_args
36+
if extra_args:
37+
test_cmd.extend(extra_args)
3938
subprocess.run(test_cmd, check=True)
4039

4140

4241
def test_verbose():
42+
test_cmd = ["python", "-m", "pytest", "-n", "auto", "-vv", "-s", "--log-level=CRITICAL"]
4343
# Get any extra arguments passed to the script
4444
extra_args = sys.argv[1:]
45-
if not extra_args:
46-
test_cmd = ["python", "-m", "pytest", "-n", "auto", "-vv", "-s", "--log-level=CRITICAL"]
47-
else:
48-
test_cmd = ["python", "-m", "pytest", "-n", "auto", "-vv", "-s", "--log-level=CRITICAL"] + extra_args
45+
if extra_args:
46+
test_cmd.extend(extra_args)
4947
subprocess.run(test_cmd, check=True)
5048

5149

tests/conftest.py

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,15 @@ async def async_client(redis_url):
5454
"""
5555
An async Redis client that uses the dynamic `redis_url`.
5656
"""
57-
client = await RedisConnectionFactory.get_async_redis_connection(redis_url)
58-
yield client
59-
try:
60-
await client.aclose()
61-
except RuntimeError as e:
62-
if "Event loop is closed" not in str(e):
63-
raise
57+
async with await RedisConnectionFactory.get_async_redis_connection(
58+
redis_url
59+
) as client:
60+
yield client
61+
# try:
62+
# await client.aclose()
63+
# except RuntimeError as e:
64+
# if "Event loop is closed" not in str(e):
65+
# raise
6466

6567

6668
@pytest.fixture
@@ -70,51 +72,51 @@ def client(redis_url):
7072
"""
7173
conn = RedisConnectionFactory.get_redis_connection(redis_url)
7274
yield conn
73-
conn.close()
75+
# conn.close()
7476

7577

76-
@pytest.fixture
77-
def openai_key():
78-
return os.getenv("OPENAI_API_KEY")
78+
# @pytest.fixture
79+
# def openai_key():
80+
# return os.getenv("OPENAI_API_KEY")
7981

8082

81-
@pytest.fixture
82-
def openai_version():
83-
return os.getenv("OPENAI_API_VERSION")
83+
# @pytest.fixture
84+
# def openai_version():
85+
# return os.getenv("OPENAI_API_VERSION")
8486

8587

86-
@pytest.fixture
87-
def azure_endpoint():
88-
return os.getenv("AZURE_OPENAI_ENDPOINT")
88+
# @pytest.fixture
89+
# def azure_endpoint():
90+
# return os.getenv("AZURE_OPENAI_ENDPOINT")
8991

9092

91-
@pytest.fixture
92-
def cohere_key():
93-
return os.getenv("COHERE_API_KEY")
93+
# @pytest.fixture
94+
# def cohere_key():
95+
# return os.getenv("COHERE_API_KEY")
9496

9597

96-
@pytest.fixture
97-
def mistral_key():
98-
return os.getenv("MISTRAL_API_KEY")
98+
# @pytest.fixture
99+
# def mistral_key():
100+
# return os.getenv("MISTRAL_API_KEY")
99101

100102

101-
@pytest.fixture
102-
def gcp_location():
103-
return os.getenv("GCP_LOCATION")
103+
# @pytest.fixture
104+
# def gcp_location():
105+
# return os.getenv("GCP_LOCATION")
104106

105107

106-
@pytest.fixture
107-
def gcp_project_id():
108-
return os.getenv("GCP_PROJECT_ID")
108+
# @pytest.fixture
109+
# def gcp_project_id():
110+
# return os.getenv("GCP_PROJECT_ID")
109111

110112

111-
@pytest.fixture
112-
def aws_credentials():
113-
return {
114-
"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
115-
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
116-
"aws_region": os.getenv("AWS_REGION", "us-east-1"),
117-
}
113+
# @pytest.fixture
114+
# def aws_credentials():
115+
# return {
116+
# "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
117+
# "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
118+
# "aws_region": os.getenv("AWS_REGION", "us-east-1"),
119+
# }
118120

119121

120122
@pytest.fixture
@@ -179,13 +181,29 @@ def sample_data():
179181
]
180182

181183

182-
@pytest.fixture
183-
def clear_db(redis):
184-
redis.flushall()
185-
yield
186-
redis.flushall()
184+
def pytest_addoption(parser: pytest.Parser) -> None:
185+
parser.addoption(
186+
"--run-api-tests",
187+
action="store_true",
188+
default=False,
189+
help="Run tests that require API keys",
190+
)
187191

188192

189-
@pytest.fixture
190-
def app_name():
191-
return "test_app"
193+
def pytest_configure(config: pytest.Config) -> None:
194+
config.addinivalue_line(
195+
"markers", "requires_api_keys: mark test as requiring API keys"
196+
)
197+
198+
199+
def pytest_collection_modifyitems(
200+
config: pytest.Config, items: list[pytest.Item]
201+
) -> None:
202+
if config.getoption("--run-api-tests"):
203+
return
204+
skip_api = pytest.mark.skip(
205+
reason="Skipping test because API keys are not provided. Use --run-api-tests to run these tests."
206+
)
207+
for item in items:
208+
if item.get_closest_marker("requires_api_keys"):
209+
item.add_marker(skip_api)

tests/integration/test_session_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
1313

1414

15+
@pytest.fixture
16+
def app_name():
17+
return "test_app"
18+
19+
1520
@pytest.fixture
1621
def standard_session(app_name, client):
1722
session = StandardSessionManager(app_name, redis_client=client)

0 commit comments

Comments
 (0)