Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 10 additions & 74 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::Debug;
use std::path::PathBuf;
Expand Down Expand Up @@ -829,7 +828,7 @@ impl Session {
history.record_items(std::iter::once(response_item));
}
RolloutItem::Compacted(compacted) => {
let snapshot = history.contents();
let snapshot = history.get_history();
let user_messages = collect_user_messages(&snapshot);
let rebuilt = build_compacted_history(
self.build_initial_context(turn_context),
Expand All @@ -841,7 +840,7 @@ impl Session {
_ => {}
}
}
history.contents()
history.get_history()
}

/// Append ResponseItems to the in-memory conversation history only.
Expand Down Expand Up @@ -891,7 +890,7 @@ impl Session {
}

pub(crate) async fn history_snapshot(&self) -> Vec<ResponseItem> {
let state = self.state.lock().await;
let mut state = self.state.lock().await;
state.history_snapshot()
}

Expand Down Expand Up @@ -988,16 +987,6 @@ impl Session {
self.send_event(turn_context, event).await;
}

/// Build the full turn input by concatenating the current conversation
/// history with additional items for this turn.
pub async fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
let history = {
let state = self.state.lock().await;
state.history_snapshot()
};
[history, extra].concat()
}

/// Returns the input if there was no task running to inject into
pub async fn inject_input(&self, input: Vec<UserInput>) -> Result<(), Vec<UserInput>> {
let mut active = self.active_turn.lock().await;
Expand Down Expand Up @@ -1500,6 +1489,7 @@ pub(crate) async fn run_task(
// model sees a fresh conversation without the parent session's history.
// For normal turns, continue recording to the session history as before.
let is_review_mode = turn_context.is_review_mode;
// TODO:(aibrahim): review thread should be a conversation history type.
let mut review_thread_history: Vec<ResponseItem> = Vec::new();
if is_review_mode {
// Seed review threads with environment context so the model knows the working directory.
Expand Down Expand Up @@ -1544,7 +1534,7 @@ pub(crate) async fn run_task(
review_thread_history.clone()
} else {
sess.record_conversation_items(&pending_input).await;
sess.turn_input_with_history(pending_input).await
sess.history_snapshot().await
};

let turn_input_messages: Vec<String> = turn_input
Expand Down Expand Up @@ -1901,61 +1891,6 @@ async fn try_run_turn(
task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
// call_ids that are part of this response.
let completed_call_ids = prompt
.input
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
ResponseItem::CustomToolCallOutput { call_id, .. } => Some(call_id),
_ => None,
})
.collect::<Vec<_>>();

// call_ids that were pending but are not part of this response.
// This usually happens because the user interrupted the model before we responded to one of its tool calls
// and then the user sent a follow-up message.
let missing_calls = {
prompt
.input
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
ResponseItem::CustomToolCall { call_id, .. } => Some(call_id),
_ => None,
})
.filter_map(|call_id| {
if completed_call_ids.contains(&call_id) {
None
} else {
Some(call_id.clone())
}
})
.map(|call_id| ResponseItem::CustomToolCallOutput {
call_id,
output: "aborted".to_string(),
})
.collect::<Vec<_>>()
};
let prompt: Cow<Prompt> = if missing_calls.is_empty() {
Cow::Borrowed(prompt)
} else {
// Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses.
let input = [missing_calls, prompt.input.clone()].concat();
Cow::Owned(Prompt {
input,
..prompt.clone()
})
};

let rollout_item = RolloutItem::TurnContext(TurnContextItem {
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
Expand All @@ -1964,11 +1899,12 @@ async fn try_run_turn(
effort: turn_context.client.get_reasoning_effort(),
summary: turn_context.client.get_reasoning_summary(),
});

sess.persist_rollout_items(&[rollout_item]).await;
let mut stream = turn_context
.client
.clone()
.stream_with_task_kind(prompt.as_ref(), task_kind)
.stream_with_task_kind(prompt, task_kind)
.or_cancel(&cancellation_token)
.await??;

Expand Down Expand Up @@ -2951,7 +2887,7 @@ mod tests {
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));

let summary1 = "summary one";
let snapshot1 = live_history.contents();
let snapshot1 = live_history.get_history();
let user_messages1 = collect_user_messages(&snapshot1);
let rebuilt1 = build_compacted_history(
session.build_initial_context(turn_context),
Expand Down Expand Up @@ -2984,7 +2920,7 @@ mod tests {
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));

let summary2 = "summary two";
let snapshot2 = live_history.contents();
let snapshot2 = live_history.get_history();
let user_messages2 = collect_user_messages(&snapshot2);
let rebuilt2 = build_compacted_history(
session.build_initial_context(turn_context),
Expand Down Expand Up @@ -3016,7 +2952,7 @@ mod tests {
live_history.record_items(std::iter::once(&assistant3));
rollout_items.push(RolloutItem::ResponseItem(assistant3.clone()));

(rollout_items, live_history.contents())
(rollout_items, live_history.get_history())
}

#[tokio::test]
Expand Down
13 changes: 9 additions & 4 deletions codex-rs/core/src/codex/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::TurnContext;
use super::get_last_assistant_message_from_turn;
use crate::Prompt;
use crate::client_common::ResponseEvent;
use crate::conversation_history::ConversationHistory;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::protocol::AgentMessageEvent;
Expand Down Expand Up @@ -64,9 +65,12 @@ async fn run_compact_task_inner(
input: Vec<UserInput>,
) {
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
let mut turn_input = sess
.turn_input_with_history(vec![initial_input_for_turn.clone().into()])
.await;

let items = sess.history_snapshot().await;

let mut history = ConversationHistory::create_with_items(items);
history.record_items(&[initial_input_for_turn.into()]);

let mut truncated_count = 0usize;

let max_retries = turn_context.client.get_provider().stream_max_retries();
Expand All @@ -83,6 +87,7 @@ async fn run_compact_task_inner(
sess.persist_rollout_items(&[rollout_item]).await;

loop {
let turn_input = history.get_history();
let prompt = Prompt {
input: turn_input.clone(),
..Default::default()
Expand All @@ -107,7 +112,7 @@ async fn run_compact_task_inner(
}
Err(e @ CodexErr::ContextWindowExceeded) => {
if turn_input.len() > 1 {
turn_input.remove(0);
history.remove_last_item();
truncated_count += 1;
retries = 0;
continue;
Expand Down
Loading
Loading