Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 251 additions & 10 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ use codex_protocol::models::ShellToolCallParams;
use codex_protocol::protocol::InitialHistory;

mod compact;
use self::compact::build_compacted_history;
use self::compact::collect_user_messages;

// A convenience extension trait for acquiring mutex locks where poisoning is
// unrecoverable and should abort the program. This avoids scattered `.unwrap()`
Expand Down Expand Up @@ -205,7 +207,7 @@ impl Codex {
config.clone(),
auth_manager.clone(),
tx_event.clone(),
conversation_history.clone(),
conversation_history,
)
.await
.map_err(|e| {
Expand Down Expand Up @@ -564,9 +566,10 @@ impl Session {
let persist = matches!(conversation_history, InitialHistory::Forked(_));

// Always add response items to conversation history
let response_items = conversation_history.get_response_items();
if !response_items.is_empty() {
self.record_into_history(&response_items);
let reconstructed_history =
self.reconstruct_history_from_rollout(turn_context, &rollout_items);
if !reconstructed_history.is_empty() {
self.record_into_history(&reconstructed_history);
}

// If persisting, persist all rollout items as-is (recorder filters)
Expand Down Expand Up @@ -678,6 +681,33 @@ impl Session {
self.persist_rollout_response_items(items).await;
}

fn reconstruct_history_from_rollout(
&self,
turn_context: &TurnContext,
rollout_items: &[RolloutItem],
) -> Vec<ResponseItem> {
let mut history = ConversationHistory::new();
for item in rollout_items {
match item {
RolloutItem::ResponseItem(response_item) => {
history.record_items(std::iter::once(response_item));
}
RolloutItem::Compacted(compacted) => {
let snapshot = history.contents();
let user_messages = collect_user_messages(&snapshot);
let rebuilt = build_compacted_history(
self.build_initial_context(turn_context),
&user_messages,
&compacted.message,
);
history.replace(rebuilt);
}
_ => {}
}
}
history.contents()
}

/// Append ResponseItems to the in-memory conversation history only.
fn record_into_history(&self, items: &[ResponseItem]) {
self.state
Expand Down Expand Up @@ -3220,18 +3250,59 @@ async fn exit_review_mode(
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ConfigOverrides;
use crate::config::ConfigToml;
use crate::protocol::CompactedItem;
use crate::protocol::InitialHistory;
use crate::protocol::ResumedHistory;
use codex_protocol::models::ContentItem;
use mcp_types::ContentBlock;
use mcp_types::TextContent;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration as StdDuration;

fn text_block(s: &str) -> ContentBlock {
ContentBlock::TextContent(TextContent {
annotations: None,
text: s.to_string(),
r#type: "text".to_string(),
})
#[test]
fn reconstruct_history_matches_live_compactions() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);

let reconstructed = session.reconstruct_history_from_rollout(&turn_context, &rollout_items);

assert_eq!(expected, reconstructed);
}

#[test]
fn record_initial_history_reconstructs_resumed_transcript() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);

tokio_test::block_on(session.record_initial_history(
&turn_context,
InitialHistory::Resumed(ResumedHistory {
conversation_id: ConversationId::default(),
history: rollout_items,
rollout_path: PathBuf::from("/tmp/resume.jsonl"),
}),
));

let actual = session.state.lock_unchecked().history.contents();
assert_eq!(expected, actual);
}

#[test]
fn record_initial_history_reconstructs_forked_transcript() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);

tokio_test::block_on(
session.record_initial_history(&turn_context, InitialHistory::Forked(rollout_items)),
);

let actual = session.state.lock_unchecked().history.contents();
assert_eq!(expected, actual);
}

#[test]
Expand Down Expand Up @@ -3386,4 +3457,174 @@ mod tests {

assert_eq!(expected, got);
}

fn text_block(s: &str) -> ContentBlock {
ContentBlock::TextContent(TextContent {
annotations: None,
text: s.to_string(),
r#type: "text".to_string(),
})
}

fn make_session_and_context() -> (Session, TurnContext) {
let (tx_event, _rx_event) = async_channel::unbounded();
let codex_home = tempfile::tempdir().expect("create temp dir");
let config = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.expect("load default test config");
let config = Arc::new(config);
let conversation_id = ConversationId::default();
let client = ModelClient::new(
config.clone(),
None,
config.model_provider.clone(),
config.model_reasoning_effort,
config.model_reasoning_summary,
conversation_id,
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &config.model_family,
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
include_view_image_tool: config.include_view_image_tool,
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
});
let turn_context = TurnContext {
client,
cwd: config.cwd.clone(),
base_instructions: config.base_instructions.clone(),
user_instructions: config.user_instructions.clone(),
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
shell_environment_policy: config.shell_environment_policy.clone(),
tools_config,
is_review_mode: false,
};
let session = Session {
conversation_id,
tx_event,
mcp_connection_manager: McpConnectionManager::default(),
session_manager: ExecSessionManager::default(),
unified_exec_manager: UnifiedExecSessionManager::default(),
notify: None,
rollout: Mutex::new(None),
state: Mutex::new(State {
history: ConversationHistory::new(),
..Default::default()
}),
codex_linux_sandbox_exe: None,
user_shell: shell::Shell::Unknown,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
};
(session, turn_context)
}

