@@ -121,7 +121,10 @@ def test_AIAssistant_invoke():
121121 id = "4" ,
122122 ),
123123 HumanMessage (
124- content = "What about tomorrow?" , additional_kwargs = {}, response_metadata = {}, id = "5"
124+ content = "What about tomorrow?" ,
125+ additional_kwargs = {},
126+ response_metadata = {},
127+ id = "5" ,
125128 ),
126129 AIMessage (
127130 content = "" ,
@@ -273,6 +276,120 @@ def tool_a(self, foo: str) -> str:
273276 ]
274277
275278
279+ @patch ("django_ai_assistant.helpers.assistants.ChatOpenAI" )
280+ def test_AIAssistant_get_llm_default_temperature (mock_chat_openai ):
281+ class DefaultTempAssistant (AIAssistant ):
282+ id = "default_temp_assistant" # noqa: A003
283+ name = "Default Temp Assistant"
284+ instructions = "Instructions"
285+ model = "gpt-test"
286+
287+ assistant = DefaultTempAssistant ()
288+ assistant .get_llm ()
289+
290+ mock_chat_openai .assert_called_once_with (
291+ model = "gpt-test" ,
292+ temperature = 1.0 ,
293+ model_kwargs = {},
294+ )
295+
296+ AIAssistant .clear_cls_registry ()
297+
298+
299+ @patch ("django_ai_assistant.helpers.assistants.ChatOpenAI" )
300+ def test_AIAssistant_get_llm_custom_float_temperature (mock_chat_openai ):
301+ custom_temperature = 0.5
302+
303+ class CustomFloatTempAssistant (AIAssistant ):
304+ id = "custom_float_temp_assistant" # noqa: A003
305+ name = "Custom Float Temp Assistant"
306+ instructions = "Instructions"
307+ model = "gpt-test"
308+ temperature = custom_temperature
309+
310+ assistant = CustomFloatTempAssistant ()
311+ assistant .get_llm ()
312+
313+ mock_chat_openai .assert_called_once_with (
314+ model = "gpt-test" ,
315+ temperature = custom_temperature ,
316+ model_kwargs = {},
317+ )
318+
319+ AIAssistant .clear_cls_registry ()
320+
321+
322+ @patch ("django_ai_assistant.helpers.assistants.ChatOpenAI" )
323+ def test_AIAssistant_get_llm_override_get_temperature_with_float (mock_chat_openai ):
324+ custom_temperature = 0.5
325+
326+ class OverrideGetFloatTempAssistant (AIAssistant ):
327+ id = "override_get_float_temp_assistant" # noqa: A003
328+ name = "Override Get Float Temp Assistant"
329+ instructions = "Instructions"
330+ model = "gpt-test"
331+
332+ def get_temperature (self ) -> float | None :
333+ return custom_temperature
334+
335+ assistant = OverrideGetFloatTempAssistant ()
336+ assistant .get_llm ()
337+
338+ mock_chat_openai .assert_called_once_with (
339+ model = "gpt-test" ,
340+ temperature = custom_temperature ,
341+ model_kwargs = {},
342+ )
343+
344+ AIAssistant .clear_cls_registry ()
345+
346+
347+ @patch ("django_ai_assistant.helpers.assistants.ChatOpenAI" )
348+ def test_AIAssistant_get_llm_custom_none_temperature (mock_chat_openai ):
349+ class CustomNoneTempAssistant (AIAssistant ):
350+ id = "custom_none_temp_assistant" # noqa: A003
351+ name = "Custom None Temp Assistant"
352+ instructions = "Instructions"
353+ model = "gpt-test"
354+ temperature = None
355+
356+ assistant = CustomNoneTempAssistant ()
357+ assistant .get_llm ()
358+
359+ mock_chat_openai .assert_called_once_with (
360+ model = "gpt-test" ,
361+ model_kwargs = {},
362+ )
363+ _ , call_kwargs = mock_chat_openai .call_args
364+ assert "temperature" not in call_kwargs
365+
366+ AIAssistant .clear_cls_registry ()
367+
368+
369+ @patch ("django_ai_assistant.helpers.assistants.ChatOpenAI" )
370+ def test_AIAssistant_get_llm_override_get_temperature_with_none (mock_chat_openai ):
371+ class OverrideGetNoneTempAssistant (AIAssistant ):
372+ id = "override_get_none_temp_assistant" # noqa: A003
373+ name = "Override Get None Temp Assistant"
374+ instructions = "Instructions"
375+ model = "gpt-test"
376+
377+ def get_temperature (self ) -> float | None :
378+ return None
379+
380+ assistant = OverrideGetNoneTempAssistant ()
381+ assistant .get_llm ()
382+
383+ mock_chat_openai .assert_called_once_with (
384+ model = "gpt-test" ,
385+ model_kwargs = {},
386+ )
387+ _ , call_kwargs = mock_chat_openai .call_args
388+ assert "temperature" not in call_kwargs
389+
390+ AIAssistant .clear_cls_registry ()
391+
392+
276393@pytest .mark .vcr
277394def test_AIAssistant_pydantic_structured_output ():
278395 from pydantic import BaseModel
0 commit comments