Skip to content

Commit 261a35a

Browse files
committed
Add none temp param test case
1 parent 5e8a13e commit 261a35a

File tree

1 file changed

+42
-17
lines changed

1 file changed

+42
-17
lines changed

tests/test_helpers/test_assistants.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,70 +289,95 @@ class DefaultTempAssistant(AIAssistant):
289289

290290
mock_chat_openai.assert_called_once_with(
291291
model="gpt-test",
292-
temperature=assistant.temperature,
292+
temperature=1.0,
293293
model_kwargs={},
294294
)
295+
295296
AIAssistant.clear_cls_registry()
296297

297298

298299
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
299-
def test_AIAssistant_get_llm_custom_temperature(mock_chat_openai):
300+
def test_AIAssistant_get_llm_custom_float_temperature(mock_chat_openai):
300301
custom_temperature = 0.5
301302

302-
class CustomTempAssistant(AIAssistant):
303-
id = "custom_temp_assistant" # noqa: A003
304-
name = "Custom Temp Assistant"
303+
class CustomFloatTempAssistant(AIAssistant):
304+
id = "custom_float_temp_assistant" # noqa: A003
305+
name = "Custom Float Temp Assistant"
305306
instructions = "Instructions"
306307
model = "gpt-test"
307308
temperature = custom_temperature
308309

309-
assistant = CustomTempAssistant()
310+
assistant = CustomFloatTempAssistant()
310311
assistant.get_llm()
311312

312313
mock_chat_openai.assert_called_once_with(
313314
model="gpt-test",
314315
temperature=custom_temperature,
315316
model_kwargs={},
316317
)
318+
317319
AIAssistant.clear_cls_registry()
318320

319321

320322
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
321-
def test_AIAssistant_get_llm_override_get_temperature(mock_chat_openai):
323+
def test_AIAssistant_get_llm_override_get_temperature_with_float(mock_chat_openai):
322324
custom_temperature = 0.5
323325

324-
class OverrideGetTempAssistant(AIAssistant):
325-
id = "override_temp_assistant" # noqa: A003
326-
name = "Override Temp Assistant"
326+
class OverrideGetFloatTempAssistant(AIAssistant):
327+
id = "override_get_float_temp_assistant" # noqa: A003
328+
name = "Override Get Float Temp Assistant"
327329
instructions = "Instructions"
328330
model = "gpt-test"
329331

330332
def get_temperature(self) -> float | None:
331333
return custom_temperature
332334

333-
assistant = OverrideGetTempAssistant()
335+
assistant = OverrideGetFloatTempAssistant()
334336
assistant.get_llm()
335337

336338
mock_chat_openai.assert_called_once_with(
337339
model="gpt-test",
338340
temperature=custom_temperature,
339341
model_kwargs={},
340342
)
343+
344+
AIAssistant.clear_cls_registry()
345+
346+
347+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
348+
def test_AIAssistant_get_llm_custom_none_temperature(mock_chat_openai):
349+
class CustomNoneTempAssistant(AIAssistant):
350+
id = "custom_none_temp_assistant" # noqa: A003
351+
name = "Custom None Temp Assistant"
352+
instructions = "Instructions"
353+
model = "gpt-test"
354+
temperature = None
355+
356+
assistant = CustomNoneTempAssistant()
357+
assistant.get_llm()
358+
359+
mock_chat_openai.assert_called_once_with(
360+
model="gpt-test",
361+
model_kwargs={},
362+
)
363+
_, call_kwargs = mock_chat_openai.call_args
364+
assert "temperature" not in call_kwargs
365+
341366
AIAssistant.clear_cls_registry()
342367

343368

344369
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
345-
def test_AIAssistant_get_llm_none_temperature(mock_chat_openai):
346-
class NoneTempAssistant(AIAssistant):
347-
id = "none_temp_assistant" # noqa: A003
348-
name = "None Temp Assistant"
370+
def test_AIAssistant_get_llm_override_get_temperature_with_none(mock_chat_openai):
371+
class OverrideGetNoneTempAssistant(AIAssistant):
372+
id = "override_get_none_temp_assistant" # noqa: A003
373+
name = "Override Get None Temp Assistant"
349374
instructions = "Instructions"
350375
model = "gpt-test"
351376

352377
def get_temperature(self) -> float | None:
353378
return None
354379

355-
assistant = NoneTempAssistant()
380+
assistant = OverrideGetNoneTempAssistant()
356381
assistant.get_llm()
357382

358383
mock_chat_openai.assert_called_once_with(
@@ -362,7 +387,7 @@ def get_temperature(self) -> float | None:
362387
_, call_kwargs = mock_chat_openai.call_args
363388
assert "temperature" not in call_kwargs
364389

365-
AIAssistant.clear_cls_registry() # Clean up registry
390+
AIAssistant.clear_cls_registry()
366391

367392

368393
@pytest.mark.vcr

0 commit comments

Comments
 (0)