Skip to content

Commit 9319eec

Browse files
fix(langchain_v1): ToolRuntime default for args (#33606)
added some noqas, this is a quick patch to support a bug uncovered in the quickstart, will resolve fully depending on where we centralize ToolNode stuff.
1 parent a47386f commit 9319eec

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

libs/langchain_v1/langchain/tools/tool_node.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def my_tool(x: int) -> str:
4949
Generic,
5050
Literal,
5151
TypedDict,
52-
TypeVar,
5352
Union,
5453
cast,
5554
get_args,
@@ -84,15 +83,18 @@ def my_tool(x: int) -> str:
8483
from langgraph.store.base import BaseStore # noqa: TC002
8584
from langgraph.types import Command, Send, StreamWriter
8685
from pydantic import BaseModel, ValidationError
87-
from typing_extensions import Unpack
86+
from typing_extensions import TypeVar, Unpack
8887

8988
if TYPE_CHECKING:
9089
from collections.abc import Sequence
9190

9291
from langgraph.runtime import Runtime
9392

94-
StateT = TypeVar("StateT")
95-
ContextT = TypeVar("ContextT")
93+
# right now we use a dict as the default, can change this to AgentState, but depends
94+
# on if this lives in LangChain or LangGraph... ideally would have some typed
95+
# messages key
96+
StateT = TypeVar("StateT", default=dict)
97+
ContextT = TypeVar("ContextT", default=None)
9698

9799
INVALID_TOOL_NAME_ERROR_TEMPLATE = (
98100
"Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
@@ -626,7 +628,7 @@ def _func(
626628
injected_tool_calls = []
627629
input_types = [input_type] * len(tool_calls)
628630
for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False):
629-
injected_call = self._inject_tool_args(call, tool_runtime)
631+
injected_call = self._inject_tool_args(call, tool_runtime) # type: ignore[arg-type]
630632
injected_tool_calls.append(injected_call)
631633
with get_executor_for_config(config) as executor:
632634
outputs = list(
@@ -661,9 +663,9 @@ async def _afunc(
661663
injected_tool_calls = []
662664
coros = []
663665
for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False):
664-
injected_call = self._inject_tool_args(call, tool_runtime)
666+
injected_call = self._inject_tool_args(call, tool_runtime) # type: ignore[arg-type]
665667
injected_tool_calls.append(injected_call)
666-
coros.append(self._arun_one(injected_call, input_type, tool_runtime))
668+
coros.append(self._arun_one(injected_call, input_type, tool_runtime)) # type: ignore[arg-type]
667669
outputs = await asyncio.gather(*coros)
668670

669671
return self._combine_tool_outputs(outputs, input_type)

0 commit comments

Comments
 (0)