|
1 | 1 | from typing import TYPE_CHECKING, Any, cast |
2 | 2 |
|
3 | 3 | import pytest |
| 4 | +from openai.types.chat import ChatCompletionMessage |
| 5 | +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function |
4 | 6 |
|
5 | 7 | from cleanlab_tlm.utils.chat import ( |
6 | 8 | _form_prompt_chat_completions_api, |
@@ -1366,13 +1368,13 @@ def test_form_response_string_chat_completions_api_empty_arguments() -> None: |
1366 | 1368 |
|
1367 | 1369 | def test_form_response_string_chat_completions_api_invalid_input() -> None: |
1368 | 1370 | """Test form_response_string_chat_completions_api raises TypeError for invalid input.""" |
1369 | | - with pytest.raises(TypeError, match="Expected response to be a dict, got str"): |
| 1371 | + with pytest.raises(TypeError, match="Expected response to be a dict or ChatCompletionMessage object, got str"): |
1370 | 1372 | form_response_string_chat_completions_api("not a dict") # type: ignore[arg-type] |
1371 | 1373 |
|
1372 | | - with pytest.raises(TypeError, match="Expected response to be a dict, got list"): |
| 1374 | + with pytest.raises(TypeError, match="Expected response to be a dict or ChatCompletionMessage object, got list"): |
1373 | 1375 | form_response_string_chat_completions_api([]) # type: ignore[arg-type] |
1374 | 1376 |
|
1375 | | - with pytest.raises(TypeError, match="Expected response to be a dict, got NoneType"): |
| 1377 | + with pytest.raises(TypeError, match="Expected response to be a dict or ChatCompletionMessage object, got NoneType"): |
1376 | 1378 | form_response_string_chat_completions_api(None) # type: ignore[arg-type] |
1377 | 1379 |
|
1378 | 1380 |
|
@@ -1406,3 +1408,184 @@ def test_form_response_string_chat_completions_api_malformed_tool_calls() -> Non |
1406 | 1408 | with pytest.warns(UserWarning, match="Error formatting tool_calls in response.*Returning content only"): |
1407 | 1409 | result = form_response_string_chat_completions_api(response) |
1408 | 1410 | assert result == "Let me check that." |
| 1411 | + |
| 1412 | + |
| 1413 | +############## ChatCompletionMessage tests ############## |
| 1414 | + |
| 1415 | + |
| 1416 | +def test_form_response_string_chat_completions_api_chatcompletion_message_just_content() -> None: |
| 1417 | + """Test form_response_string_chat_completions_api with ChatCompletionMessage containing just content.""" |
| 1418 | + |
| 1419 | + content = "Hello, how can I help you today?" |
| 1420 | + message = ChatCompletionMessage( |
| 1421 | + role="assistant", |
| 1422 | + content=content, |
| 1423 | + ) |
| 1424 | + result = form_response_string_chat_completions_api(message) |
| 1425 | + assert result == content |
| 1426 | + |
| 1427 | + |
| 1428 | +def test_form_response_string_chat_completions_api_chatcompletion_message_just_tool_calls() -> None: |
| 1429 | + """Test form_response_string_chat_completions_api with ChatCompletionMessage containing just tool calls.""" |
| 1430 | + message = ChatCompletionMessage( |
| 1431 | + role="assistant", |
| 1432 | + content=None, |
| 1433 | + tool_calls=[ |
| 1434 | + ChatCompletionMessageToolCall( |
| 1435 | + id="call_123", |
| 1436 | + function=Function( |
| 1437 | + name="search_restaurants", |
| 1438 | + arguments='{"city": "Tokyo", "cuisine_type": "sushi", "max_price": 150, "dietary_restrictions": ["vegetarian", "gluten-free"], "open_now": true}', |
| 1439 | + ), |
| 1440 | + type="function", |
| 1441 | + ) |
| 1442 | + ], |
| 1443 | + ) |
| 1444 | + expected = ( |
| 1445 | + "<tool_call>\n" |
| 1446 | + "{\n" |
| 1447 | + ' "name": "search_restaurants",\n' |
| 1448 | + ' "arguments": {\n' |
| 1449 | + ' "city": "Tokyo",\n' |
| 1450 | + ' "cuisine_type": "sushi",\n' |
| 1451 | + ' "max_price": 150,\n' |
| 1452 | + ' "dietary_restrictions": [\n' |
| 1453 | + ' "vegetarian",\n' |
| 1454 | + ' "gluten-free"\n' |
| 1455 | + " ],\n" |
| 1456 | + ' "open_now": true\n' |
| 1457 | + " }\n" |
| 1458 | + "}\n" |
| 1459 | + "</tool_call>" |
| 1460 | + ) |
| 1461 | + result = form_response_string_chat_completions_api(message) |
| 1462 | + assert result == expected |
| 1463 | + |
| 1464 | + |
| 1465 | +def test_form_response_string_chat_completions_api_chatcompletion_message_content_and_tool_calls() -> None: |
| 1466 | + """Test form_response_string_chat_completions_api with ChatCompletionMessage containing both content and tool calls.""" |
| 1467 | + message = ChatCompletionMessage( |
| 1468 | + role="assistant", |
| 1469 | + content="I'll check the weather for you.", |
| 1470 | + tool_calls=[ |
| 1471 | + ChatCompletionMessageToolCall( |
| 1472 | + id="call_123", |
| 1473 | + function=Function( |
| 1474 | + name="get_weather", |
| 1475 | + arguments='{"location": "Paris"}', |
| 1476 | + ), |
| 1477 | + type="function", |
| 1478 | + ) |
| 1479 | + ], |
| 1480 | + ) |
| 1481 | + expected = ( |
| 1482 | + "I'll check the weather for you.\n" |
| 1483 | + "<tool_call>\n" |
| 1484 | + "{\n" |
| 1485 | + ' "name": "get_weather",\n' |
| 1486 | + ' "arguments": {\n' |
| 1487 | + ' "location": "Paris"\n' |
| 1488 | + " }\n" |
| 1489 | + "}\n" |
| 1490 | + "</tool_call>" |
| 1491 | + ) |
| 1492 | + result = form_response_string_chat_completions_api(message) |
| 1493 | + assert result == expected |
| 1494 | + |
| 1495 | + |
| 1496 | +def test_form_response_string_chat_completions_api_chatcompletion_message_multiple_tool_calls() -> None: |
| 1497 | + """Test form_response_string_chat_completions_api with ChatCompletionMessage containing multiple tool calls.""" |
| 1498 | + message = ChatCompletionMessage( |
| 1499 | + role="assistant", |
| 1500 | + content="Let me check multiple things for you.", |
| 1501 | + tool_calls=[ |
| 1502 | + ChatCompletionMessageToolCall( |
| 1503 | + id="call_123", |
| 1504 | + function=Function( |
| 1505 | + name="get_weather", |
| 1506 | + arguments='{"location": "Paris"}', |
| 1507 | + ), |
| 1508 | + type="function", |
| 1509 | + ), |
| 1510 | + ChatCompletionMessageToolCall( |
| 1511 | + id="call_456", |
| 1512 | + function=Function( |
| 1513 | + name="get_time", |
| 1514 | + arguments='{"timezone": "UTC"}', |
| 1515 | + ), |
| 1516 | + type="function", |
| 1517 | + ), |
| 1518 | + ], |
| 1519 | + ) |
| 1520 | + expected = ( |
| 1521 | + "Let me check multiple things for you.\n" |
| 1522 | + "<tool_call>\n" |
| 1523 | + "{\n" |
| 1524 | + ' "name": "get_weather",\n' |
| 1525 | + ' "arguments": {\n' |
| 1526 | + ' "location": "Paris"\n' |
| 1527 | + " }\n" |
| 1528 | + "}\n" |
| 1529 | + "</tool_call>\n" |
| 1530 | + "<tool_call>\n" |
| 1531 | + "{\n" |
| 1532 | + ' "name": "get_time",\n' |
| 1533 | + ' "arguments": {\n' |
| 1534 | + ' "timezone": "UTC"\n' |
| 1535 | + " }\n" |
| 1536 | + "}\n" |
| 1537 | + "</tool_call>" |
| 1538 | + ) |
| 1539 | + result = form_response_string_chat_completions_api(message) |
| 1540 | + assert result == expected |
| 1541 | + |
| 1542 | + |
| 1543 | +def test_form_response_string_chat_completions_api_chatcompletion_message_empty_content() -> None: |
| 1544 | + """Test form_response_string_chat_completions_api with ChatCompletionMessage containing empty content.""" |
| 1545 | + message = ChatCompletionMessage( |
| 1546 | + role="assistant", |
| 1547 | + content="", |
| 1548 | + ) |
| 1549 | + expected = "" |
| 1550 | + result = form_response_string_chat_completions_api(message) |
| 1551 | + assert result == expected |
| 1552 | + |
| 1553 | + |
| 1554 | +def test_form_response_string_chat_completions_api_chatcompletion_message_empty_arguments() -> None: |
| 1555 | + """Test form_response_string_chat_completions_api with ChatCompletionMessage containing empty arguments.""" |
| 1556 | + message = ChatCompletionMessage( |
| 1557 | + role="assistant", |
| 1558 | + content="Running action", |
| 1559 | + tool_calls=[ |
| 1560 | + ChatCompletionMessageToolCall( |
| 1561 | + id="call_123", |
| 1562 | + function=Function( |
| 1563 | + name="execute_action", |
| 1564 | + arguments="", |
| 1565 | + ), |
| 1566 | + type="function", |
| 1567 | + ) |
| 1568 | + ], |
| 1569 | + ) |
| 1570 | + expected = ( |
| 1571 | + "Running action\n" |
| 1572 | + "<tool_call>\n" |
| 1573 | + "{\n" |
| 1574 | + ' "name": "execute_action",\n' |
| 1575 | + ' "arguments": {}\n' |
| 1576 | + "}\n" |
| 1577 | + "</tool_call>" |
| 1578 | + ) |
| 1579 | + result = form_response_string_chat_completions_api(message) |
| 1580 | + assert result == expected |
| 1581 | + |
| 1582 | + |
| 1583 | +def test_form_response_string_chat_completions_api_chatcompletion_message_none_content() -> None: |
| 1584 | + """Test form_response_string_chat_completions_api with ChatCompletionMessage containing None content.""" |
| 1585 | + message = ChatCompletionMessage( |
| 1586 | + role="assistant", |
| 1587 | + content=None, |
| 1588 | + ) |
| 1589 | + expected = "" |
| 1590 | + result = form_response_string_chat_completions_api(message) |
| 1591 | + assert result == expected |
0 commit comments