Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 247 additions & 4 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ use codex_protocol::models::ShellToolCallParams;
use codex_protocol::protocol::InitialHistory;

mod compact;
use self::compact::build_compacted_history;
use self::compact::collect_user_messages;

// A convenience extension trait for acquiring mutex locks where poisoning is
// unrecoverable and should abort the program. This avoids scattered `.unwrap()`
Expand Down Expand Up @@ -202,7 +204,7 @@ impl Codex {
config.clone(),
auth_manager.clone(),
tx_event.clone(),
conversation_history.clone(),
conversation_history,
)
.await
.map_err(|e| {
Expand Down Expand Up @@ -559,9 +561,10 @@ impl Session {
let persist = matches!(conversation_history, InitialHistory::Forked(_));

// Always add response items to conversation history
let response_items = conversation_history.get_response_items();
if !response_items.is_empty() {
self.record_into_history(&response_items);
let reconstructed_history =
self.reconstruct_history_from_rollout(turn_context, &rollout_items);
if !reconstructed_history.is_empty() {
self.record_into_history(&reconstructed_history);
}

// If persisting, persist all rollout items as-is (recorder filters)
Expand Down Expand Up @@ -673,6 +676,33 @@ impl Session {
self.persist_rollout_response_items(items).await;
}

fn reconstruct_history_from_rollout(
&self,
turn_context: &TurnContext,
rollout_items: &[RolloutItem],
) -> Vec<ResponseItem> {
let mut history = ConversationHistory::new();
for item in rollout_items {
match item {
RolloutItem::ResponseItem(response_item) => {
history.record_items(std::iter::once(response_item));
}
RolloutItem::Compacted(compacted) => {
let snapshot = history.contents();
let user_messages = collect_user_messages(&snapshot);
let rebuilt = build_compacted_history(
self.build_initial_context(turn_context),
&user_messages,
&compacted.message,
);
history.replace(rebuilt);
}
_ => {}
}
}
history.contents()
}

/// Append ResponseItems to the in-memory conversation history only.
fn record_into_history(&self, items: &[ResponseItem]) {
self.state
Expand Down Expand Up @@ -2991,10 +3021,18 @@ fn convert_call_tool_result_to_function_call_output_payload(
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ConfigOverrides;
use crate::config::ConfigToml;
use crate::protocol::CompactedItem;
use crate::protocol::InitialHistory;
use crate::protocol::ResumedHistory;
use codex_protocol::models::ContentItem;
use mcp_types::ContentBlock;
use mcp_types::TextContent;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration as StdDuration;

fn text_block(s: &str) -> ContentBlock {
Expand All @@ -3005,6 +3043,211 @@ mod tests {
})
}

fn make_session_and_context() -> (Session, TurnContext) {
let (tx_event, _rx_event) = async_channel::unbounded();
let codex_home = tempfile::tempdir().expect("create temp dir");
let config = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.expect("load default test config");
let config = Arc::new(config);
let conversation_id = ConversationId::default();
let client = ModelClient::new(
config.clone(),
None,
config.model_provider.clone(),
config.model_reasoning_effort,
config.model_reasoning_summary,
conversation_id,
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &config.model_family,
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
include_view_image_tool: config.include_view_image_tool,
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
});
let turn_context = TurnContext {
client,
cwd: config.cwd.clone(),
base_instructions: config.base_instructions.clone(),
user_instructions: config.user_instructions.clone(),
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
shell_environment_policy: config.shell_environment_policy.clone(),
tools_config,
};
let session = Session {
conversation_id,
tx_event,
mcp_connection_manager: McpConnectionManager::default(),
session_manager: ExecSessionManager::default(),
unified_exec_manager: UnifiedExecSessionManager::default(),
notify: None,
rollout: Mutex::new(None),
state: Mutex::new(State {
history: ConversationHistory::new(),
..Default::default()
}),
codex_linux_sandbox_exe: None,
user_shell: shell::Shell::Unknown,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
};
(session, turn_context)
}

fn sample_rollout(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned elsewhere, we should put helper functions under tests.

session: &Session,
turn_context: &TurnContext,
) -> (Vec<RolloutItem>, Vec<ResponseItem>) {
let mut rollout_items = Vec::new();
let mut live_history = ConversationHistory::new();

let initial_context = session.build_initial_context(turn_context);
for item in &initial_context {
rollout_items.push(RolloutItem::ResponseItem(item.clone()));
}
live_history.record_items(initial_context.iter());

let user1 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "first user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user1));
rollout_items.push(RolloutItem::ResponseItem(user1.clone()));

let assistant1 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply one".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant1));
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));

let summary1 = "summary one";
let snapshot1 = live_history.contents();
let user_messages1 = collect_user_messages(&snapshot1);
let rebuilt1 = build_compacted_history(
session.build_initial_context(turn_context),
&user_messages1,
summary1,
);
live_history.replace(rebuilt1);
rollout_items.push(RolloutItem::Compacted(CompactedItem {
message: summary1.to_string(),
}));

let user2 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "second user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user2));
rollout_items.push(RolloutItem::ResponseItem(user2.clone()));

let assistant2 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply two".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant2));
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));

let summary2 = "summary two";
let snapshot2 = live_history.contents();
let user_messages2 = collect_user_messages(&snapshot2);
let rebuilt2 = build_compacted_history(
session.build_initial_context(turn_context),
&user_messages2,
summary2,
);
live_history.replace(rebuilt2);
rollout_items.push(RolloutItem::Compacted(CompactedItem {
message: summary2.to_string(),
}));

let user3 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "third user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user3));
rollout_items.push(RolloutItem::ResponseItem(user3.clone()));

let assistant3 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply three".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant3));
rollout_items.push(RolloutItem::ResponseItem(assistant3.clone()));

(rollout_items, live_history.contents())
}

