diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 3ea2ca79b5f..a8dcb7259a3 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -47,6 +47,7 @@ use crate::openai_tools::create_tools_json_for_responses_api; use crate::protocol::RateLimitSnapshot; use crate::protocol::RateLimitWindow; use crate::protocol::TokenUsage; +use crate::state::TaskKind; use crate::token_data::PlanType; use crate::util::backoff; use codex_otel::otel_event_manager::OtelEventManager; @@ -123,8 +124,16 @@ impl ModelClient { /// the provider config. Public callers always invoke `stream()` – the /// specialised helpers are private to avoid accidental misuse. pub async fn stream(&self, prompt: &Prompt) -> Result { + self.stream_with_task_kind(prompt, TaskKind::Regular).await + } + + pub(crate) async fn stream_with_task_kind( + &self, + prompt: &Prompt, + task_kind: TaskKind, + ) -> Result { match self.provider.wire_api { - WireApi::Responses => self.stream_responses(prompt).await, + WireApi::Responses => self.stream_responses(prompt, task_kind).await, WireApi::Chat => { // Create the raw streaming connection first. let response_stream = stream_chat_completions( @@ -165,7 +174,11 @@ impl ModelClient { } /// Implementation for the OpenAI *Responses* experimental API. - async fn stream_responses(&self, prompt: &Prompt) -> Result { + async fn stream_responses( + &self, + prompt: &Prompt, + task_kind: TaskKind, + ) -> Result { if let Some(path) = &*CODEX_RS_SSE_FIXTURE { // short circuit for tests warn!(path, "Streaming from fixture"); @@ -244,7 +257,7 @@ impl ModelClient { let max_attempts = self.provider.request_max_retries(); for attempt in 0..=max_attempts { match self - .attempt_stream_responses(attempt, &payload_json, &auth_manager) + .attempt_stream_responses(attempt, &payload_json, &auth_manager, task_kind) .await { Ok(stream) => { @@ -272,6 +285,7 @@ impl ModelClient { attempt: u64, payload_json: &Value, auth_manager: &Option>, + task_kind: TaskKind, ) -> std::result::Result { // Always fetch the latest auth in case a prior attempt refreshed the token. let auth = auth_manager.as_ref().and_then(|m| m.auth()); @@ -294,6 +308,7 @@ impl ModelClient { .header("conversation_id", self.conversation_id.to_string()) .header("session_id", self.conversation_id.to_string()) .header(reqwest::header::ACCEPT, "text/event-stream") + .header("Codex-Task-Type", task_kind.header_value()) .json(payload_json); if let Some(auth) = auth.as_ref() diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index fe352f0103f..22b34b1cd03 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -99,6 +99,7 @@ use crate::rollout::RolloutRecorderParams; use crate::shell; use crate::state::ActiveTurn; use crate::state::SessionServices; +use crate::state::TaskKind; use crate::tasks::CompactTask; use crate::tasks::RegularTask; use crate::tasks::ReviewTask; @@ -1634,6 +1635,7 @@ pub(crate) async fn run_task( turn_context: Arc, sub_id: String, input: Vec, + task_kind: TaskKind, ) -> Option { if input.is_empty() { return None; @@ -1717,6 +1719,7 @@ pub(crate) async fn run_task( Arc::clone(&turn_diff_tracker), sub_id.clone(), turn_input, + task_kind, ) .await { @@ -1942,6 +1945,7 @@ async fn run_turn( turn_diff_tracker: SharedTurnDiffTracker, sub_id: String, input: Vec, + task_kind: TaskKind, ) -> CodexResult { let mcp_tools = sess.services.mcp_connection_manager.list_all_tools(); let router = Arc::new(ToolRouter::from_config( @@ -1971,6 +1975,7 @@ async fn run_turn( Arc::clone(&turn_diff_tracker), &sub_id, &prompt, + task_kind, ) .await { @@ -2044,6 +2049,7 @@ async fn try_run_turn( turn_diff_tracker: SharedTurnDiffTracker, sub_id: &str, prompt: &Prompt, + task_kind: TaskKind, ) -> CodexResult { // call_ids that are part of this response. let completed_call_ids = prompt @@ -2109,7 +2115,11 @@ async fn try_run_turn( summary: turn_context.client.get_reasoning_summary(), }); sess.persist_rollout_items(&[rollout_item]).await; - let mut stream = turn_context.client.clone().stream(&prompt).await?; + let mut stream = turn_context + .client + .clone() + .stream_with_task_kind(prompt.as_ref(), task_kind) + .await?; let tool_runtime = ToolCallRuntime::new( Arc::clone(&router), diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index d43e3abcbbf..93bbfa79c6a 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -16,6 +16,7 @@ use crate::protocol::InputItem; use crate::protocol::InputMessageKind; use crate::protocol::TaskStartedEvent; use crate::protocol::TurnContextItem; +use crate::state::TaskKind; use crate::truncate::truncate_middle; use crate::util::backoff; use askama::Template; @@ -258,7 +259,11 @@ async fn drain_to_completed( sub_id: &str, prompt: &Prompt, ) -> CodexResult<()> { - let mut stream = turn_context.client.clone().stream(prompt).await?; + let mut stream = turn_context + .client + .clone() + .stream_with_task_kind(prompt, TaskKind::Compact) + .await?; loop { let maybe_event = stream.next().await; let Some(event) = maybe_event else { diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs index f715d5481e5..89af13a1a50 100644 --- a/codex-rs/core/src/state/turn.rs +++ b/codex-rs/core/src/state/turn.rs @@ -34,6 +34,16 @@ pub(crate) enum TaskKind { Compact, } +impl TaskKind { + pub(crate) fn header_value(self) -> &'static str { + match self { + TaskKind::Regular => "standard", + TaskKind::Review => "review", + TaskKind::Compact => "compact", + } + } +} + #[derive(Clone)] pub(crate) struct RunningTask { pub(crate) handle: AbortHandle, @@ -113,3 +123,15 @@ impl ActiveTurn { } } } + +#[cfg(test)] +mod tests { + use super::TaskKind; + + #[test] + fn header_value_matches_expected_labels() { + assert_eq!(TaskKind::Regular.header_value(), "standard"); + assert_eq!(TaskKind::Review.header_value(), "review"); + assert_eq!(TaskKind::Compact.header_value(), "compact"); + } +} diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index 9d240997468..b3758d5fc60 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -27,6 +27,6 @@ impl SessionTask for RegularTask { input: Vec, ) -> Option { let sess = session.clone_session(); - run_task(sess, ctx, sub_id, input).await + run_task(sess, ctx, sub_id, input, TaskKind::Regular).await } } diff --git a/codex-rs/core/src/tasks/review.rs b/codex-rs/core/src/tasks/review.rs index 047a2f40e21..cec92432347 100644 --- a/codex-rs/core/src/tasks/review.rs +++ b/codex-rs/core/src/tasks/review.rs @@ -28,7 +28,7 @@ impl SessionTask for ReviewTask { input: Vec, ) -> Option { let sess = session.clone_session(); - run_task(sess, ctx, sub_id, input).await + run_task(sess, ctx, sub_id, input, TaskKind::Review).await } async fn abort(&self, session: Arc, sub_id: &str) { diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs new file mode 100644 index 00000000000..19967b06eab --- /dev/null +++ b/codex-rs/core/tests/responses_headers.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use codex_app_server_protocol::AuthMode; +use codex_core::ContentItem; +use codex_core::ModelClient; +use codex_core::ModelProviderInfo; +use codex_core::Prompt; +use codex_core::ResponseEvent; +use codex_core::ResponseItem; +use codex_core::WireApi; +use codex_otel::otel_event_manager::OtelEventManager; +use codex_protocol::ConversationId; +use core_test_support::load_default_config_for_test; +use core_test_support::responses; +use futures::StreamExt; +use tempfile::TempDir; +use wiremock::matchers::header; + +#[tokio::test] +async fn responses_stream_includes_task_type_header() { + core_test_support::skip_if_no_network!(); + + let server = responses::start_mock_server().await; + let response_body = responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ]); + + let request_recorder = responses::mount_sse_once_match( + &server, + header("Codex-Task-Type", "standard"), + response_body, + ) + .await; + + let provider = ModelProviderInfo { + name: "mock".into(), + base_url: Some(format!("{}/v1", server.uri())), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = TempDir::new().expect("failed to create TempDir"); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let config = Arc::new(config); + + let conversation_id = ConversationId::new(); + + let otel_event_manager = OtelEventManager::new( + conversation_id, + config.model.as_str(), + config.model_family.slug.as_str(), + None, + Some(AuthMode::ChatGPT), + false, + "test".to_string(), + ); + + let client = ModelClient::new( + Arc::clone(&config), + None, + otel_event_manager, + provider, + effort, + summary, + conversation_id, + ); + + let mut prompt = Prompt::default(); + prompt.input = vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: "hello".into(), + }], + }]; + + let mut stream = client.stream(&prompt).await.expect("stream failed"); + while let Some(event) = stream.next().await { + if matches!(event, Ok(ResponseEvent::Completed { .. })) { + break; + } + } + + let request = request_recorder.single_request(); + assert_eq!( + request.header("Codex-Task-Type").as_deref(), + Some("standard") + ); +}