From c122259012e4773bcf238f2e7efc02a5c0fe136d Mon Sep 17 00:00:00 2001 From: jimmyfraiture Date: Wed, 24 Sep 2025 17:16:59 +0100 Subject: [PATCH] core: add SessionState helper methods and migrate call sites (history, approvals, tokens, rate limits) --- codex-rs/core/src/codex.rs | 45 ++++++---------- codex-rs/core/src/codex/compact.rs | 6 +-- codex-rs/core/src/state/session.rs | 86 ++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 31 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index c9a6668b26..cfafafb827 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -109,7 +109,6 @@ use crate::protocol::Submission; use crate::protocol::TaskCompleteEvent; use crate::protocol::TokenCountEvent; use crate::protocol::TokenUsage; -use crate::protocol::TokenUsageInfo; use crate::protocol::TurnDiffEvent; use crate::protocol::WebSearchBeginEvent; use crate::rollout::RolloutRecorder; @@ -569,7 +568,7 @@ impl Session { let event_id = sub_id.clone(); let prev_entry = { let mut state = self.state.lock().await; - state.pending_approvals.insert(sub_id, tx_approve) + state.insert_pending_approval(sub_id, tx_approve) }; if prev_entry.is_some() { warn!("Overwriting existing pending approval for sub_id: {event_id}"); @@ -601,7 +600,7 @@ impl Session { let event_id = sub_id.clone(); let prev_entry = { let mut state = self.state.lock().await; - state.pending_approvals.insert(sub_id, tx_approve) + state.insert_pending_approval(sub_id, tx_approve) }; if prev_entry.is_some() { warn!("Overwriting existing pending approval for sub_id: {event_id}"); @@ -623,7 +622,7 @@ impl Session { pub async fn notify_approval(&self, sub_id: &str, decision: ReviewDecision) { let entry = { let mut state = self.state.lock().await; - state.pending_approvals.remove(sub_id) + state.remove_pending_approval(sub_id) }; match entry { Some(tx_approve) => { @@ -637,7 +636,7 @@ impl Session { pub async fn add_approved_command(&self, cmd: Vec) { let mut state = self.state.lock().await; - state.approved_commands.insert(cmd); + state.add_approved_command(cmd); } /// Records input items: always append to conversation history and @@ -677,7 +676,7 @@ impl Session { /// Append ResponseItems to the in-memory conversation history only. async fn record_into_history(&self, items: &[ResponseItem]) { let mut state = self.state.lock().await; - state.history.record_items(items.iter()); + state.record_items(items.iter()); } async fn persist_rollout_response_items(&self, items: &[ResponseItem]) { @@ -724,12 +723,10 @@ impl Session { { let mut state = self.state.lock().await; if let Some(token_usage) = token_usage { - let info = TokenUsageInfo::new_or_append( - &state.token_info, - &Some(token_usage.clone()), + state.update_token_info_from_usage( + token_usage, turn_context.client.get_model_context_window(), ); - state.token_info = info; } } self.send_token_count_event(sub_id).await; @@ -738,7 +735,7 @@ impl Session { async fn update_rate_limits(&self, sub_id: &str, new_rate_limits: RateLimitSnapshot) { { let mut state = self.state.lock().await; - state.latest_rate_limits = Some(new_rate_limits); + state.set_rate_limits(new_rate_limits); } self.send_token_count_event(sub_id).await; } @@ -746,7 +743,7 @@ impl Session { async fn send_token_count_event(&self, sub_id: &str) { let (info, rate_limits) = { let state = self.state.lock().await; - (state.token_info.clone(), state.latest_rate_limits.clone()) + state.token_info_and_rate_limits() }; let event = Event { id: sub_id.to_string(), @@ -966,7 +963,7 @@ impl Session { pub async fn turn_input_with_history(&self, extra: Vec) -> Vec { let history = { let state = self.state.lock().await; - state.history.contents() + state.history_snapshot() }; [history, extra].concat() } @@ -975,7 +972,7 @@ impl Session { pub async fn inject_input(&self, input: Vec) -> Result<(), Vec> { let mut state = self.state.lock().await; if state.current_task.is_some() { - state.pending_input.push(input.into()); + state.push_pending_input(input.into()); Ok(()) } else { Err(input) @@ -984,13 +981,7 @@ impl Session { pub async fn get_pending_input(&self) -> Vec { let mut state = self.state.lock().await; - if state.pending_input.is_empty() { - Vec::with_capacity(0) - } else { - let mut ret = Vec::new(); - std::mem::swap(&mut ret, &mut state.pending_input); - ret - } + state.take_pending_input() } pub async fn call_tool( @@ -1008,8 +999,7 @@ impl Session { pub async fn interrupt_task(&self) { info!("interrupt received: abort current task, if any"); let mut state = self.state.lock().await; - state.pending_approvals.clear(); - state.pending_input.clear(); + state.clear_pending(); if let Some(task) = state.current_task.take() { task.abort(TurnAbortReason::Interrupted); } @@ -1017,8 +1007,7 @@ impl Session { fn interrupt_task_sync(&self) { if let Ok(mut state) = self.state.try_lock() { - state.pending_approvals.clear(); - state.pending_input.clear(); + state.clear_pending(); if let Some(task) = state.current_task.take() { task.abort(TurnAbortReason::Interrupted); } @@ -2807,7 +2796,7 @@ async fn handle_container_exec_with_params( ¶ms.command, turn_context.approval_policy, &turn_context.sandbox_policy, - &state.approved_commands, + state.approved_commands_ref(), params.with_escalated_permissions.unwrap_or(false), ) }; @@ -3375,7 +3364,7 @@ mod tests { }), )); - let actual = tokio_test::block_on(async { session.state.lock().await.history.contents() }); + let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() }); assert_eq!(expected, actual); } @@ -3388,7 +3377,7 @@ mod tests { session.record_initial_history(&turn_context, InitialHistory::Forked(rollout_items)), ); - let actual = tokio_test::block_on(async { session.state.lock().await.history.contents() }); + let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() }); assert_eq!(expected, actual); } diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index 8f213d4e7e..88e8a8f433 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -153,7 +153,7 @@ async fn run_compact_task_inner( } let history_snapshot = { let state = sess.state.lock().await; - state.history.contents() + state.history_snapshot() }; let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default(); let user_messages = collect_user_messages(&history_snapshot); @@ -161,7 +161,7 @@ async fn run_compact_task_inner( let new_history = build_compacted_history(initial_context, &user_messages, &summary_text); { let mut state = sess.state.lock().await; - state.history.replace(new_history); + state.replace_history(new_history); } let rollout_item = RolloutItem::Compacted(CompactedItem { @@ -271,7 +271,7 @@ async fn drain_to_completed( match event { Ok(ResponseEvent::OutputItemDone(item)) => { let mut state = sess.state.lock().await; - state.history.record_items(std::slice::from_ref(&item)); + state.record_items(std::slice::from_ref(&item)); } Ok(ResponseEvent::Completed { .. }) => { return Ok(()); diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index 8fc7cbb4fe..6b968769d2 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -4,12 +4,14 @@ use std::collections::HashMap; use std::collections::HashSet; use codex_protocol::models::ResponseInputItem; +use codex_protocol::models::ResponseItem; use tokio::sync::oneshot; use crate::codex::AgentTask; use crate::conversation_history::ConversationHistory; use crate::protocol::RateLimitSnapshot; use crate::protocol::ReviewDecision; +use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; /// Persistent, session-scoped state previously stored directly on `Session`. @@ -32,4 +34,88 @@ impl SessionState { ..Default::default() } } + + // History helpers + pub(crate) fn record_items(&mut self, items: I) + where + I: IntoIterator, + I::Item: std::ops::Deref, + { + self.history.record_items(items) + } + + pub(crate) fn history_snapshot(&self) -> Vec { + self.history.contents() + } + + pub(crate) fn replace_history(&mut self, items: Vec) { + self.history.replace(items); + } + + // Approved command helpers + pub(crate) fn add_approved_command(&mut self, cmd: Vec) { + self.approved_commands.insert(cmd); + } + + pub(crate) fn approved_commands_ref(&self) -> &HashSet> { + &self.approved_commands + } + + // Token/rate limit helpers + pub(crate) fn update_token_info_from_usage( + &mut self, + usage: &TokenUsage, + model_context_window: Option, + ) { + self.token_info = TokenUsageInfo::new_or_append( + &self.token_info, + &Some(usage.clone()), + model_context_window, + ); + } + + pub(crate) fn set_rate_limits(&mut self, snapshot: RateLimitSnapshot) { + self.latest_rate_limits = Some(snapshot); + } + + pub(crate) fn token_info_and_rate_limits( + &self, + ) -> (Option, Option) { + (self.token_info.clone(), self.latest_rate_limits.clone()) + } + + // Pending input/approval helpers + pub(crate) fn insert_pending_approval( + &mut self, + key: String, + tx: oneshot::Sender, + ) -> Option> { + self.pending_approvals.insert(key, tx) + } + + pub(crate) fn remove_pending_approval( + &mut self, + key: &str, + ) -> Option> { + self.pending_approvals.remove(key) + } + + pub(crate) fn clear_pending(&mut self) { + self.pending_approvals.clear(); + self.pending_input.clear(); + } + + pub(crate) fn push_pending_input(&mut self, input: ResponseInputItem) { + self.pending_input.push(input); + } + + pub(crate) fn take_pending_input(&mut self) -> Vec { + if self.pending_input.is_empty() { + Vec::with_capacity(0) + } else { + let mut ret = Vec::new(); + std::mem::swap(&mut ret, &mut self.pending_input); + ret + } + } }