#[test]
fn reconstruct_history_matches_live_compactions() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);

let reconstructed = session.reconstruct_history_from_rollout(&turn_context, &rollout_items);

assert_eq!(expected, reconstructed);
}

#[test]
fn record_initial_history_reconstructs_resumed_transcript() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);

tokio_test::block_on(session.record_initial_history(
&turn_context,
InitialHistory::Resumed(ResumedHistory {
conversation_id: ConversationId::default(),
history: rollout_items,
rollout_path: PathBuf::from("/tmp/resume.jsonl"),
}),
));

let actual = session.state.lock_unchecked().history.contents();
assert_eq!(expected, actual);
}

#[test]
fn record_initial_history_reconstructs_forked_transcript() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);

tokio_test::block_on(
session.record_initial_history(
&turn_context,
InitialHistory::Forked(rollout_items.clone()),
),
);

let actual = session.state.lock_unchecked().history.contents();
assert_eq!(expected, actual);
}

#[test]
fn prefers_structured_content_when_present() {
let ctr = CallToolResult {
Expand Down
13 changes: 6 additions & 7 deletions codex-rs/core/src/codex/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ async fn run_compact_task_inner(
};
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
let user_messages = collect_user_messages(&history_snapshot);
let new_history =
build_compacted_history(&sess, turn_context.as_ref(), &user_messages, &summary_text);
let initial_context = sess.build_initial_context(turn_context.as_ref());
let new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
{
let mut state = sess.state.lock_unchecked();
state.history.replace(new_history);
Expand Down Expand Up @@ -223,7 +223,7 @@ fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
}
}

fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
items
.iter()
.filter_map(|item| match item {
Expand All @@ -243,13 +243,12 @@ fn is_session_prefix_message(text: &str) -> bool {
)
}

fn build_compacted_history(
sess: &Session,
turn_context: &TurnContext,
pub(crate) fn build_compacted_history(
initial_context: Vec<ResponseItem>,
user_messages: &[String],
summary_text: &str,
) -> Vec<ResponseItem> {
let mut history = sess.build_initial_context(turn_context);
let mut history = initial_context;
let user_messages_text = if user_messages.is_empty() {
"(none)".to_string()
} else {
Expand Down
Loading
Loading