Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 17 additions & 28 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ use crate::protocol::Submission;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TokenCountEvent;
use crate::protocol::TokenUsage;
use crate::protocol::TokenUsageInfo;
use crate::protocol::TurnDiffEvent;
use crate::protocol::WebSearchBeginEvent;
use crate::rollout::RolloutRecorder;
Expand Down Expand Up @@ -569,7 +568,7 @@ impl Session {
let event_id = sub_id.clone();
let prev_entry = {
let mut state = self.state.lock().await;
state.pending_approvals.insert(sub_id, tx_approve)
state.insert_pending_approval(sub_id, tx_approve)
};
if prev_entry.is_some() {
warn!("Overwriting existing pending approval for sub_id: {event_id}");
Expand Down Expand Up @@ -601,7 +600,7 @@ impl Session {
let event_id = sub_id.clone();
let prev_entry = {
let mut state = self.state.lock().await;
state.pending_approvals.insert(sub_id, tx_approve)
state.insert_pending_approval(sub_id, tx_approve)
};
if prev_entry.is_some() {
warn!("Overwriting existing pending approval for sub_id: {event_id}");
Expand All @@ -623,7 +622,7 @@ impl Session {
pub async fn notify_approval(&self, sub_id: &str, decision: ReviewDecision) {
let entry = {
let mut state = self.state.lock().await;
state.pending_approvals.remove(sub_id)
state.remove_pending_approval(sub_id)
};
match entry {
Some(tx_approve) => {
Expand All @@ -637,7 +636,7 @@ impl Session {

pub async fn add_approved_command(&self, cmd: Vec<String>) {
let mut state = self.state.lock().await;
state.approved_commands.insert(cmd);
state.add_approved_command(cmd);
}

/// Records input items: always append to conversation history and
Expand Down Expand Up @@ -677,7 +676,7 @@ impl Session {
/// Append ResponseItems to the in-memory conversation history only.
async fn record_into_history(&self, items: &[ResponseItem]) {
let mut state = self.state.lock().await;
state.history.record_items(items.iter());
state.record_items(items.iter());
}

async fn persist_rollout_response_items(&self, items: &[ResponseItem]) {
Expand Down Expand Up @@ -724,12 +723,10 @@ impl Session {
{
let mut state = self.state.lock().await;
if let Some(token_usage) = token_usage {
let info = TokenUsageInfo::new_or_append(
&state.token_info,
&Some(token_usage.clone()),
state.update_token_info_from_usage(
token_usage,
turn_context.client.get_model_context_window(),
);
state.token_info = info;
}
}
self.send_token_count_event(sub_id).await;
Expand All @@ -738,15 +735,15 @@ impl Session {
async fn update_rate_limits(&self, sub_id: &str, new_rate_limits: RateLimitSnapshot) {
{
let mut state = self.state.lock().await;
state.latest_rate_limits = Some(new_rate_limits);
state.set_rate_limits(new_rate_limits);
}
self.send_token_count_event(sub_id).await;
}

async fn send_token_count_event(&self, sub_id: &str) {
let (info, rate_limits) = {
let state = self.state.lock().await;
(state.token_info.clone(), state.latest_rate_limits.clone())
state.token_info_and_rate_limits()
};
let event = Event {
id: sub_id.to_string(),
Expand Down Expand Up @@ -966,7 +963,7 @@ impl Session {
pub async fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
let history = {
let state = self.state.lock().await;
state.history.contents()
state.history_snapshot()
};
[history, extra].concat()
}
Expand All @@ -975,7 +972,7 @@ impl Session {
pub async fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
let mut state = self.state.lock().await;
if state.current_task.is_some() {
state.pending_input.push(input.into());
state.push_pending_input(input.into());
Ok(())
} else {
Err(input)
Expand All @@ -984,13 +981,7 @@ impl Session {

pub async fn get_pending_input(&self) -> Vec<ResponseInputItem> {
let mut state = self.state.lock().await;
if state.pending_input.is_empty() {
Vec::with_capacity(0)
} else {
let mut ret = Vec::new();
std::mem::swap(&mut ret, &mut state.pending_input);
ret
}
state.take_pending_input()
}

pub async fn call_tool(
Expand All @@ -1008,17 +999,15 @@ impl Session {
pub async fn interrupt_task(&self) {
info!("interrupt received: abort current task, if any");
let mut state = self.state.lock().await;
state.pending_approvals.clear();
state.pending_input.clear();
state.clear_pending();
if let Some(task) = state.current_task.take() {
task.abort(TurnAbortReason::Interrupted);
}
}

fn interrupt_task_sync(&self) {
if let Ok(mut state) = self.state.try_lock() {
state.pending_approvals.clear();
state.pending_input.clear();
state.clear_pending();
if let Some(task) = state.current_task.take() {
task.abort(TurnAbortReason::Interrupted);
}
Expand Down Expand Up @@ -2807,7 +2796,7 @@ async fn handle_container_exec_with_params(
&params.command,
turn_context.approval_policy,
&turn_context.sandbox_policy,
&state.approved_commands,
state.approved_commands_ref(),
params.with_escalated_permissions.unwrap_or(false),
)
};
Expand Down Expand Up @@ -3375,7 +3364,7 @@ mod tests {
}),
));

let actual = tokio_test::block_on(async { session.state.lock().await.history.contents() });
let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() });
assert_eq!(expected, actual);
}

Expand All @@ -3388,7 +3377,7 @@ mod tests {
session.record_initial_history(&turn_context, InitialHistory::Forked(rollout_items)),
);

let actual = tokio_test::block_on(async { session.state.lock().await.history.contents() });
let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() });
assert_eq!(expected, actual);
}

Expand Down
6 changes: 3 additions & 3 deletions codex-rs/core/src/codex/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ async fn run_compact_task_inner(
}
let history_snapshot = {
let state = sess.state.lock().await;
state.history.contents()
state.history_snapshot()
};
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
let user_messages = collect_user_messages(&history_snapshot);
let initial_context = sess.build_initial_context(turn_context.as_ref());
let new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
{
let mut state = sess.state.lock().await;
state.history.replace(new_history);
state.replace_history(new_history);
}

let rollout_item = RolloutItem::Compacted(CompactedItem {
Expand Down Expand Up @@ -271,7 +271,7 @@ async fn drain_to_completed(
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
let mut state = sess.state.lock().await;
state.history.record_items(std::slice::from_ref(&item));
state.record_items(std::slice::from_ref(&item));
}
Ok(ResponseEvent::Completed { .. }) => {
return Ok(());
Expand Down
86 changes: 86 additions & 0 deletions codex-rs/core/src/state/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use std::collections::HashMap;
use std::collections::HashSet;

use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use tokio::sync::oneshot;

use crate::codex::AgentTask;
use crate::conversation_history::ConversationHistory;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::ReviewDecision;
use crate::protocol::TokenUsage;
use crate::protocol::TokenUsageInfo;

/// Persistent, session-scoped state previously stored directly on `Session`.
Expand All @@ -32,4 +34,88 @@ impl SessionState {
..Default::default()
}
}

// History helpers
pub(crate) fn record_items<I>(&mut self, items: I)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = ResponseItem>,
{
self.history.record_items(items)
}

pub(crate) fn history_snapshot(&self) -> Vec<ResponseItem> {
self.history.contents()
}

pub(crate) fn replace_history(&mut self, items: Vec<ResponseItem>) {
self.history.replace(items);
}

// Approved command helpers
pub(crate) fn add_approved_command(&mut self, cmd: Vec<String>) {
self.approved_commands.insert(cmd);
}

pub(crate) fn approved_commands_ref(&self) -> &HashSet<Vec<String>> {
&self.approved_commands
}

// Token/rate limit helpers
pub(crate) fn update_token_info_from_usage(
&mut self,
usage: &TokenUsage,
model_context_window: Option<u64>,
) {
self.token_info = TokenUsageInfo::new_or_append(
&self.token_info,
&Some(usage.clone()),
model_context_window,
);
}

pub(crate) fn set_rate_limits(&mut self, snapshot: RateLimitSnapshot) {
self.latest_rate_limits = Some(snapshot);
}

pub(crate) fn token_info_and_rate_limits(
&self,
) -> (Option<TokenUsageInfo>, Option<RateLimitSnapshot>) {
(self.token_info.clone(), self.latest_rate_limits.clone())
}

// Pending input/approval helpers
pub(crate) fn insert_pending_approval(
&mut self,
key: String,
tx: oneshot::Sender<ReviewDecision>,
) -> Option<oneshot::Sender<ReviewDecision>> {
self.pending_approvals.insert(key, tx)
}

pub(crate) fn remove_pending_approval(
&mut self,
key: &str,
) -> Option<oneshot::Sender<ReviewDecision>> {
self.pending_approvals.remove(key)
}

pub(crate) fn clear_pending(&mut self) {
self.pending_approvals.clear();
self.pending_input.clear();
}

pub(crate) fn push_pending_input(&mut self, input: ResponseInputItem) {
self.pending_input.push(input);
}

pub(crate) fn take_pending_input(&mut self) -> Vec<ResponseInputItem> {
if self.pending_input.is_empty() {
Vec::with_capacity(0)
} else {
let mut ret = Vec::new();
std::mem::swap(&mut ret, &mut self.pending_input);
ret
}
}
}
Loading