Skip to content

Commit 3f5d61f

Browse files
committed
Add getting LLM interface to niavely work with OpenAI.
Also added tracking tokens.
1 parent 78b56af commit 3f5d61f

File tree

4 files changed

+96
-20
lines changed

4 files changed

+96
-20
lines changed

InterSpec/LlmConversationHistory.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <vector>
2929
#include <string>
3030
#include <chrono>
31+
#include <optional>
3132

3233
#include "external_libs/SpecUtils/3rdparty/nlohmann/json.hpp"
3334

@@ -81,6 +82,13 @@ struct LlmConversationStart {
8182
std::chrono::system_clock::time_point timestamp;
8283
std::string conversationId; // ID for the entire conversation thread
8384

85+
// Token usage tracking from LLM API responses
86+
// Using optional<size_t> because some models/APIs don't provide token usage information,
87+
// making it clear when this data is unavailable rather than defaulting to 0
88+
std::optional<size_t> promptTokens; // Tokens used in the input/prompt
89+
std::optional<size_t> completionTokens; // Tokens generated in the response
90+
std::optional<size_t> totalTokens; // Total tokens used (prompt + completion)
91+
8492
// Nested follow-up responses (assistant responses, tool calls, tool results)
8593
std::vector<LlmConversationResponse> responses;
8694

@@ -131,6 +139,12 @@ class LlmConversationHistory {
131139
/** Add a follow-up response to a specific conversation by conversation ID */
132140
void addFollowUpResponse(const std::string& conversationId, const LlmConversationResponse& response);
133141

142+
/** Add token usage to a specific conversation by conversation ID (accumulates across API calls) */
143+
void addTokenUsage(const std::string& conversationId,
144+
std::optional<int> promptTokens,
145+
std::optional<int> completionTokens,
146+
std::optional<int> totalTokens);
147+
134148
/** Find a conversation start by conversation ID */
135149
LlmConversationStart* findConversationByConversationId(const std::string& conversationId);
136150

@@ -185,4 +199,4 @@ class LlmConversationHistory {
185199
static LlmConversationResponse::Type stringToResponseType(const std::string& str);
186200
};
187201

188-
#endif // LLM_CONVERSATION_HISTORY_H
202+
#endif // LLM_CONVERSATION_HISTORY_H

src/AnalystChecks.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ namespace AnalystChecks
234234

235235
// get fit peak
236236

237-
if( options.source.has_value() )
237+
if( options.source.has_value() && !options.source.value().empty() )
238238
{
239239
const string source = options.source.value();
240240

src/LlmConversationHistory.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,39 @@ void LlmConversationHistory::addFollowUpResponse(const std::string& conversation
147147
}
148148
}
149149

150+
void LlmConversationHistory::addTokenUsage(const std::string& conversationId,
151+
std::optional<int> promptTokens,
152+
std::optional<int> completionTokens,
153+
std::optional<int> totalTokens) {
154+
LlmConversationStart* conversation = findConversationByConversationId(conversationId);
155+
if (conversation) {
156+
// Accumulate token usage across API calls within this conversation
157+
if (promptTokens.has_value() && (promptTokens.value() > 0) ) {
158+
if (conversation->promptTokens.has_value()) {
159+
conversation->promptTokens = conversation->promptTokens.value() + promptTokens.value();
160+
} else {
161+
conversation->promptTokens = static_cast<size_t>( promptTokens.value() );
162+
}
163+
}
164+
165+
if (completionTokens.has_value() && (completionTokens.value() > 0)) {
166+
if (conversation->completionTokens.has_value()) {
167+
conversation->completionTokens = conversation->completionTokens.value() + completionTokens.value();
168+
} else {
169+
conversation->completionTokens = static_cast<size_t>( completionTokens.value() );
170+
}
171+
}
172+
173+
if (totalTokens.has_value() && (totalTokens.value() > 0)) {
174+
if (conversation->totalTokens.has_value()) {
175+
conversation->totalTokens = conversation->totalTokens.value() + totalTokens.value();
176+
} else {
177+
conversation->totalTokens = static_cast<size_t>( totalTokens.value() );
178+
}
179+
}
180+
}
181+
}
182+
150183
LlmConversationStart* LlmConversationHistory::findConversationByConversationId(const std::string& conversationId) {
151184
for (auto& conv : *m_conversations) {
152185
if (conv.conversationId == conversationId) {
@@ -216,7 +249,8 @@ nlohmann::json LlmConversationHistory::toApiFormat() const {
216249
responseMsg["role"] = "assistant";
217250
responseMsg["tool_calls"] = json::array();
218251
json toolCall;
219-
toolCall["id"] = conv.conversationId + ":" + response.invocationId;
252+
// Use just the invocationId to keep within OpenAI's 40-character limit
253+
toolCall["id"] = response.invocationId;
220254
toolCall["type"] = "function";
221255
toolCall["function"]["name"] = response.toolName;
222256
toolCall["function"]["arguments"] = response.toolParameters.dump();
@@ -226,7 +260,7 @@ nlohmann::json LlmConversationHistory::toApiFormat() const {
226260

227261
case LlmConversationResponse::Type::ToolResult:
228262
responseMsg["role"] = "tool";
229-
responseMsg["tool_call_id"] = conv.conversationId + ":" + response.invocationId;
263+
responseMsg["tool_call_id"] = response.invocationId;
230264
responseMsg["content"] = response.content;
231265
break;
232266

@@ -464,4 +498,4 @@ LlmConversationResponse::Type LlmConversationHistory::stringToResponseType(const
464498
return LlmConversationResponse::Type::Assistant; // Default fallback
465499
}
466500

467-
#endif // USE_LLM_INTERFACE
501+
#endif // USE_LLM_INTERFACE

src/LlmInterface.cpp

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,30 @@ void LlmInterface::handleApiResponse(const std::string& response) {
387387
try {
388388
json responseJson = json::parse(response);
389389

390+
// Parse and accumulate token usage information if available
391+
if (responseJson.contains("usage") && m_history && !m_currentConversationId.empty()) {
392+
const auto& usage = responseJson["usage"];
393+
394+
std::optional<int> promptTokens, completionTokens, totalTokens;
395+
if (usage.contains("prompt_tokens") && usage["prompt_tokens"].is_number())
396+
promptTokens = usage["prompt_tokens"].get<int>();
397+
if (usage.contains("completion_tokens") && usage["completion_tokens"].is_number())
398+
completionTokens = usage["completion_tokens"].get<int>();
399+
if (usage.contains("total_tokens") && usage["total_tokens"].is_number())
400+
totalTokens = usage["total_tokens"].get<int>();
401+
402+
// Accumulate token usage for this conversation
403+
m_history->addTokenUsage(m_currentConversationId, promptTokens, completionTokens, totalTokens);
404+
405+
if (completionTokens.has_value()) {
406+
cout << "=== Token Usage This Call ===" << endl;
407+
cout << "Prompt tokens: " << (promptTokens.has_value() ? std::to_string(promptTokens.value()) : "N/A") << endl;
408+
cout << "Completion tokens: " << completionTokens.value() << endl;
409+
cout << "Total tokens: " << (totalTokens.has_value() ? std::to_string(totalTokens.value()) : "N/A") << endl;
410+
cout << "=============================" << endl;
411+
}
412+
}
413+
390414
if (responseJson.contains("choices") && !responseJson["choices"].empty()) {
391415
json choice = responseJson["choices"][0];
392416
if (choice.contains("message")) {
@@ -396,23 +420,15 @@ void LlmInterface::handleApiResponse(const std::string& response) {
396420
if( message.contains("content") && message["content"].is_string() )
397421
content = message["content"];
398422

399-
400423
if (role == "assistant") {
401424
// Extract thinking content and clean content
402425
auto [cleanContent, thinkingContent] = extractThinkingAndContent(content);
403-
404-
cout
405-
<< "=== Start Cleaned Response Content ===" << endl
406-
<< cleanContent
407-
<< "\n=== End Cleaned Response Content ===" << endl
408-
<< endl;
409-
426+
410427
// Add assistant message to history with thinking content
411428
m_history->addAssistantMessageWithThinking(cleanContent, thinkingContent, m_currentConversationId);
412429

413430
// Handle structured tool calls first (OpenAI format)
414431
if (message.contains("tool_calls")) {
415-
cout << "Found structured tool_calls" << endl;
416432
executeToolCalls(message["tool_calls"]);
417433
} else {
418434
// Parse content for text-based tool requests (use cleaned content)
@@ -668,7 +684,15 @@ nlohmann::json LlmInterface::buildMessagesArray(const std::string& userMessage,
668684

669685
json request;
670686
request["model"] = m_config->llmApi.model;
671-
request["max_tokens"] = m_config->llmApi.maxTokens;
687+
688+
// Use max_completion_tokens for newer OpenAI models, max_tokens for others
689+
string modelName = m_config->llmApi.model;
690+
if (modelName.find("gpt-4") != string::npos || modelName.find("gpt-3.5") != string::npos ||
691+
modelName.find("o1") != string::npos || modelName.find("gpt-5") != string::npos) {
692+
request["max_completion_tokens"] = m_config->llmApi.maxTokens;
693+
} else {
694+
request["max_tokens"] = m_config->llmApi.maxTokens;
695+
}
672696

673697
json messages = json::array();
674698

@@ -684,8 +708,10 @@ nlohmann::json LlmInterface::buildMessagesArray(const std::string& userMessage,
684708
if (!m_history->isEmpty()) {
685709
json historyMessages = m_history->toApiFormat();
686710
cout << "=== Including " << historyMessages.size() << " history messages in request ===" << endl;
711+
687712
for (size_t i = 0; i < historyMessages.size(); ++i) {
688713
const auto& msg = historyMessages[i];
714+
689715
cout << " " << i << ". " << msg["role"].get<string>() << ": "
690716
<< (msg.contains("content") ? msg["content"].get<string>().substr(0, 50) + "..." : "tool_call") << endl;
691717
messages.push_back(msg);
@@ -754,7 +780,7 @@ void LlmInterface::setupJavaScriptBridge() {
754780
// Set up the JavaScript function to handle HTTP requests
755781
string jsCode = R"(
756782
window.llmHttpRequest = function(endpoint, requestJsonString, bearerToken, requestId) {
757-
console.log('LLM HTTP Request to:', endpoint, 'For requestID', requestId);
783+
//console.log('LLM HTTP Request to:', endpoint, 'For requestID', requestId);
758784
//console.log('Request data:', requestJsonString);
759785
760786
var headers = {
@@ -779,12 +805,12 @@ void LlmInterface::setupJavaScriptBridge() {
779805
signal: controller.signal
780806
})
781807
.then(function(response) {
782-
console.log('LLM Response status:', response.status);
808+
//console.log('LLM Response status:', response.status);
783809
return response.text();
784810
})
785811
.then(function(responseText) {
786812
//console.log('LLM Response:', responseText);
787-
console.log( 'Got LLM Response text' );
813+
//console.log( 'Got LLM Response text', responseText );
788814
789815
// Clear the timeout since we got a response
790816
clearTimeout(timeoutId);
@@ -820,7 +846,7 @@ void LlmInterface::setupJavaScriptBridge() {
820846
// Set up the response callback using JSignal to emit signal to C++
821847
string callbackJs =
822848
"window.llmResponseCallback = function(response, requestId) { "
823-
"console.log('Emitting signal to C++ with response length:', response.length, 'requestId:', requestId); "
849+
//"console.log('Emitting signal to C++ with response length:', response.length, 'requestId:', requestId); "
824850
"" + m_responseSignal->createCall("response", "requestId") + ";"
825851
"};";
826852

@@ -830,6 +856,8 @@ void LlmInterface::setupJavaScriptBridge() {
830856
}
831857

832858
void LlmInterface::handleJavaScriptResponse(std::string response, int requestId) {
859+
860+
cout << ":handleJavaScriptResponse: " << response << endl << endl << endl;
833861
try {
834862
// Find and remove the pending request
835863
PendingRequest pendingRequest;
@@ -842,7 +870,7 @@ void LlmInterface::handleJavaScriptResponse(std::string response, int requestId)
842870

843871
// Check for errors first
844872
json responseJson = json::parse(response);
845-
if (responseJson.contains("error")) {
873+
if (responseJson.contains("error") && !responseJson["error"].is_null()) {
846874
string errorMsg = "LLM API Error: " + responseJson["error"].dump(2);
847875
cout << errorMsg << endl;
848876

0 commit comments

Comments
 (0)