Skip to content

Commit f760597

Browse files
authored
Merge pull request #197 from vintasoftware/feat/get_llm-models-without-temperature
Allow omitting `temperature` in `get_llm` for models that do not support it
2 parents 2006aae + 7305397 commit f760597

File tree

2 files changed

+141
-11
lines changed

2 files changed

+141
-11
lines changed

django_ai_assistant/helpers/assistants.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,12 @@ class AIAssistant(abc.ABC): # noqa: F821
7171
Should be a valid model name from OpenAI, because the default `get_llm` method uses OpenAI.\n
7272
`get_llm` can be overridden to use a different LLM implementation.
7373
"""
74-
temperature: float = 1.0
75-
"""Temperature to use for the assistant LLM model.\nDefaults to `1.0`."""
74+
temperature: float | None = 1.0
75+
"""Temperature to use for the assistant LLM model.\n
76+
Defaults to `1.0`.\n
77+
When `None`, the temperature parameter is omitted when constructing the BaseChatModel
78+
in the `get_llm` method.
79+
"""
7680
tool_max_concurrency: int = 1
7781
"""Maximum number of tools to run concurrently / in parallel.\nDefaults to `1` (no concurrency)."""
7882
has_rag: bool = False
@@ -238,14 +242,16 @@ def get_model(self) -> str:
238242
"""
239243
return self.model
240244

241-
def get_temperature(self) -> float:
245+
def get_temperature(self) -> float | None:
242246
"""Get the temperature to use for the assistant LLM model.
243247
By default, this is the `temperature` attribute, which is `1.0` by default.\n
244248
Used by the `get_llm` method to create the LLM instance.\n
245-
Override the `temperature` attribute or this method to use a different temperature.
249+
Override the `temperature` attribute or this method to use a different temperature.\n
250+
Returning `None` is a valid option, particularly for models that do not support
251+
temperature control, allowing the parameter to be omitted in the `get_llm` method.\n
246252
247253
Returns:
248-
float: The temperature to use for the assistant LLM model.
254+
float | None: The temperature to use for the assistant LLM model.
249255
"""
250256
return self.temperature
251257

@@ -271,11 +277,18 @@ def get_llm(self) -> BaseChatModel:
271277
model = self.get_model()
272278
temperature = self.get_temperature()
273279
model_kwargs = self.get_model_kwargs()
274-
return ChatOpenAI(
275-
model=model,
276-
temperature=temperature,
277-
model_kwargs=model_kwargs,
278-
)
280+
281+
if temperature is not None:
282+
return ChatOpenAI(
283+
model=model,
284+
temperature=temperature,
285+
model_kwargs=model_kwargs,
286+
)
287+
else:
288+
return ChatOpenAI(
289+
model=model,
290+
model_kwargs=model_kwargs,
291+
)
279292

280293
def get_structured_output_llm(self) -> Runnable:
281294
"""Get the LLM model to use for the structured output.

tests/test_helpers/test_assistants.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ def test_AIAssistant_invoke():
121121
id="4",
122122
),
123123
HumanMessage(
124-
content="What about tomorrow?", additional_kwargs={}, response_metadata={}, id="5"
124+
content="What about tomorrow?",
125+
additional_kwargs={},
126+
response_metadata={},
127+
id="5",
125128
),
126129
AIMessage(
127130
content="",
@@ -273,6 +276,120 @@ def tool_a(self, foo: str) -> str:
273276
]
274277

275278

279+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
280+
def test_AIAssistant_get_llm_default_temperature(mock_chat_openai):
281+
class DefaultTempAssistant(AIAssistant):
282+
id = "default_temp_assistant" # noqa: A003
283+
name = "Default Temp Assistant"
284+
instructions = "Instructions"
285+
model = "gpt-test"
286+
287+
assistant = DefaultTempAssistant()
288+
assistant.get_llm()
289+
290+
mock_chat_openai.assert_called_once_with(
291+
model="gpt-test",
292+
temperature=1.0,
293+
model_kwargs={},
294+
)
295+
296+
AIAssistant.clear_cls_registry()
297+
298+
299+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
300+
def test_AIAssistant_get_llm_custom_float_temperature(mock_chat_openai):
301+
custom_temperature = 0.5
302+
303+
class CustomFloatTempAssistant(AIAssistant):
304+
id = "custom_float_temp_assistant" # noqa: A003
305+
name = "Custom Float Temp Assistant"
306+
instructions = "Instructions"
307+
model = "gpt-test"
308+
temperature = custom_temperature
309+
310+
assistant = CustomFloatTempAssistant()
311+
assistant.get_llm()
312+
313+
mock_chat_openai.assert_called_once_with(
314+
model="gpt-test",
315+
temperature=custom_temperature,
316+
model_kwargs={},
317+
)
318+
319+
AIAssistant.clear_cls_registry()
320+
321+
322+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
323+
def test_AIAssistant_get_llm_override_get_temperature_with_float(mock_chat_openai):
324+
custom_temperature = 0.5
325+
326+
class OverrideGetFloatTempAssistant(AIAssistant):
327+
id = "override_get_float_temp_assistant" # noqa: A003
328+
name = "Override Get Float Temp Assistant"
329+
instructions = "Instructions"
330+
model = "gpt-test"
331+
332+
def get_temperature(self) -> float | None:
333+
return custom_temperature
334+
335+
assistant = OverrideGetFloatTempAssistant()
336+
assistant.get_llm()
337+
338+
mock_chat_openai.assert_called_once_with(
339+
model="gpt-test",
340+
temperature=custom_temperature,
341+
model_kwargs={},
342+
)
343+
344+
AIAssistant.clear_cls_registry()
345+
346+
347+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
348+
def test_AIAssistant_get_llm_custom_none_temperature(mock_chat_openai):
349+
class CustomNoneTempAssistant(AIAssistant):
350+
id = "custom_none_temp_assistant" # noqa: A003
351+
name = "Custom None Temp Assistant"
352+
instructions = "Instructions"
353+
model = "gpt-test"
354+
temperature = None
355+
356+
assistant = CustomNoneTempAssistant()
357+
assistant.get_llm()
358+
359+
mock_chat_openai.assert_called_once_with(
360+
model="gpt-test",
361+
model_kwargs={},
362+
)
363+
_, call_kwargs = mock_chat_openai.call_args
364+
assert "temperature" not in call_kwargs
365+
366+
AIAssistant.clear_cls_registry()
367+
368+
369+
@patch("django_ai_assistant.helpers.assistants.ChatOpenAI")
370+
def test_AIAssistant_get_llm_override_get_temperature_with_none(mock_chat_openai):
371+
class OverrideGetNoneTempAssistant(AIAssistant):
372+
id = "override_get_none_temp_assistant" # noqa: A003
373+
name = "Override Get None Temp Assistant"
374+
instructions = "Instructions"
375+
model = "gpt-test"
376+
377+
def get_temperature(self) -> float | None:
378+
return None
379+
380+
assistant = OverrideGetNoneTempAssistant()
381+
assistant.get_llm()
382+
383+
mock_chat_openai.assert_called_once_with(
384+
model="gpt-test",
385+
model_kwargs={},
386+
)
387+
_, call_kwargs = mock_chat_openai.call_args
388+
assert "temperature" not in call_kwargs
389+
390+
AIAssistant.clear_cls_registry()
391+
392+
276393
@pytest.mark.vcr
277394
def test_AIAssistant_pydantic_structured_output():
278395
from pydantic import BaseModel

0 commit comments

Comments
 (0)