fn sample_rollout(
session: &Session,
turn_context: &TurnContext,
) -> (Vec<RolloutItem>, Vec<ResponseItem>) {
let mut rollout_items = Vec::new();
let mut live_history = ConversationHistory::new();

let initial_context = session.build_initial_context(turn_context);
for item in &initial_context {
rollout_items.push(RolloutItem::ResponseItem(item.clone()));
}
live_history.record_items(initial_context.iter());

let user1 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "first user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user1));
rollout_items.push(RolloutItem::ResponseItem(user1.clone()));

let assistant1 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply one".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant1));
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));

let summary1 = "summary one";
let snapshot1 = live_history.contents();
let user_messages1 = collect_user_messages(&snapshot1);
let rebuilt1 = build_compacted_history(
session.build_initial_context(turn_context),
&user_messages1,
summary1,
);
live_history.replace(rebuilt1);
rollout_items.push(RolloutItem::Compacted(CompactedItem {
message: summary1.to_string(),
}));

let user2 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "second user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user2));
rollout_items.push(RolloutItem::ResponseItem(user2.clone()));

let assistant2 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply two".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant2));
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));

let summary2 = "summary two";
let snapshot2 = live_history.contents();
let user_messages2 = collect_user_messages(&snapshot2);
let rebuilt2 = build_compacted_history(
session.build_initial_context(turn_context),
&user_messages2,
summary2,
);
live_history.replace(rebuilt2);
rollout_items.push(RolloutItem::Compacted(CompactedItem {
message: summary2.to_string(),
}));

let user3 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "third user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user3));
rollout_items.push(RolloutItem::ResponseItem(user3.clone()));

let assistant3 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply three".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant3));
rollout_items.push(RolloutItem::ResponseItem(assistant3.clone()));

(rollout_items, live_history.contents())
}
}
13 changes: 6 additions & 7 deletions codex-rs/core/src/codex/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ async fn run_compact_task_inner(
};
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
let user_messages = collect_user_messages(&history_snapshot);
let new_history =
build_compacted_history(&sess, turn_context.as_ref(), &user_messages, &summary_text);
let initial_context = sess.build_initial_context(turn_context.as_ref());
let new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
{
let mut state = sess.state.lock_unchecked();
state.history.replace(new_history);
Expand Down Expand Up @@ -223,7 +223,7 @@ fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
}
}

fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
items
.iter()
.filter_map(|item| match item {
Expand All @@ -243,13 +243,12 @@ fn is_session_prefix_message(text: &str) -> bool {
)
}

fn build_compacted_history(
sess: &Session,
turn_context: &TurnContext,
pub(crate) fn build_compacted_history(
initial_context: Vec<ResponseItem>,
user_messages: &[String],
summary_text: &str,
) -> Vec<ResponseItem> {
let mut history = sess.build_initial_context(turn_context);
let mut history = initial_context;
let user_messages_text = if user_messages.is_empty() {
"(none)".to_string()
} else {
Expand Down
Loading
Loading