Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions codex-rs/core/src/tools/handlers/apply_patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ impl ToolHandler for ApplyPatchHandler {
)
}

fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
true
}

async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
Expand Down
13 changes: 13 additions & 0 deletions codex-rs/core/src/tools/handlers/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::codex::TurnContext;
use crate::exec::ExecParams;
use crate::exec_env::create_env;
use crate::function_tool::FunctionCallError;
use crate::is_safe_command::is_known_safe_command;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
Expand Down Expand Up @@ -52,6 +53,18 @@ impl ToolHandler for ShellHandler {
)
}

fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
match &invocation.payload {
ToolPayload::Function { arguments } => {
serde_json::from_str::<ShellToolCallParams>(arguments)
.map(|params| !is_known_safe_command(&params.command))
.unwrap_or(true)
}
ToolPayload::LocalShell { params } => !is_known_safe_command(&params.command),
_ => true, // unknown payloads => assume mutating
}
}

async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
Expand Down
19 changes: 16 additions & 3 deletions codex-rs/core/src/tools/handlers/unified_exec.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::path::PathBuf;

use async_trait::async_trait;
use serde::Deserialize;

use crate::function_tool::FunctionCallError;
use crate::is_safe_command::is_known_safe_command;
use crate::protocol::EventMsg;
use crate::protocol::ExecCommandOutputDeltaEvent;
use crate::protocol::ExecOutputStream;
Expand All @@ -20,6 +18,8 @@ use crate::unified_exec::UnifiedExecContext;
use crate::unified_exec::UnifiedExecResponse;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::unified_exec::WriteStdinRequest;
use async_trait::async_trait;
use serde::Deserialize;

pub struct UnifiedExecHandler;

Expand Down Expand Up @@ -74,6 +74,19 @@ impl ToolHandler for UnifiedExecHandler {
)
}

fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
let (ToolPayload::Function { arguments } | ToolPayload::UnifiedExec { arguments }) =
&invocation.payload
else {
return true;
};

let Ok(params) = serde_json::from_str::<ExecCommandArgs>(arguments) else {
return true;
};
!is_known_safe_command(&["bash".to_string(), "-lc".to_string(), params.cmd])
}

async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
Expand Down
5 changes: 0 additions & 5 deletions codex-rs/core/src/tools/parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouter;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
use codex_utils_readiness::Readiness;

pub(crate) struct ToolCallRuntime {
router: Arc<ToolRouter>,
Expand Down Expand Up @@ -55,7 +54,6 @@ impl ToolCallRuntime {
let tracker = Arc::clone(&self.tracker);
let lock = Arc::clone(&self.parallel_execution);
let started = Instant::now();
let readiness = self.turn_context.tool_call_gate.clone();

let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
AbortOnDropHandle::new(tokio::spawn(async move {
Expand All @@ -65,9 +63,6 @@ impl ToolCallRuntime {
Ok(Self::aborted_response(&call, secs))
},
res = async {
tracing::trace!("waiting for tool gate");
readiness.wait_ready().await;
tracing::trace!("tool gate released");
let _guard = if supports_parallel {
Either::Left(lock.read().await)
} else {
Expand Down
17 changes: 13 additions & 4 deletions codex-rs/core/src/tools/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use codex_protocol::models::ResponseInputItem;
use tracing::warn;

use crate::client_common::tools::ToolSpec;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use async_trait::async_trait;
use codex_protocol::models::ResponseInputItem;
use codex_utils_readiness::Readiness;
use tracing::warn;

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ToolKind {
Expand All @@ -30,6 +30,10 @@ pub trait ToolHandler: Send + Sync {
)
}

fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
false
}

async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
}

Expand Down Expand Up @@ -106,6 +110,11 @@ impl ToolRegistry {
let output_cell = &output_cell;
let invocation = invocation;
async move {
if handler.is_mutating(&invocation) {
tracing::trace!("waiting for tool gate");
invocation.turn.tool_call_gate.wait_ready().await;
tracing::trace!("tool gate released");
}
match handler.handle(invocation).await {
Ok(output) => {
let preview = output.log_preview();
Expand Down
Loading