Skip to content

Commit 5e8a13e

Browse files
committed
Allow omitting temperature in get_llm for models that do not support the parameter
1 parent be763a5 commit 5e8a13e

File tree

2 files changed

+117
-12
lines changed

2 files changed

+117
-12
lines changed

django_ai_assistant/helpers/assistants.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,12 @@ class AIAssistant(abc.ABC): # noqa: F821
7171
Should be a valid model name from OpenAI, because the default `get_llm` method uses OpenAI.\n
7272
`get_llm` can be overridden to use a different LLM implementation.
7373
"""
74-
temperature: float = 1.0
75-
"""Temperature to use for the assistant LLM model.\nDefaults to `1.0`."""
74+
temperature: float | None = 1.0
75+
"""Temperature to use for the assistant LLM model.\n
76+
Defaults to `1.0`.\n
77+
When `None`, the temperature parameter is omitted when constructing the BaseChatModel
78+
in the `get_llm` method.
79+
"""
7680
tool_max_concurrency: int = 1
7781
"""Maximum number of tools to run concurrently / in parallel.\nDefaults to `1` (no concurrency)."""
7882
has_rag: bool = False
@@ -238,14 +242,16 @@ def get_model(self) -> str:
238242
"""
239243
return self.model
240244

241-
def get_temperature(self) -> float:
245+
def get_temperature(self) -> float | None:
242246
"""Get the temperature to use for the assistant LLM model.
243247
By default, this is the `temperature` attribute, which is `1.0` by default.\n
244248
Used by the `get_llm` method to create the LLM instance.\n
245-
Override the `temperature` attribute or this method to use a different temperature.
249+
Override the `temperature` attribute or this method to use a different temperature.\n
250+
Returning `None` is a valid option, particularly for models that do not support
251+
temperature control, allowing the parameter to be omitted in the `get_llm` method.\n
246252
247-
Returns:
248-
float: The temperature to use for the assistant LLM model.
253+
Returns:
254+
float | None: The temperature to use for the assistant LLM model.
249255
"""
250256
return self.temperature
251257

@@ -271,11 +277,18 @@ def get_llm(self) -> BaseChatModel:
271277
model = self.get_model()
272278
temperature = self.get_temperature()
273279
model_kwargs = self.get_model_kwargs()
274-
return ChatOpenAI(
275-
model=model,
276-
temperature=temperature,
277-
model_kwargs=model_kwargs,
278-
)
280+
281+
if temperature is not None:
282+
return ChatOpenAI(
283+
model=model,
284+
temperature=temperature,
285+
model_kwargs=model_kwargs,
286+
)
287+
else:
288+
return ChatOpenAI(
289+
model=model,
290+
model_kwargs=model_kwargs,
291+
)
279292

280293
def get_structured_output_llm(self) -> Runnable:
281294
"""Get the LLM model to use for the structured output.

tests/test_helpers/test_assistants.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ def test_AIAssistant_invoke():
121121
id="4",
122122
),
123123
HumanMessage(
124-
content="What about tomorrow?", additional_kwargs={}, response_metadata={}, id="5"
124+
content="What about tomorrow?",
125+
additional_kwargs={},
126+
response_metadata={},
127+
id="5",
125128
),
126129
AIMessage(
127130
content="",
@@ -273,6 +276,95 @@ def tool_a(self, foo: str) -> str:
273276
]
274277

275278

279+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
280+
def test_AIAssistant_get_llm_default_temperature(mock_chat_openai):
281+
class DefaultTempAssistant(AIAssistant):
282+
id = "default_temp_assistant" # noqa: A003
283+
name = "Default Temp Assistant"
284+
instructions = "Instructions"
285+
model = "gpt-test"
286+
287+
assistant = DefaultTempAssistant()
288+
assistant.get_llm()
289+
290+
mock_chat_openai.assert_called_once_with(
291+
model="gpt-test",
292+
temperature=assistant.temperature,
293+
model_kwargs={},
294+
)
295+
AIAssistant.clear_cls_registry()
296+
297+
298+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
299+
def test_AIAssistant_get_llm_custom_temperature(mock_chat_openai):
300+
custom_temperature = 0.5
301+
302+
class CustomTempAssistant(AIAssistant):
303+
id = "custom_temp_assistant" # noqa: A003
304+
name = "Custom Temp Assistant"
305+
instructions = "Instructions"
306+
model = "gpt-test"
307+
temperature = custom_temperature
308+
309+
assistant = CustomTempAssistant()
310+
assistant.get_llm()
311+
312+
mock_chat_openai.assert_called_once_with(
313+
model="gpt-test",
314+
temperature=custom_temperature,
315+
model_kwargs={},
316+
)
317+
AIAssistant.clear_cls_registry()
318+
319+
320+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
321+
def test_AIAssistant_get_llm_override_get_temperature(mock_chat_openai):
322+
custom_temperature = 0.5
323+
324+
class OverrideGetTempAssistant(AIAssistant):
325+
id = "override_temp_assistant" # noqa: A003
326+
name = "Override Temp Assistant"
327+
instructions = "Instructions"
328+
model = "gpt-test"
329+
330+
def get_temperature(self) -> float | None:
331+
return custom_temperature
332+
333+
assistant = OverrideGetTempAssistant()
334+
assistant.get_llm()
335+
336+
mock_chat_openai.assert_called_once_with(
337+
model="gpt-test",
338+
temperature=custom_temperature,
339+
model_kwargs={},
340+
)
341+
AIAssistant.clear_cls_registry()
342+
343+
344+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
345+
def test_AIAssistant_get_llm_none_temperature(mock_chat_openai):
346+
class NoneTempAssistant(AIAssistant):
347+
id = "none_temp_assistant" # noqa: A003
348+
name = "None Temp Assistant"
349+
instructions = "Instructions"
350+
model = "gpt-test"
351+
352+
def get_temperature(self) -> float | None:
353+
return None
354+
355+
assistant = NoneTempAssistant()
356+
assistant.get_llm()
357+
358+
mock_chat_openai.assert_called_once_with(
359+
model="gpt-test",
360+
model_kwargs={},
361+
)
362+
_, call_kwargs = mock_chat_openai.call_args
363+
assert "temperature" not in call_kwargs
364+
365+
AIAssistant.clear_cls_registry() # Clean up registry
366+
367+
276368
@pytest.mark.vcr
277369
def test_AIAssistant_pydantic_structured_output():
278370
from pydantic import BaseModel

0 commit comments

Comments
 (0)