From 729b6a8e340eed27a572e380a96e71910d7565d3 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Mon, 23 Feb 2026 12:37:15 +0000 Subject: [PATCH 01/60] APPENG-4528 and APPENG-4529 --- src/vuln_analysis/functions/cve_agent.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 4f1e168d..753b2311 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os import asyncio from vuln_analysis.runtime_context import ctx_state import typing @@ -210,7 +210,13 @@ async def cve_agent(config: CVEAgentExecutorToolConfig, builder: Builder): async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: trace_id.set(state.original_input.input.scan.id) ctx_state.set(state) - agent = await _create_agent(config, builder, state) + agent = None + if os.environ.get("ENABLE_GRAPH_AGENT","0") == "1": + logger.info("ENABLE_GRAPH_AGENT is set to 1. Executing CVE agent in graph mode.") + agent = await _create_agent(config, builder, state) + else: + logger.info("ENABLE_GRAPH_AGENT is set to 0. Executing CVE agent.") + agent = await _create_agent(config, builder, state) results = await asyncio.gather(*(_process_steps(agent, steps, semaphore) for steps in state.checklist_plans.values()), return_exceptions=True) results = _postprocess_results(results, config.replace_exceptions, config.replace_exceptions_value, From 97d7e2d6859d09548850c18968af6965aa97f733 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Mon, 23 Feb 2026 14:09:50 +0000 Subject: [PATCH 02/60] basic react graph --- .../functions/cve_react_graph.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 src/vuln_analysis/functions/cve_react_graph.py diff --git a/src/vuln_analysis/functions/cve_react_graph.py b/src/vuln_analysis/functions/cve_react_graph.py new file mode 100644 index 00000000..49d3b1cb --- /dev/null +++ b/src/vuln_analysis/functions/cve_react_graph.py @@ -0,0 +1,118 @@ +import os +import uuid +from langgraph.graph import MessagesState, StateGraph,END,START +from langgraph.prebuilt import ToolNode, tools_condition +from langgraph.graph.message import add_messages +from langgraph.errors import GraphRecursionError +from langchain_openai import ChatOpenAI +from langchain_core.tools import tool +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage, BaseMessage +from pydantic import BaseModel, Field +from langchain_core.tools import StructuredTool +from typing import Literal + +class FinalAnswer(BaseModel): + """Use this when you are done and have the answer.""" + action_type: Literal["final_answer"] = "final_answer" + response: str = Field(description="The final response to the user") + +class TripleAction(BaseModel): + action_type: Literal["triple"] = "triple" + num: float = Field(description="The number to triple") + +action_models = [TripleAction, FinalAnswer] + +# The union type forces the LLM to pick exactly one of these valid structures +class AgentResponse(BaseModel): + command: TripleAction | FinalAnswer + +tool_descriptions = "\n".join(["triple: param num: a number to triple \nreturns: the triple of the input number\n"]) + +SYSTEM_MESSAGE = f"""You are a reasoning agent. +You have access to: {tool_descriptions} + +Your logic flow must be: +1. **REASON**: What has been asked? What has been done? What is left to do? +2. **ACT**: Call a tool if more info or a calculation is needed. +3. **VERIFY**: (After a tool returns) Did this tool result satisfy the ENTIRETY of the user's request? + +CRITICAL: If the user asked for a multi-step task (e.g., "Find X AND then do Y"), do not trigger 'FinalAnswer' after finding X. You must loop back to perform Y. +""" + +AGENT_REASON ="agent_reason" +ACT = "act" +LAST = -1 +os.environ["OPENAI_API_KEY"] = "EMPTY" +base_url = os.environ["NVIDIA_API_BASE"] +model_name = os.environ["CVE_AGENT_EXECUTOR_MODEL_NAME"] +llm = ChatOpenAI(base_url=base_url, model=model_name ,temperature=0.1, max_completion_tokens=2000, top_p=0.01) + +def triple(num: float) -> float: + """ + :param num: a number to triple + :return: the number tripled -> multiplied by 3 + """ + return 3 * float(num) + +tools = [StructuredTool.from_function(triple)] +#tools = [triple] +tool_node = ToolNode(tools) + +llm = llm.with_structured_output(AgentResponse) + +def agent_reasoning_node(state: MessagesState) -> MessagesState: + response = llm.invoke( + [{"role": "system", "content": SYSTEM_MESSAGE}, *state["messages"]] + ) + command = response.command # Extract the specific action object + print(f'agent_reasoning_node:command: {command}') + # --- ROUTING LOGIC --- + + # Case 1: Final Answer + if isinstance(command, FinalAnswer): + return { + "messages": [AIMessage(content=command.response)] + } + + # Case 2: Tool Call (Search or Calculate) + else: + # We assume the 'action_type' matches the tool name exactly + tool_name = command.action_type + + # We convert the Pydantic object to a dict, explicitly excluding the type tag + # This leaves ONLY the arguments (e.g., {'query': '...'} or {'a': 1, 'b': 2}) + arguments = command.model_dump(exclude={"action_type"}) + + # Construct the message for ToolNode + tool_call_id = str(uuid.uuid4()) + msg = AIMessage( + content="", + tool_calls=[{ + "name": tool_name, + "args": arguments, # This is now GUARANTEED to have the right keys + "id": tool_call_id + }] + ) + return {"messages": [msg]} + +def should_continue(state: MessagesState)->str: + if not state["messages"][LAST].tool_calls: + return END + return ACT + +def _create_cve_react_graph(): + flow = StateGraph(MessagesState) + flow.add_node(AGENT_REASON, agent_reasoning_node) + flow.add_node(ACT, tool_node) + flow.add_edge(START, AGENT_REASON) + flow.add_conditional_edges(AGENT_REASON,should_continue,{END: END,ACT: ACT}) + flow.add_edge(ACT, AGENT_REASON) + app = flow.compile() + app.get_graph().draw_mermaid_png(output_file_path="flow.png") + return app + +if __name__ == "__main__": + print("Hello react graph, World!") + app =_create_cve_react_graph() + result = app.invoke({"messages": [HumanMessage(content="The current temperature is 12° Celsius. Triple it")]}) + print(result) \ No newline at end of file From 0398a343683a4c5a90822d33d995a9cadd384d64 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Tue, 24 Feb 2026 08:49:34 +0200 Subject: [PATCH 03/60] save --- src/vuln_analysis/functions/cve_react_graph.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/vuln_analysis/functions/cve_react_graph.py b/src/vuln_analysis/functions/cve_react_graph.py index 49d3b1cb..e65c6094 100644 --- a/src/vuln_analysis/functions/cve_react_graph.py +++ b/src/vuln_analysis/functions/cve_react_graph.py @@ -2,10 +2,8 @@ import uuid from langgraph.graph import MessagesState, StateGraph,END,START from langgraph.prebuilt import ToolNode, tools_condition -from langgraph.graph.message import add_messages from langgraph.errors import GraphRecursionError from langchain_openai import ChatOpenAI -from langchain_core.tools import tool from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage, BaseMessage from pydantic import BaseModel, Field from langchain_core.tools import StructuredTool @@ -113,6 +111,10 @@ def _create_cve_react_graph(): if __name__ == "__main__": print("Hello react graph, World!") - app =_create_cve_react_graph() - result = app.invoke({"messages": [HumanMessage(content="The current temperature is 12° Celsius. Triple it")]}) - print(result) \ No newline at end of file + try: + app =_create_cve_react_graph() + result = app.invoke({"messages": [HumanMessage(content="Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?")]}) + print(result) + except GraphRecursionError as e: + print(f"GraphRecursionError: {e}") + print("GraphRecursionError: Stopping forcefully.") From e011d31d00ce5b4bcc7c629274f7d3bbe0435684 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Tue, 24 Feb 2026 12:30:53 +0000 Subject: [PATCH 04/60] make baseline --- src/vuln_analysis/functions/cve_agent.py | 26 +++-- .../functions/cve_react_graph.py | 22 ++++- .../functions/react_internals.py | 96 +++++++++++++++++++ 3 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 src/vuln_analysis/functions/react_internals.py diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 753b2311..1569eec8 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -66,14 +66,9 @@ class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): description="Whether to enable CVE Web Search tool or not.") verbose: bool = Field(default=False, description="Set to true for verbose output") - -async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, - state: AgentMorpheusEngineState) -> AgentExecutor: +async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builder,state: AgentMorpheusEngineState) -> tuple[list[typing.Any],list[str]]: from vuln_analysis.utils.prompting import build_tool_descriptions - tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - # Filter tools that are not available based on state tools = [ tool for tool in tools @@ -89,13 +84,28 @@ async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, state.code_index_path is None)) ) ] - # Get tool names after filtering for dynamic guidance enabled_tool_names = [tool.name for tool in tools] # Build tool selection guidance with strategic context tool_descriptions = build_tool_descriptions(enabled_tool_names) + return tools, tool_descriptions + +async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Builder,state: AgentMorpheusEngineState): + + tools, tool_descriptions = await common_build_tools(config, builder,state) + llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) tool_guidance = "\n".join(tool_descriptions) + prompt_template_str = get_agent_prompt(config.prompt, config.prompt_examples) + +async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, + state: AgentMorpheusEngineState) -> AgentExecutor: + + tools, tool_descriptions = await common_build_tools(config, builder, state) + + llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + tool_guidance = "\n".join(tool_descriptions) + # Get prompt template prompt_template_str = get_agent_prompt(config.prompt, config.prompt_examples) @@ -213,7 +223,7 @@ async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: agent = None if os.environ.get("ENABLE_GRAPH_AGENT","0") == "1": logger.info("ENABLE_GRAPH_AGENT is set to 1. Executing CVE agent in graph mode.") - agent = await _create_agent(config, builder, state) + agent = await _create_graph_agent(config, builder, state) else: logger.info("ENABLE_GRAPH_AGENT is set to 0. Executing CVE agent.") agent = await _create_agent(config, builder, state) diff --git a/src/vuln_analysis/functions/cve_react_graph.py b/src/vuln_analysis/functions/cve_react_graph.py index e65c6094..6825f8ee 100644 --- a/src/vuln_analysis/functions/cve_react_graph.py +++ b/src/vuln_analysis/functions/cve_react_graph.py @@ -43,7 +43,7 @@ class AgentResponse(BaseModel): os.environ["OPENAI_API_KEY"] = "EMPTY" base_url = os.environ["NVIDIA_API_BASE"] model_name = os.environ["CVE_AGENT_EXECUTOR_MODEL_NAME"] -llm = ChatOpenAI(base_url=base_url, model=model_name ,temperature=0.1, max_completion_tokens=2000, top_p=0.01) +llm_base = ChatOpenAI(base_url=base_url, model=model_name ,temperature=0.0, max_completion_tokens=2000, top_p=0.01) def triple(num: float) -> float: """ @@ -56,7 +56,7 @@ def triple(num: float) -> float: #tools = [triple] tool_node = ToolNode(tools) -llm = llm.with_structured_output(AgentResponse) +llm = llm_base.with_structured_output(AgentResponse) def agent_reasoning_node(state: MessagesState) -> MessagesState: response = llm.invoke( @@ -109,12 +109,24 @@ def _create_cve_react_graph(): app.get_graph().draw_mermaid_png(output_file_path="flow.png") return app + +from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought,Observation + if __name__ == "__main__": print("Hello react graph, World!") try: - app =_create_cve_react_graph() - result = app.invoke({"messages": [HumanMessage(content="Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?")]}) - print(result) + #app =_create_cve_react_graph() + #result = app.invoke({"messages": [HumanMessage(content="Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?")]}) + #print(result) + TOOL_GUIDANCE = """Code Keyword Search: Exact text matching for function names, class names, or imports\nCall Chain Analyzer: Checks if functions are reachable from application code\nFunction Caller Finder: Finds which functions call specific library functions\nUse 'Function Caller Finder' + 'Call Chain Analyzer' together to trace function reachability\nCVE Web Search: External vulnerability information lookup""" + TOOL_DESCRIPTIONS = """\nCode Keyword Search(*args, **kwargs) - Performs keyword search on container source code for exact text matches. Input should be a function name, class name, or code pattern. Use this first before semantic search tools for precise lookups.\nCVE Web Search(*args, **kwargs) - Searches the web for information about CVEs, vulnerabilities, libraries, and security advisories not available in the container.\nCall Chain Analyzer(*args, **kwargs) - Checks if a function from a package is reachable from application code through the call chain.\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name'.\n Example 1: 'urllib,parse'.\n \n Input format 2(java): 'maven_gav,class_name.function_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: (is_reachable: bool, call_hierarchy_path: list).\nFunction Caller Finder(*args, **kwargs) - Finds functions in a package that call a specific library function. GO ecosystem only.\n Input format: 'package_name,library.function(args_with_literals)'. \n Example: 'github.com/namespace/package_name,errors.New(\"text_literal'\")'.\n Returns: ['package,caller1', 'package,caller2'] or [].\nFunction Locator(*args, **kwargs) - Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, and provides ecosystem type (GO/Python/Java/JavaScript/C/C++).\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name' or 'package_name,class_name.method_name'.\n Example 1: 'libxml2,xmlParseDocument'.\n \n Input format 2(java): 'maven_gav,class_name.method_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: {'ecosystem': str, 'package_msg': str, 'result': [function_names]}.\n""" + SYSTEM_MESSAGE = build_system_prompt(TOOL_DESCRIPTIONS, TOOL_GUIDANCE) + llm_test = llm_base.with_structured_output(Thought) + response = llm_test.invoke([{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": "Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?"}]) + print(response) + except GraphRecursionError as e: print(f"GraphRecursionError: {e}") print("GraphRecursionError: Stopping forcefully.") + + diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py new file mode 100644 index 00000000..72c683d4 --- /dev/null +++ b/src/vuln_analysis/functions/react_internals.py @@ -0,0 +1,96 @@ +from langchain_core.messages import HumanMessage +from langchain_core.output_parsers.openai_tools import ( + JsonOutputToolsParser, + PydanticToolsParser, +) +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_openai import ChatOpenAI + +from pydantic import BaseModel, Field +from typing import Any,Optional +from typing import Literal +#---- REACT Schemas ----# + +class ToolCall(BaseModel): + tool: str = Field(description="Tool name to execute (must match AVAILABLE_TOOLS exactly)") + input: dict[str, Any] = Field( + description='Arguments for the tool. Use {"query": "value"} - the value format depends on the tool (see AVAILABLE_TOOLS). NEVER use empty {}. Examples: Code Keyword Search {"query": "urllib.parse"}, Call Chain Analyzer {"query": "libpq,PQescapeLiteral"}, Function Locator {"query": "libpq,function_name"}.', + ) + reason: str = Field(description="Why this tool is needed") + +class Thought(BaseModel): + thought: str = Field( + description="Brief reasoning about next step (max 3-4 sentences)", + max_length=300, + ) + mode: Literal["act", "finish"] = Field(description="'act' to call tools (when more information is needed), 'finish' to return the final answer (when you have sufficient evidence)") + + actions: ToolCall | None = Field(default=None, description="When mode is 'act', the tool to execute") + + final_answer: str | None = Field( + default=None, + description="When mode is 'finish', concise answer (3-5 sentences) with key evidence", + max_length=500, + ) + +class Observation(BaseModel): + results: list[dict[str, Any]] = Field(description="Raw tool outputs") + + memory: str = Field(description="Compressed working memory summary") + +class AgentState(BaseModel): + #goal: str = Field(description="Investigation question") + #step: int = Field(description="Current step number") + #max_steps: int = Field(description="Maximum number of steps") + memory: str | None = Field(default=None, description="Compressed working memory summary") + thought: Thought | None = Field(description="Current thought") + observation: Observation | None = Field(description="Current observation") + +### --- End of REACT Schemas ----# +#---- REACT Prompt Templates ----# +AGENT_SYS_PROMPT = ( + "You are an expert security analyst investigating Common Vulnerabilities and " + "Exposures (CVE) in container images. Your role is to methodically answer " + "investigation questions using available tools to determine if vulnerabilities " + "are exploitable in the specific container context. You have access to the " + "container's source code, documentation, and dependency information through " + "specialized search and analysis tools." +) + +LANGGRAPH_SYSTEM_PROMPT_TEMPLATE = """{sys_prompt} + + +Answer the investigation question using the available tools. If the input is not a question, formulate it into a question first. A Tool Selection Strategy is provided to help you decide which tools to use. Focus on answering the question. Summarize key findings and evidence concisely in the final answer. + + + +{tools} + + + +{tool_selection_strategy} + + + +Follow this format exactly: +- thought: Brief reasoning about next step (max 3-4 sentences) +- mode: "act" to call tool (when more information is needed), "finish" to return the final answer (when you have sufficient evidence) +- actions: When mode is "act", a single object with "tool", "input", "reason". The "input" MUST be a dict with "query" key - format per tool in AVAILABLE_TOOLS. NEVER use empty input {{}}. +- final_answer: When mode is "finish", a concise answer (3-5 sentences) with key evidence. Do not repeat tool outputs verbatim. + + +CRITICAL: Keep thought under 100 words and final_answer under 150 words to stay within token limits. +""" + +def build_system_prompt( + tool_descriptions: str, + tool_guidance: str, + sys_prompt: str | None = None, +) -> str: + sys_prompt = sys_prompt or AGENT_SYS_PROMPT + return LANGGRAPH_SYSTEM_PROMPT_TEMPLATE.format( + sys_prompt=sys_prompt, + tools=tool_descriptions, + tool_selection_strategy=tool_guidance, + ) +### --- End of REACT Prompt Templates ----# From cc25c7c6dfabfffb583ffadde3a836aa61d46630 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Tue, 24 Feb 2026 13:15:48 +0000 Subject: [PATCH 05/60] input field was filled --- .../functions/cve_react_graph.py | 3 ++- .../functions/react_internals.py | 21 ++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/vuln_analysis/functions/cve_react_graph.py b/src/vuln_analysis/functions/cve_react_graph.py index 6825f8ee..508d4f4c 100644 --- a/src/vuln_analysis/functions/cve_react_graph.py +++ b/src/vuln_analysis/functions/cve_react_graph.py @@ -122,7 +122,8 @@ def _create_cve_react_graph(): TOOL_DESCRIPTIONS = """\nCode Keyword Search(*args, **kwargs) - Performs keyword search on container source code for exact text matches. Input should be a function name, class name, or code pattern. Use this first before semantic search tools for precise lookups.\nCVE Web Search(*args, **kwargs) - Searches the web for information about CVEs, vulnerabilities, libraries, and security advisories not available in the container.\nCall Chain Analyzer(*args, **kwargs) - Checks if a function from a package is reachable from application code through the call chain.\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name'.\n Example 1: 'urllib,parse'.\n \n Input format 2(java): 'maven_gav,class_name.function_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: (is_reachable: bool, call_hierarchy_path: list).\nFunction Caller Finder(*args, **kwargs) - Finds functions in a package that call a specific library function. GO ecosystem only.\n Input format: 'package_name,library.function(args_with_literals)'. \n Example: 'github.com/namespace/package_name,errors.New(\"text_literal'\")'.\n Returns: ['package,caller1', 'package,caller2'] or [].\nFunction Locator(*args, **kwargs) - Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, and provides ecosystem type (GO/Python/Java/JavaScript/C/C++).\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name' or 'package_name,class_name.method_name'.\n Example 1: 'libxml2,xmlParseDocument'.\n \n Input format 2(java): 'maven_gav,class_name.method_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: {'ecosystem': str, 'package_msg': str, 'result': [function_names]}.\n""" SYSTEM_MESSAGE = build_system_prompt(TOOL_DESCRIPTIONS, TOOL_GUIDANCE) llm_test = llm_base.with_structured_output(Thought) - response = llm_test.invoke([{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": "Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?"}]) + TEST_QUESTION = "Does the codebase use libpq PQescapeLiteral?" # Shorter for testing + response = llm_test.invoke([{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": TEST_QUESTION}]) print(response) except GraphRecursionError as e: diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 72c683d4..6e4a0c67 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -13,10 +13,11 @@ class ToolCall(BaseModel): tool: str = Field(description="Tool name to execute (must match AVAILABLE_TOOLS exactly)") - input: dict[str, Any] = Field( - description='Arguments for the tool. Use {"query": "value"} - the value format depends on the tool (see AVAILABLE_TOOLS). NEVER use empty {}. Examples: Code Keyword Search {"query": "urllib.parse"}, Call Chain Analyzer {"query": "libpq,PQescapeLiteral"}, Function Locator {"query": "libpq,function_name"}.', - ) - reason: str = Field(description="Why this tool is needed") + input: str = Field(description="Arguments for the tool. Example: Code Keyword Search: PQescapeLiteral") + #input: dict[str, Any] = Field( + # description='{"query": "value"}. Never empty. E.g. Code Keyword Search: {"query": "PQescapeLiteral"}', + #) + #reason: str = Field(description="Why this tool is needed") class Thought(BaseModel): thought: str = Field( @@ -73,10 +74,14 @@ class AgentState(BaseModel): Follow this format exactly: -- thought: Brief reasoning about next step (max 3-4 sentences) -- mode: "act" to call tool (when more information is needed), "finish" to return the final answer (when you have sufficient evidence) -- actions: When mode is "act", a single object with "tool", "input", "reason". The "input" MUST be a dict with "query" key - format per tool in AVAILABLE_TOOLS. NEVER use empty input {{}}. -- final_answer: When mode is "finish", a concise answer (3-5 sentences) with key evidence. Do not repeat tool outputs verbatim. +- mode: Set to "act" if you need to use a tool, or "finish" if you have the final answer. +- actions: Required only when mode="act". Provide exactly one tool call. + - input: Arguments for the tool. Example: Code Keyword Search: PQescapeLiteral +- final_answer: Required only when mode="finish". Summarize your findings here. +- thought: Always provide your internal reasoning, regardless of the mode. +CRITICAL: +1. The "input" field inside "actions" MUST contain the correct keys for the tool.See AVAILABLE_TOOLS for the available their input keys. +2. Never provide an empty dictionary for input. CRITICAL: Keep thought under 100 words and final_answer under 150 words to stay within token limits. From b37fe455f3788262be9e0370eec4afd2893be526 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Tue, 24 Feb 2026 13:50:02 +0000 Subject: [PATCH 06/60] arugment filled --- .../functions/cve_react_graph.py | 2 +- .../functions/react_internals.py | 46 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/vuln_analysis/functions/cve_react_graph.py b/src/vuln_analysis/functions/cve_react_graph.py index 508d4f4c..a9488465 100644 --- a/src/vuln_analysis/functions/cve_react_graph.py +++ b/src/vuln_analysis/functions/cve_react_graph.py @@ -122,7 +122,7 @@ def _create_cve_react_graph(): TOOL_DESCRIPTIONS = """\nCode Keyword Search(*args, **kwargs) - Performs keyword search on container source code for exact text matches. Input should be a function name, class name, or code pattern. Use this first before semantic search tools for precise lookups.\nCVE Web Search(*args, **kwargs) - Searches the web for information about CVEs, vulnerabilities, libraries, and security advisories not available in the container.\nCall Chain Analyzer(*args, **kwargs) - Checks if a function from a package is reachable from application code through the call chain.\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name'.\n Example 1: 'urllib,parse'.\n \n Input format 2(java): 'maven_gav,class_name.function_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: (is_reachable: bool, call_hierarchy_path: list).\nFunction Caller Finder(*args, **kwargs) - Finds functions in a package that call a specific library function. GO ecosystem only.\n Input format: 'package_name,library.function(args_with_literals)'. \n Example: 'github.com/namespace/package_name,errors.New(\"text_literal'\")'.\n Returns: ['package,caller1', 'package,caller2'] or [].\nFunction Locator(*args, **kwargs) - Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, and provides ecosystem type (GO/Python/Java/JavaScript/C/C++).\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name' or 'package_name,class_name.method_name'.\n Example 1: 'libxml2,xmlParseDocument'.\n \n Input format 2(java): 'maven_gav,class_name.method_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: {'ecosystem': str, 'package_msg': str, 'result': [function_names]}.\n""" SYSTEM_MESSAGE = build_system_prompt(TOOL_DESCRIPTIONS, TOOL_GUIDANCE) llm_test = llm_base.with_structured_output(Thought) - TEST_QUESTION = "Does the codebase use libpq PQescapeLiteral?" # Shorter for testing + TEST_QUESTION = "Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?" # Shorter for testing response = llm_test.invoke([{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": TEST_QUESTION}]) print(response) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 6e4a0c67..a7a9914e 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -12,12 +12,12 @@ #---- REACT Schemas ----# class ToolCall(BaseModel): - tool: str = Field(description="Tool name to execute (must match AVAILABLE_TOOLS exactly)") - input: str = Field(description="Arguments for the tool. Example: Code Keyword Search: PQescapeLiteral") + tool: str = Field(description="Exact tool name from AVAILABLE_TOOLS") + tool_input: str = Field(description="The input for the tool. Example: Code Keyword Search: PQescapeLiteral") #input: dict[str, Any] = Field( - # description='{"query": "value"}. Never empty. E.g. Code Keyword Search: {"query": "PQescapeLiteral"}', + # description="MUST be a non-empty dict with keys matching the tool's requirements (e.g., {'query': 'val'} or {'package_name': 'pkg', 'function_name': 'fn'})." #) - #reason: str = Field(description="Why this tool is needed") + reason: str = Field(description="Briefly explain why this specific tool/input helps the investigation") class Thought(BaseModel): thought: str = Field( @@ -58,34 +58,34 @@ class AgentState(BaseModel): "specialized search and analysis tools." ) +# Update LANGGRAPH_SYSTEM_PROMPT_TEMPLATE in react_internals.py LANGGRAPH_SYSTEM_PROMPT_TEMPLATE = """{sys_prompt} - -Answer the investigation question using the available tools. If the input is not a question, formulate it into a question first. A Tool Selection Strategy is provided to help you decide which tools to use. Focus on answering the question. Summarize key findings and evidence concisely in the final answer. - - {tools} - + {tool_selection_strategy} - + + + +1. Output MUST be valid JSON. +2. 'thought' < 100 words. 'final_answer' < 150 words. +3. If mode="act", 'actions' is REQUIRED. If mode="finish", 'final_answer' is REQUIRED. +4. - input:MUST be a non-empty string with the format of the tool's requirements. + - -Follow this format exactly: -- mode: Set to "act" if you need to use a tool, or "finish" if you have the final answer. -- actions: Required only when mode="act". Provide exactly one tool call. - - input: Arguments for the tool. Example: Code Keyword Search: PQescapeLiteral -- final_answer: Required only when mode="finish". Summarize your findings here. -- thought: Always provide your internal reasoning, regardless of the mode. -CRITICAL: -1. The "input" field inside "actions" MUST contain the correct keys for the tool.See AVAILABLE_TOOLS for the available their input keys. -2. Never provide an empty dictionary for input. - +Return your response in this JSON format: +{{ + "thought": "...", + "mode": "act" | "finish", + "actions": {{ "tool": "...", "input": {{ "query": "..." }}, "reason": "..." }}, + "final_answer": "..." +}} -CRITICAL: Keep thought under 100 words and final_answer under 150 words to stay within token limits. -""" +RESPONSE: +{{""" def build_system_prompt( tool_descriptions: str, From 9ddea1a546fd6b14880338d431b99f829feef112 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Tue, 24 Feb 2026 17:37:30 +0200 Subject: [PATCH 07/60] save progress --- src/vuln_analysis/functions/cve_agent.py | 50 ++++++++++++- .../functions/cve_react_graph.py | 74 +++++++++++-------- .../functions/react_internals.py | 31 +++++--- 3 files changed, 111 insertions(+), 44 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 1569eec8..bf175e86 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -35,6 +35,10 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id +from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought +from langgraph.graph import StateGraph,END,START +from langgraph.prebuilt import ToolNode + logger = LoggingFactory.get_agent_logger(__name__) class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): @@ -93,11 +97,53 @@ async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builde async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Builder,state: AgentMorpheusEngineState): - tools, tool_descriptions = await common_build_tools(config, builder,state) + tools, tool_descriptions = await common_build_tools(config, builder, state) llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + thought_llm = llm.with_structured_output(Thought) tool_guidance = "\n".join(tool_descriptions) - prompt_template_str = get_agent_prompt(config.prompt, config.prompt_examples) + descriptions = "\n".join(tool_descriptions) + system_prompt = build_system_prompt(descriptions, tool_guidance) + from langchain_core.messages import SystemMessage, AIMessage + tool_node = ToolNode(tools) + TOOL_NODE = "tool_node" + THOUGHT_NODE = "thought_node" + def thought_node(state: AgentState) -> AgentState: + messages = [SystemMessage(content=system_prompt)] + state["messages"] + # 2. Invoke LLM with structured output + # This returns a Thought object based on your pydantic class + response: Thought = thought_llm.invoke(messages) + + # 3. Create an AIMessage to keep the history in state["messages"] + # We store the 'thought' text as the content for the graph's history + ai_message = AIMessage(content=response.thought) + + # 4. Return the updates + # LangGraph merges this dict into the existing state + return { + "messages": [ai_message], + "thought": response, + "step": state.get("step", 0) + 1, # Increment the step counter here + "max_steps": 6, + "observation": None + } + def should_continue(state: AgentState) -> str: + if state.get("thought", None).mode == "finish": + return END + if state.get("step", 0) >= state.get("max_steps", 6): + return END + return TOOL_NODE + def create_graph(): + flow = StateGraph(AgentState) + flow.add_node(THOUGHT_NODE, thought_node) + flow.add_node(TOOL_NODE, tool_node) + flow.add_edge(START, THOUGHT_NODE) + flow.add_conditional_edges(THOUGHT_NODE ,should_continue,{END: END,TOOL_NODE: TOOL_NODE}) + flow.add_edge(TOOL_NODE, THOUGHT_NODE) + app = flow.compile() + app.get_graph().draw_mermaid_png(output_file_path="flow.png") + return app + return create_graph() async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState) -> AgentExecutor: diff --git a/src/vuln_analysis/functions/cve_react_graph.py b/src/vuln_analysis/functions/cve_react_graph.py index a9488465..8cc5d5d5 100644 --- a/src/vuln_analysis/functions/cve_react_graph.py +++ b/src/vuln_analysis/functions/cve_react_graph.py @@ -8,6 +8,47 @@ from pydantic import BaseModel, Field from langchain_core.tools import StructuredTool from typing import Literal +from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought,Observation + +ACT = "act" +LAST = -1 +os.environ["OPENAI_API_KEY"] = "EMPTY" +base_url = os.environ["NVIDIA_API_BASE"] +model_name = os.environ["CVE_AGENT_EXECUTOR_MODEL_NAME"] +llm_base = ChatOpenAI(base_url=base_url, model=model_name ,temperature=0.0, max_completion_tokens=2000, top_p=0.01) + + + + + +def _create_cve_react_graph(): + flow = StateGraph(AgentState) + flow.add_node(AGENT_REASON, thought_node) + flow.add_node(ACT, tool_node) + flow.add_edge(START, AGENT_REASON) + flow.add_conditional_edges(AGENT_REASON,should_continue,{END: END,ACT: ACT}) + flow.add_edge(ACT, AGENT_REASON) + app = flow.compile() + app.get_graph().draw_mermaid_png(output_file_path="flow.png") + return app + +if __name__ == "__main__": + print("Hello react graph, World!") + try: + #app =_create_cve_react_graph() + #result = app.invoke({"messages": [HumanMessage(content="Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?")]}) + #print(result) + TOOL_GUIDANCE = """Code Keyword Search: Exact text matching for function names, class names, or imports\nCall Chain Analyzer: Checks if functions are reachable from application code\nFunction Caller Finder: Finds which functions call specific library functions\nUse 'Function Caller Finder' + 'Call Chain Analyzer' together to trace function reachability\nCVE Web Search: External vulnerability information lookup""" + TOOL_DESCRIPTIONS = """\nCode Keyword Search(*args, **kwargs) - Performs keyword search on container source code for exact text matches. Input should be a function name, class name, or code pattern. Use this first before semantic search tools for precise lookups.\nCVE Web Search(*args, **kwargs) - Searches the web for information about CVEs, vulnerabilities, libraries, and security advisories not available in the container.\nCall Chain Analyzer(*args, **kwargs) - Checks if a function from a package is reachable from application code through the call chain.\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name'.\n Example 1: 'urllib,parse'.\n \n Input format 2(java): 'maven_gav,class_name.function_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: (is_reachable: bool, call_hierarchy_path: list).\nFunction Caller Finder(*args, **kwargs) - Finds functions in a package that call a specific library function. GO ecosystem only.\n Input format: 'package_name,library.function(args_with_literals)'. \n Example: 'github.com/namespace/package_name,errors.New(\"text_literal'\")'.\n Returns: ['package,caller1', 'package,caller2'] or [].\nFunction Locator(*args, **kwargs) - Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, and provides ecosystem type (GO/Python/Java/JavaScript/C/C++).\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name' or 'package_name,class_name.method_name'.\n Example 1: 'libxml2,xmlParseDocument'.\n \n Input format 2(java): 'maven_gav,class_name.method_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: {'ecosystem': str, 'package_msg': str, 'result': [function_names]}.\n""" + SYSTEM_MESSAGE = build_system_prompt(TOOL_DESCRIPTIONS, TOOL_GUIDANCE) + llm_test = llm_base.with_structured_output(Thought) + TEST_QUESTION = "Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?" # Shorter for testing + response = llm_test.invoke([{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": TEST_QUESTION}]) + print(response) + + except GraphRecursionError as e: + print(f"GraphRecursionError: {e}") + print("GraphRecursionError: Stopping forcefully.") class FinalAnswer(BaseModel): """Use this when you are done and have the answer.""" @@ -38,12 +79,6 @@ class AgentResponse(BaseModel): """ AGENT_REASON ="agent_reason" -ACT = "act" -LAST = -1 -os.environ["OPENAI_API_KEY"] = "EMPTY" -base_url = os.environ["NVIDIA_API_BASE"] -model_name = os.environ["CVE_AGENT_EXECUTOR_MODEL_NAME"] -llm_base = ChatOpenAI(base_url=base_url, model=model_name ,temperature=0.0, max_completion_tokens=2000, top_p=0.01) def triple(num: float) -> float: """ @@ -98,7 +133,7 @@ def should_continue(state: MessagesState)->str: return END return ACT -def _create_cve_react_graph(): +def _create_cve_react_graph_debug(): flow = StateGraph(MessagesState) flow.add_node(AGENT_REASON, agent_reasoning_node) flow.add_node(ACT, tool_node) @@ -107,27 +142,4 @@ def _create_cve_react_graph(): flow.add_edge(ACT, AGENT_REASON) app = flow.compile() app.get_graph().draw_mermaid_png(output_file_path="flow.png") - return app - - -from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought,Observation - -if __name__ == "__main__": - print("Hello react graph, World!") - try: - #app =_create_cve_react_graph() - #result = app.invoke({"messages": [HumanMessage(content="Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?")]}) - #print(result) - TOOL_GUIDANCE = """Code Keyword Search: Exact text matching for function names, class names, or imports\nCall Chain Analyzer: Checks if functions are reachable from application code\nFunction Caller Finder: Finds which functions call specific library functions\nUse 'Function Caller Finder' + 'Call Chain Analyzer' together to trace function reachability\nCVE Web Search: External vulnerability information lookup""" - TOOL_DESCRIPTIONS = """\nCode Keyword Search(*args, **kwargs) - Performs keyword search on container source code for exact text matches. Input should be a function name, class name, or code pattern. Use this first before semantic search tools for precise lookups.\nCVE Web Search(*args, **kwargs) - Searches the web for information about CVEs, vulnerabilities, libraries, and security advisories not available in the container.\nCall Chain Analyzer(*args, **kwargs) - Checks if a function from a package is reachable from application code through the call chain.\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name'.\n Example 1: 'urllib,parse'.\n \n Input format 2(java): 'maven_gav,class_name.function_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: (is_reachable: bool, call_hierarchy_path: list).\nFunction Caller Finder(*args, **kwargs) - Finds functions in a package that call a specific library function. GO ecosystem only.\n Input format: 'package_name,library.function(args_with_literals)'. \n Example: 'github.com/namespace/package_name,errors.New(\"text_literal'\")'.\n Returns: ['package,caller1', 'package,caller2'] or [].\nFunction Locator(*args, **kwargs) - Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, and provides ecosystem type (GO/Python/Java/JavaScript/C/C++).\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name' or 'package_name,class_name.method_name'.\n Example 1: 'libxml2,xmlParseDocument'.\n \n Input format 2(java): 'maven_gav,class_name.method_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: {'ecosystem': str, 'package_msg': str, 'result': [function_names]}.\n""" - SYSTEM_MESSAGE = build_system_prompt(TOOL_DESCRIPTIONS, TOOL_GUIDANCE) - llm_test = llm_base.with_structured_output(Thought) - TEST_QUESTION = "Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?" # Shorter for testing - response = llm_test.invoke([{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": TEST_QUESTION}]) - print(response) - - except GraphRecursionError as e: - print(f"GraphRecursionError: {e}") - print("GraphRecursionError: Stopping forcefully.") - - + return app \ No newline at end of file diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index a7a9914e..e79fbedf 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -7,8 +7,9 @@ from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field -from typing import Any,Optional +from typing import Any from typing import Literal +from langgraph.graph import MessagesState #---- REACT Schemas ----# class ToolCall(BaseModel): @@ -39,13 +40,13 @@ class Observation(BaseModel): memory: str = Field(description="Compressed working memory summary") -class AgentState(BaseModel): - #goal: str = Field(description="Investigation question") - #step: int = Field(description="Current step number") - #max_steps: int = Field(description="Maximum number of steps") - memory: str | None = Field(default=None, description="Compressed working memory summary") - thought: Thought | None = Field(description="Current thought") - observation: Observation | None = Field(description="Current observation") +class AgentState(MessagesState): + #goal: str + step: int = 0 + max_steps: int = 6 + #memory: str | None = None + thought: Thought | None = None + observation: Observation | None = None ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# @@ -69,6 +70,12 @@ class AgentState(BaseModel): {tool_selection_strategy} +{tool_instructions} + +RESPONSE: +{{""" + +AGENT_THOUGHT_INSTRUCTIONS = """ 1. Output MUST be valid JSON. 2. 'thought' < 100 words. 'final_answer' < 150 words. @@ -83,19 +90,21 @@ class AgentState(BaseModel): "actions": {{ "tool": "...", "input": {{ "query": "..." }}, "reason": "..." }}, "final_answer": "..." }} - -RESPONSE: -{{""" +""" def build_system_prompt( tool_descriptions: str, tool_guidance: str, + instructions: str = AGENT_THOUGHT_INSTRUCTIONS, sys_prompt: str | None = None, ) -> str: sys_prompt = sys_prompt or AGENT_SYS_PROMPT return LANGGRAPH_SYSTEM_PROMPT_TEMPLATE.format( sys_prompt=sys_prompt, tools=tool_descriptions, + tool_instructions=instructions, tool_selection_strategy=tool_guidance, ) ### --- End of REACT Prompt Templates ----# + + From 0ff74d3abfef8b4ade4c5b76f17641e70f7deabc Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Tue, 24 Feb 2026 19:31:57 +0200 Subject: [PATCH 08/60] debug tools check agent instances --- src/vuln_analysis/functions/cve_agent.py | 50 ++++++++++++++++++------ 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index bf175e86..82dd113f 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -38,7 +38,8 @@ from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought from langgraph.graph import StateGraph,END,START from langgraph.prebuilt import ToolNode - +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage +import uuid logger = LoggingFactory.get_agent_logger(__name__) class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): @@ -103,20 +104,38 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build tool_guidance = "\n".join(tool_descriptions) descriptions = "\n".join(tool_descriptions) system_prompt = build_system_prompt(descriptions, tool_guidance) - from langchain_core.messages import SystemMessage, AIMessage + #from langchain_core.messages import SystemMessage, AIMessage tool_node = ToolNode(tools) TOOL_NODE = "tool_node" THOUGHT_NODE = "thought_node" - def thought_node(state: AgentState) -> AgentState: + async def thought_node(state: AgentState) -> AgentState: messages = [SystemMessage(content=system_prompt)] + state["messages"] # 2. Invoke LLM with structured output # This returns a Thought object based on your pydantic class response: Thought = thought_llm.invoke(messages) - # 3. Create an AIMessage to keep the history in state["messages"] - # We store the 'thought' text as the content for the graph's history - ai_message = AIMessage(content=response.thought) + #check if the response is a final answer + if response.mode == "finish": + ai_message = AIMessage(content=response.final_answer) + else: + #get the tool name from the actions + tool_name = response.actions.tool + # We convert the Pydantic object to a dict, explicitly excluding the type tag + # This leaves ONLY the arguments (e.g., {'query': '...'} or {'a': 1, 'b': 2}) + arguments = {"query": response.actions.tool_input} + + # Construct the message for ToolNode + tool_call_id = str(uuid.uuid4()) + ai_message = AIMessage( + content=response.thought, + tool_calls=[{ + "name": tool_name, + "args": arguments, + "id": tool_call_id + }] + ) + # 4. Return the updates # LangGraph merges this dict into the existing state return { @@ -126,14 +145,14 @@ def thought_node(state: AgentState) -> AgentState: "max_steps": 6, "observation": None } - def should_continue(state: AgentState) -> str: + async def should_continue(state: AgentState) -> str: if state.get("thought", None).mode == "finish": return END if state.get("step", 0) >= state.get("max_steps", 6): return END return TOOL_NODE - def create_graph(): + async def create_graph(): flow = StateGraph(AgentState) flow.add_node(THOUGHT_NODE, thought_node) flow.add_node(TOOL_NODE, tool_node) @@ -143,7 +162,7 @@ def create_graph(): app = flow.compile() app.get_graph().draw_mermaid_png(output_file_path="flow.png") return app - return create_graph() + return await create_graph() async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState) -> AgentExecutor: @@ -187,11 +206,20 @@ async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, async def _process_steps(agent, steps, semaphore): async def _process_step(step): + initial_state = {"input": step} + if not isinstance(agent, AgentExecutor): + initial_state = { + "messages": [HumanMessage(content=step)], + "step": 0, + "max_steps": 6, + "thought": None, + "observation": None, + } if semaphore: async with semaphore: - return await agent.ainvoke({"input": step}) + return await agent.ainvoke(initial_state) else: - return await agent.ainvoke({"input": step}) + return await agent.ainvoke(initial_state) return await asyncio.gather(*(_process_step(step) for step in steps), return_exceptions=True) From 71fdd45f32444e76de5160c735c62777e21f1ff4 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 25 Feb 2026 11:40:18 +0200 Subject: [PATCH 09/60] llm adds more detail params for functions, agent state with input and output params --- src/vuln_analysis/functions/cve_agent.py | 32 +++- .../functions/cve_react_graph.py | 145 ------------------ .../functions/react_internals.py | 60 +++++++- 3 files changed, 79 insertions(+), 158 deletions(-) delete mode 100644 src/vuln_analysis/functions/cve_react_graph.py diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 82dd113f..75c85d3e 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -35,7 +35,7 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought +from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought, _build_tool_arguments, FORCED_FINISH_PROMPT from langgraph.graph import StateGraph,END,START from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage, AIMessage @@ -108,6 +108,7 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build tool_node = ToolNode(tools) TOOL_NODE = "tool_node" THOUGHT_NODE = "thought_node" + FORCED_FINISH_NODE = "forced_finish_node" async def thought_node(state: AgentState) -> AgentState: messages = [SystemMessage(content=system_prompt)] + state["messages"] # 2. Invoke LLM with structured output @@ -116,14 +117,16 @@ async def thought_node(state: AgentState) -> AgentState: #check if the response is a final answer + final_answer = "waiting for the agent to respond" if response.mode == "finish": ai_message = AIMessage(content=response.final_answer) + final_answer = response.final_answer else: #get the tool name from the actions tool_name = response.actions.tool # We convert the Pydantic object to a dict, explicitly excluding the type tag # This leaves ONLY the arguments (e.g., {'query': '...'} or {'a': 1, 'b': 2}) - arguments = {"query": response.actions.tool_input} + arguments = _build_tool_arguments(response.actions) # Construct the message for ToolNode tool_call_id = str(uuid.uuid4()) @@ -143,22 +146,39 @@ async def thought_node(state: AgentState) -> AgentState: "thought": response, "step": state.get("step", 0) + 1, # Increment the step counter here "max_steps": 6, - "observation": None + "observation": None, + "output":final_answer } + async def should_continue(state: AgentState) -> str: if state.get("thought", None).mode == "finish": return END if state.get("step", 0) >= state.get("max_steps", 6): - return END + return FORCED_FINISH_NODE return TOOL_NODE + async def forced_finish_node(state: AgentState) -> AgentState: + messages = [SystemMessage(content=system_prompt)] + state["messages"] + messages.append(HumanMessage(content=FORCED_FINISH_PROMPT)) + response: Thought = thought_llm.invoke(messages) + final_answer = "waiting for the agent to respond" + if response.mode == "finish": + ai_message = AIMessage(content=response.final_answer) + final_answer = response.final_answer + else: + ai_message = AIMessage(content="Failed to generate a final answer within the maximum allowed steps.") + return {"messages": [ai_message],"thought": response,"step": state.get("step", 0),"max_steps":state.get("max_steps", 6),"observation": state.get("observation", None),"output":final_answer} + + async def create_graph(): flow = StateGraph(AgentState) flow.add_node(THOUGHT_NODE, thought_node) flow.add_node(TOOL_NODE, tool_node) + flow.add_node(FORCED_FINISH_NODE, forced_finish_node) flow.add_edge(START, THOUGHT_NODE) - flow.add_conditional_edges(THOUGHT_NODE ,should_continue,{END: END,TOOL_NODE: TOOL_NODE}) + flow.add_conditional_edges(THOUGHT_NODE ,should_continue,{END: END,TOOL_NODE: TOOL_NODE,FORCED_FINISH_NODE: FORCED_FINISH_NODE}) flow.add_edge(TOOL_NODE, THOUGHT_NODE) + flow.add_edge(FORCED_FINISH_NODE, END) app = flow.compile() app.get_graph().draw_mermaid_png(output_file_path="flow.png") return app @@ -209,11 +229,13 @@ async def _process_step(step): initial_state = {"input": step} if not isinstance(agent, AgentExecutor): initial_state = { + "input": step, "messages": [HumanMessage(content=step)], "step": 0, "max_steps": 6, "thought": None, "observation": None, + "output":"waiting for the agent to respond" } if semaphore: async with semaphore: diff --git a/src/vuln_analysis/functions/cve_react_graph.py b/src/vuln_analysis/functions/cve_react_graph.py deleted file mode 100644 index 8cc5d5d5..00000000 --- a/src/vuln_analysis/functions/cve_react_graph.py +++ /dev/null @@ -1,145 +0,0 @@ -import os -import uuid -from langgraph.graph import MessagesState, StateGraph,END,START -from langgraph.prebuilt import ToolNode, tools_condition -from langgraph.errors import GraphRecursionError -from langchain_openai import ChatOpenAI -from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage, BaseMessage -from pydantic import BaseModel, Field -from langchain_core.tools import StructuredTool -from typing import Literal -from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought,Observation - -ACT = "act" -LAST = -1 -os.environ["OPENAI_API_KEY"] = "EMPTY" -base_url = os.environ["NVIDIA_API_BASE"] -model_name = os.environ["CVE_AGENT_EXECUTOR_MODEL_NAME"] -llm_base = ChatOpenAI(base_url=base_url, model=model_name ,temperature=0.0, max_completion_tokens=2000, top_p=0.01) - - - - - -def _create_cve_react_graph(): - flow = StateGraph(AgentState) - flow.add_node(AGENT_REASON, thought_node) - flow.add_node(ACT, tool_node) - flow.add_edge(START, AGENT_REASON) - flow.add_conditional_edges(AGENT_REASON,should_continue,{END: END,ACT: ACT}) - flow.add_edge(ACT, AGENT_REASON) - app = flow.compile() - app.get_graph().draw_mermaid_png(output_file_path="flow.png") - return app - -if __name__ == "__main__": - print("Hello react graph, World!") - try: - #app =_create_cve_react_graph() - #result = app.invoke({"messages": [HumanMessage(content="Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?")]}) - #print(result) - TOOL_GUIDANCE = """Code Keyword Search: Exact text matching for function names, class names, or imports\nCall Chain Analyzer: Checks if functions are reachable from application code\nFunction Caller Finder: Finds which functions call specific library functions\nUse 'Function Caller Finder' + 'Call Chain Analyzer' together to trace function reachability\nCVE Web Search: External vulnerability information lookup""" - TOOL_DESCRIPTIONS = """\nCode Keyword Search(*args, **kwargs) - Performs keyword search on container source code for exact text matches. Input should be a function name, class name, or code pattern. Use this first before semantic search tools for precise lookups.\nCVE Web Search(*args, **kwargs) - Searches the web for information about CVEs, vulnerabilities, libraries, and security advisories not available in the container.\nCall Chain Analyzer(*args, **kwargs) - Checks if a function from a package is reachable from application code through the call chain.\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name'.\n Example 1: 'urllib,parse'.\n \n Input format 2(java): 'maven_gav,class_name.function_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: (is_reachable: bool, call_hierarchy_path: list).\nFunction Caller Finder(*args, **kwargs) - Finds functions in a package that call a specific library function. GO ecosystem only.\n Input format: 'package_name,library.function(args_with_literals)'. \n Example: 'github.com/namespace/package_name,errors.New(\"text_literal'\")'.\n Returns: ['package,caller1', 'package,caller2'] or [].\nFunction Locator(*args, **kwargs) - Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, and provides ecosystem type (GO/Python/Java/JavaScript/C/C++).\n Make sure the input format is matching exactly one of the following formats:\n \n Input format 1: 'package_name,function_name' or 'package_name,class_name.method_name'.\n Example 1: 'libxml2,xmlParseDocument'.\n \n Input format 2(java): 'maven_gav,class_name.method_name'.\n Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.\n \n Returns: {'ecosystem': str, 'package_msg': str, 'result': [function_names]}.\n""" - SYSTEM_MESSAGE = build_system_prompt(TOOL_DESCRIPTIONS, TOOL_GUIDANCE) - llm_test = llm_base.with_structured_output(Thought) - TEST_QUESTION = "Does the application properly handle and validate input data to prevent malicious SQL injection, especially when using the affected PostgreSQL libpq functions?" # Shorter for testing - response = llm_test.invoke([{"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": TEST_QUESTION}]) - print(response) - - except GraphRecursionError as e: - print(f"GraphRecursionError: {e}") - print("GraphRecursionError: Stopping forcefully.") - -class FinalAnswer(BaseModel): - """Use this when you are done and have the answer.""" - action_type: Literal["final_answer"] = "final_answer" - response: str = Field(description="The final response to the user") - -class TripleAction(BaseModel): - action_type: Literal["triple"] = "triple" - num: float = Field(description="The number to triple") - -action_models = [TripleAction, FinalAnswer] - -# The union type forces the LLM to pick exactly one of these valid structures -class AgentResponse(BaseModel): - command: TripleAction | FinalAnswer - -tool_descriptions = "\n".join(["triple: param num: a number to triple \nreturns: the triple of the input number\n"]) - -SYSTEM_MESSAGE = f"""You are a reasoning agent. -You have access to: {tool_descriptions} - -Your logic flow must be: -1. **REASON**: What has been asked? What has been done? What is left to do? -2. **ACT**: Call a tool if more info or a calculation is needed. -3. **VERIFY**: (After a tool returns) Did this tool result satisfy the ENTIRETY of the user's request? - -CRITICAL: If the user asked for a multi-step task (e.g., "Find X AND then do Y"), do not trigger 'FinalAnswer' after finding X. You must loop back to perform Y. -""" - -AGENT_REASON ="agent_reason" - -def triple(num: float) -> float: - """ - :param num: a number to triple - :return: the number tripled -> multiplied by 3 - """ - return 3 * float(num) - -tools = [StructuredTool.from_function(triple)] -#tools = [triple] -tool_node = ToolNode(tools) - -llm = llm_base.with_structured_output(AgentResponse) - -def agent_reasoning_node(state: MessagesState) -> MessagesState: - response = llm.invoke( - [{"role": "system", "content": SYSTEM_MESSAGE}, *state["messages"]] - ) - command = response.command # Extract the specific action object - print(f'agent_reasoning_node:command: {command}') - # --- ROUTING LOGIC --- - - # Case 1: Final Answer - if isinstance(command, FinalAnswer): - return { - "messages": [AIMessage(content=command.response)] - } - - # Case 2: Tool Call (Search or Calculate) - else: - # We assume the 'action_type' matches the tool name exactly - tool_name = command.action_type - - # We convert the Pydantic object to a dict, explicitly excluding the type tag - # This leaves ONLY the arguments (e.g., {'query': '...'} or {'a': 1, 'b': 2}) - arguments = command.model_dump(exclude={"action_type"}) - - # Construct the message for ToolNode - tool_call_id = str(uuid.uuid4()) - msg = AIMessage( - content="", - tool_calls=[{ - "name": tool_name, - "args": arguments, # This is now GUARANTEED to have the right keys - "id": tool_call_id - }] - ) - return {"messages": [msg]} - -def should_continue(state: MessagesState)->str: - if not state["messages"][LAST].tool_calls: - return END - return ACT - -def _create_cve_react_graph_debug(): - flow = StateGraph(MessagesState) - flow.add_node(AGENT_REASON, agent_reasoning_node) - flow.add_node(ACT, tool_node) - flow.add_edge(START, AGENT_REASON) - flow.add_conditional_edges(AGENT_REASON,should_continue,{END: END,ACT: ACT}) - flow.add_edge(ACT, AGENT_REASON) - app = flow.compile() - app.get_graph().draw_mermaid_png(output_file_path="flow.png") - return app \ No newline at end of file diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index e79fbedf..5f1764e7 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -14,10 +14,25 @@ class ToolCall(BaseModel): tool: str = Field(description="Exact tool name from AVAILABLE_TOOLS") - tool_input: str = Field(description="The input for the tool. Example: Code Keyword Search: PQescapeLiteral") - #input: dict[str, Any] = Field( - # description="MUST be a non-empty dict with keys matching the tool's requirements (e.g., {'query': 'val'} or {'package_name': 'pkg', 'function_name': 'fn'})." - #) + #tool_input: str = Field(description="The input for the tool. Example: Code Keyword Search: PQescapeLiteral") + package_name: str | None = Field( + default=None, + description="Package/module name. REQUIRED when using Function Locator, Function Caller Finder, or Call Chain Analyzer. E.g. libpq, urllib, github.com/org/pkg" + ) + function_name: str | None = Field( + default=None, + description="Function or method name with optional args. REQUIRED with package_name for code path tools. E.g. PQescapeLiteral(), parse(), errors.New(\"x\")" + ) + # For search tools (Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search) + query: str | None = Field( + default=None, + description="Search query. Use for search tools when package_name/function_name don't apply" + ) + # Fallback: if LLM uses tool_input for simple query-only tools + tool_input: str | None = Field( + default=None, + description="Legacy/fallback input. Prefer package_name+function_name or query." + ) reason: str = Field(description="Briefly explain why this specific tool/input helps the investigation") class Thought(BaseModel): @@ -41,12 +56,13 @@ class Observation(BaseModel): memory: str = Field(description="Compressed working memory summary") class AgentState(MessagesState): - #goal: str + input: str = "" step: int = 0 max_steps: int = 6 #memory: str | None = None thought: Thought | None = None observation: Observation | None = None + output: str = "" ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# @@ -80,18 +96,36 @@ class AgentState(MessagesState): 1. Output MUST be valid JSON. 2. 'thought' < 100 words. 'final_answer' < 150 words. 3. If mode="act", 'actions' is REQUIRED. If mode="finish", 'final_answer' is REQUIRED. -4. - input:MUST be a non-empty string with the format of the tool's requirements. +4. For 'Function Locator', 'Function Caller Finder', 'Call Chain Analyzer': + - MUST set package_name AND function_name (e.g. "libpq", "PQescapeLiteral"). + - Do NOT use query for these tools. +5. For 'Code Keyword Search', 'Code Semantic Search', 'Docs Semantic Search', 'CVE Web Search': + - Use query with a single search term or phrase. Return your response in this JSON format: {{ "thought": "...", "mode": "act" | "finish", - "actions": {{ "tool": "...", "input": {{ "query": "..." }}, "reason": "..." }}, + "actions": {{ "tool": "...", + "package_name": "..." | null, + "function_name": "..." | null, + "query": "..." | null, + ,"reason": "..." }}, "final_answer": "..." }} """ +FORCED_FINISH_PROMPT = """ + +You are at or past your maximum allowed steps. +You MUST use mode="finish" and provide a final_answer NOW. +Do NOT call any more tools. Summarize the evidence you have gathered so far into a concise +final_answer (3-5 sentences) that directly addresses the investigation question. + +""" + +### --- End of REACT Prompt Templates ----# def build_system_prompt( tool_descriptions: str, tool_guidance: str, @@ -105,6 +139,16 @@ def build_system_prompt( tool_instructions=instructions, tool_selection_strategy=tool_guidance, ) -### --- End of REACT Prompt Templates ----# + +def _build_tool_arguments(actions: ToolCall)->dict[str, Any]: + pkg_tools = {"Function Locator", "Function Caller Finder", "Call Chain Analyzer"} + if actions.tool in pkg_tools and actions.package_name and actions.function_name: + return {"query": f"{actions.package_name},{actions.function_name}"} + if actions.query: + return {"query": actions.query} + if actions.tool_input: + return {"query": actions.tool_input} # fallback + raise ValueError(f"Tool {actions.tool} requires package_name+function_name or query/tool_input") + From 69a0b313327532dfd69cbc7647d2c98dd39299b0 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 25 Feb 2026 13:25:06 +0200 Subject: [PATCH 10/60] solve the verification characters issue --- src/vuln_analysis/functions/react_internals.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 5f1764e7..6f14daf7 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -38,7 +38,7 @@ class ToolCall(BaseModel): class Thought(BaseModel): thought: str = Field( description="Brief reasoning about next step (max 3-4 sentences)", - max_length=300, + max_length=3000, ) mode: Literal["act", "finish"] = Field(description="'act' to call tools (when more information is needed), 'finish' to return the final answer (when you have sufficient evidence)") @@ -47,7 +47,7 @@ class Thought(BaseModel): final_answer: str | None = Field( default=None, description="When mode is 'finish', concise answer (3-5 sentences) with key evidence", - max_length=500, + max_length=3000, ) class Observation(BaseModel): @@ -115,13 +115,18 @@ class AgentState(MessagesState): "final_answer": "..." }} """ -FORCED_FINISH_PROMPT = """ +FORCED_FINISH_PROMPT = f""" You are at or past your maximum allowed steps. You MUST use mode="finish" and provide a final_answer NOW. Do NOT call any more tools. Summarize the evidence you have gathered so far into a concise final_answer (3-5 sentences) that directly addresses the investigation question. + 'final_answer' < 300 words. + 'thought' < 300 words. + +RESPONSE: +{{ """ From 37886af82cdb980df2d6bd9a713cca01a8590539 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 25 Feb 2026 13:53:04 +0200 Subject: [PATCH 11/60] set the max_steps to max_iterations in config --- src/vuln_analysis/functions/cve_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 75c85d3e..cde51de3 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -145,7 +145,7 @@ async def thought_node(state: AgentState) -> AgentState: "messages": [ai_message], "thought": response, "step": state.get("step", 0) + 1, # Increment the step counter here - "max_steps": 6, + "max_steps": config.max_iterations, "observation": None, "output":final_answer } @@ -153,7 +153,7 @@ async def thought_node(state: AgentState) -> AgentState: async def should_continue(state: AgentState) -> str: if state.get("thought", None).mode == "finish": return END - if state.get("step", 0) >= state.get("max_steps", 6): + if state.get("step", 0) >= state.get("max_steps", config.max_iterations): return FORCED_FINISH_NODE return TOOL_NODE @@ -167,7 +167,7 @@ async def forced_finish_node(state: AgentState) -> AgentState: final_answer = response.final_answer else: ai_message = AIMessage(content="Failed to generate a final answer within the maximum allowed steps.") - return {"messages": [ai_message],"thought": response,"step": state.get("step", 0),"max_steps":state.get("max_steps", 6),"observation": state.get("observation", None),"output":final_answer} + return {"messages": [ai_message],"thought": response,"step": state.get("step", 0),"max_steps":state.get("max_steps", config.max_iterations),"observation": state.get("observation", None),"output":final_answer} async def create_graph(): From 30fe5d8445290141069416c2ed11a44c9f268dec Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 25 Feb 2026 15:05:37 +0200 Subject: [PATCH 12/60] add preprocess node --- src/vuln_analysis/functions/cve_agent.py | 16 +++++++++++++++- src/vuln_analysis/functions/react_internals.py | 1 + 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index cde51de3..1b71083e 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -109,6 +109,18 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build TOOL_NODE = "tool_node" THOUGHT_NODE = "thought_node" FORCED_FINISH_NODE = "forced_finish_node" + PRE_PROCESS_NODE = "pre_process_node" + + async def pre_process_node(state: AgentState) -> AgentState: + workflow_state = ctx_state.get() + ecosystem = workflow_state.original_input.input.image.ecosystem.value + for cve_intel in workflow_state.cve_intel: + if cve_intel.rhsa is not None: + print(cve_intel.rhsa.upstream_fix) + print(cve_intel.rhsa.statement) + + return {"ecosystem": ecosystem} + async def thought_node(state: AgentState) -> AgentState: messages = [SystemMessage(content=system_prompt)] + state["messages"] # 2. Invoke LLM with structured output @@ -175,7 +187,9 @@ async def create_graph(): flow.add_node(THOUGHT_NODE, thought_node) flow.add_node(TOOL_NODE, tool_node) flow.add_node(FORCED_FINISH_NODE, forced_finish_node) - flow.add_edge(START, THOUGHT_NODE) + flow.add_node(PRE_PROCESS_NODE, pre_process_node) + flow.add_edge(START, PRE_PROCESS_NODE) + flow.add_edge(PRE_PROCESS_NODE, THOUGHT_NODE) flow.add_conditional_edges(THOUGHT_NODE ,should_continue,{END: END,TOOL_NODE: TOOL_NODE,FORCED_FINISH_NODE: FORCED_FINISH_NODE}) flow.add_edge(TOOL_NODE, THOUGHT_NODE) flow.add_edge(FORCED_FINISH_NODE, END) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 6f14daf7..2ea12d2d 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -63,6 +63,7 @@ class AgentState(MessagesState): thought: Thought | None = None observation: Observation | None = None output: str = "" + ecosystem: str | None = None ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# From 851a93fd52457282503b1b4c95e7589d265352cc Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 25 Feb 2026 15:39:28 +0200 Subject: [PATCH 13/60] save last changes --- src/vuln_analysis/functions/cve_agent.py | 33 +++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 1b71083e..008a4a26 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -41,6 +41,7 @@ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage import uuid logger = LoggingFactory.get_agent_logger(__name__) +from vuln_analysis.utils.prompting import build_tool_descriptions class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): """ @@ -71,8 +72,8 @@ class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): description="Whether to enable CVE Web Search tool or not.") verbose: bool = Field(default=False, description="Set to true for verbose output") -async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builder,state: AgentMorpheusEngineState) -> tuple[list[typing.Any],list[str]]: - from vuln_analysis.utils.prompting import build_tool_descriptions +async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builder,state: AgentMorpheusEngineState) -> tuple[list[typing.Any],list[str],list[str]]: + tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) # Filter tools that are not available based on state tools = [ @@ -91,18 +92,18 @@ async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builde ] # Get tool names after filtering for dynamic guidance enabled_tool_names = [tool.name for tool in tools] - + tool_descriptions_list = [t.name + ": " + t.description for t in tools] # Build tool selection guidance with strategic context tool_descriptions = build_tool_descriptions(enabled_tool_names) - return tools, tool_descriptions + return tools, tool_descriptions, tool_descriptions_list async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Builder,state: AgentMorpheusEngineState): - tools, tool_descriptions = await common_build_tools(config, builder, state) + tools, tool_guidance_list, tool_descriptions_list = await common_build_tools(config, builder, state) llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) thought_llm = llm.with_structured_output(Thought) - tool_guidance = "\n".join(tool_descriptions) - descriptions = "\n".join(tool_descriptions) + tool_guidance = "\n".join(tool_guidance_list) + descriptions = "\n".join(tool_descriptions_list) system_prompt = build_system_prompt(descriptions, tool_guidance) #from langchain_core.messages import SystemMessage, AIMessage tool_node = ToolNode(tools) @@ -111,6 +112,18 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build FORCED_FINISH_NODE = "forced_finish_node" PRE_PROCESS_NODE = "pre_process_node" + def _filter_function_caller_finder(): + filtered_tools = [ + t for t in tools + if t.name != ToolNames.FUNCTION_CALLER_FINDER + ] + list_of_tool_names = [t.name for t in filtered_tools] + list_of_tool_descriptions = [t.name + ": " + t.description for t in filtered_tools] + tool_guidance_list_local = build_tool_descriptions(list_of_tool_names) + tool_guidance_local = "\n".join(tool_guidance_list_local) + descriptions_local = "\n".join(list_of_tool_descriptions) + return tool_guidance_local, descriptions_local + async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value @@ -118,7 +131,9 @@ async def pre_process_node(state: AgentState) -> AgentState: if cve_intel.rhsa is not None: print(cve_intel.rhsa.upstream_fix) print(cve_intel.rhsa.statement) - + tool_guidance_local, descriptions_local = _filter_function_caller_finder() + runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) + print(runtime_prompt) return {"ecosystem": ecosystem} async def thought_node(state: AgentState) -> AgentState: @@ -200,7 +215,7 @@ async def create_graph(): async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState) -> AgentExecutor: - tools, tool_descriptions = await common_build_tools(config, builder, state) + tools, tool_descriptions,_ = await common_build_tools(config, builder, state) llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) tool_guidance = "\n".join(tool_descriptions) From d00670e1ac408122f43270785da2f643f3d2c743 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 25 Feb 2026 14:33:30 +0000 Subject: [PATCH 14/60] last changes 3 --- src/vuln_analysis/functions/cve_agent.py | 4 ++-- src/vuln_analysis/functions/react_internals.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 008a4a26..669b02d2 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -133,8 +133,8 @@ async def pre_process_node(state: AgentState) -> AgentState: print(cve_intel.rhsa.statement) tool_guidance_local, descriptions_local = _filter_function_caller_finder() runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) - print(runtime_prompt) - return {"ecosystem": ecosystem} + #print(runtime_prompt) + return {"ecosystem": ecosystem,"runtime_prompt": runtime_prompt} async def thought_node(state: AgentState) -> AgentState: messages = [SystemMessage(content=system_prompt)] + state["messages"] diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 2ea12d2d..a15029d0 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -64,6 +64,7 @@ class AgentState(MessagesState): observation: Observation | None = None output: str = "" ecosystem: str | None = None + runtime_prompt: str | None = None ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# From 69b1f479dbd4345f8d6229129faad8b0710333a7 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 26 Feb 2026 07:44:14 +0000 Subject: [PATCH 15/60] claude code review --- src/vuln_analysis/functions/cve_agent.py | 110 +++++++++++------- .../functions/react_internals.py | 14 +-- 2 files changed, 68 insertions(+), 56 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 669b02d2..b05ca1c8 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -35,13 +35,14 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt,AgentState,Thought, _build_tool_arguments, FORCED_FINISH_PROMPT -from langgraph.graph import StateGraph,END,START +from vuln_analysis.functions.react_internals import build_system_prompt, AgentState, Thought, _build_tool_arguments, FORCED_FINISH_PROMPT +from vuln_analysis.utils.prompting import build_tool_descriptions +from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage, AIMessage import uuid + logger = LoggingFactory.get_agent_logger(__name__) -from vuln_analysis.utils.prompting import build_tool_descriptions class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): """ @@ -72,7 +73,7 @@ class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): description="Whether to enable CVE Web Search tool or not.") verbose: bool = Field(default=False, description="Set to true for verbose output") -async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builder,state: AgentMorpheusEngineState) -> tuple[list[typing.Any],list[str],list[str]]: +async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState) -> tuple[list[typing.Any], list[str], list[str]]: tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) # Filter tools that are not available based on state @@ -97,16 +98,15 @@ async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builde tool_descriptions = build_tool_descriptions(enabled_tool_names) return tools, tool_descriptions, tool_descriptions_list -async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Builder,state: AgentMorpheusEngineState): +async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState): tools, tool_guidance_list, tool_descriptions_list = await common_build_tools(config, builder, state) llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) thought_llm = llm.with_structured_output(Thought) tool_guidance = "\n".join(tool_guidance_list) descriptions = "\n".join(tool_descriptions_list) - system_prompt = build_system_prompt(descriptions, tool_guidance) - #from langchain_core.messages import SystemMessage, AIMessage - tool_node = ToolNode(tools) + default_system_prompt = build_system_prompt(descriptions, tool_guidance) + tool_node = ToolNode(tools, handle_tool_errors=True) TOOL_NODE = "tool_node" THOUGHT_NODE = "thought_node" FORCED_FINISH_NODE = "forced_finish_node" @@ -123,78 +123,93 @@ def _filter_function_caller_finder(): tool_guidance_local = "\n".join(tool_guidance_list_local) descriptions_local = "\n".join(list_of_tool_descriptions) return tool_guidance_local, descriptions_local - + async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value for cve_intel in workflow_state.cve_intel: if cve_intel.rhsa is not None: - print(cve_intel.rhsa.upstream_fix) - print(cve_intel.rhsa.statement) - tool_guidance_local, descriptions_local = _filter_function_caller_finder() + logger.debug("RHSA upstream_fix: %s", cve_intel.rhsa.upstream_fix) + logger.debug("RHSA statement: %s", cve_intel.rhsa.statement) + tool_guidance_local, descriptions_local = _filter_function_caller_finder() runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) - #print(runtime_prompt) - return {"ecosystem": ecosystem,"runtime_prompt": runtime_prompt} + return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt} async def thought_node(state: AgentState) -> AgentState: - messages = [SystemMessage(content=system_prompt)] + state["messages"] - # 2. Invoke LLM with structured output - # This returns a Thought object based on your pydantic class - response: Thought = thought_llm.invoke(messages) - - - #check if the response is a final answer + # Use runtime_prompt from pre_process if available, otherwise fall back to default + active_prompt = state.get("runtime_prompt") or default_system_prompt + messages = [SystemMessage(content=active_prompt)] + state["messages"] + response: Thought = await thought_llm.ainvoke(messages) + final_answer = "waiting for the agent to respond" if response.mode == "finish": ai_message = AIMessage(content=response.final_answer) final_answer = response.final_answer + elif response.actions is None: + logger.warning("LLM returned mode='act' but actions is None, forcing finish") + ai_message = AIMessage(content=response.thought or "No actions provided, finishing.") + response = Thought( + thought=response.thought or "No actions provided", + mode="finish", + actions=None, + final_answer=response.thought or "Insufficient evidence to provide a definitive answer." + ) + final_answer = response.final_answer else: - #get the tool name from the actions tool_name = response.actions.tool - # We convert the Pydantic object to a dict, explicitly excluding the type tag - # This leaves ONLY the arguments (e.g., {'query': '...'} or {'a': 1, 'b': 2}) arguments = _build_tool_arguments(response.actions) - - # Construct the message for ToolNode tool_call_id = str(uuid.uuid4()) ai_message = AIMessage( - content=response.thought, + content=response.thought, tool_calls=[{ "name": tool_name, - "args": arguments, + "args": arguments, "id": tool_call_id }] ) - # 4. Return the updates - # LangGraph merges this dict into the existing state return { "messages": [ai_message], "thought": response, - "step": state.get("step", 0) + 1, # Increment the step counter here + "step": state.get("step", 0) + 1, "max_steps": config.max_iterations, "observation": None, - "output":final_answer + "output": final_answer } async def should_continue(state: AgentState) -> str: - if state.get("thought", None).mode == "finish": + thought = state.get("thought", None) + if thought is not None and thought.mode == "finish": return END if state.get("step", 0) >= state.get("max_steps", config.max_iterations): return FORCED_FINISH_NODE return TOOL_NODE async def forced_finish_node(state: AgentState) -> AgentState: - messages = [SystemMessage(content=system_prompt)] + state["messages"] + active_prompt = state.get("runtime_prompt") or default_system_prompt + messages = [SystemMessage(content=active_prompt)] + state["messages"] messages.append(HumanMessage(content=FORCED_FINISH_PROMPT)) - response: Thought = thought_llm.invoke(messages) - final_answer = "waiting for the agent to respond" - if response.mode == "finish": + response: Thought = await thought_llm.ainvoke(messages) + if response.mode == "finish" and response.final_answer: ai_message = AIMessage(content=response.final_answer) final_answer = response.final_answer else: - ai_message = AIMessage(content="Failed to generate a final answer within the maximum allowed steps.") - return {"messages": [ai_message],"thought": response,"step": state.get("step", 0),"max_steps":state.get("max_steps", config.max_iterations),"observation": state.get("observation", None),"output":final_answer} + final_answer = "Failed to generate a final answer within the maximum allowed steps." + ai_message = AIMessage(content=final_answer) + response = Thought( + thought=response.thought or "Max steps exceeded", + mode="finish", + actions=None, + final_answer=final_answer + ) + return { + "messages": [ai_message], + "thought": response, + "step": state.get("step", 0), + "max_steps": state.get("max_steps", config.max_iterations), + "observation": state.get("observation", None), + "output": final_answer + } async def create_graph(): @@ -205,11 +220,16 @@ async def create_graph(): flow.add_node(PRE_PROCESS_NODE, pre_process_node) flow.add_edge(START, PRE_PROCESS_NODE) flow.add_edge(PRE_PROCESS_NODE, THOUGHT_NODE) - flow.add_conditional_edges(THOUGHT_NODE ,should_continue,{END: END,TOOL_NODE: TOOL_NODE,FORCED_FINISH_NODE: FORCED_FINISH_NODE}) + flow.add_conditional_edges( + THOUGHT_NODE, + should_continue, + {END: END, TOOL_NODE: TOOL_NODE, FORCED_FINISH_NODE: FORCED_FINISH_NODE} + ) flow.add_edge(TOOL_NODE, THOUGHT_NODE) flow.add_edge(FORCED_FINISH_NODE, END) app = flow.compile() - app.get_graph().draw_mermaid_png(output_file_path="flow.png") + if config.verbose: + app.get_graph().draw_mermaid_png(output_file_path="flow.png") return app return await create_graph() async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, @@ -253,7 +273,7 @@ async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, return agent_executor -async def _process_steps(agent, steps, semaphore): +async def _process_steps(agent, steps, semaphore, max_iterations: int = 10): async def _process_step(step): initial_state = {"input": step} if not isinstance(agent, AgentExecutor): @@ -261,10 +281,10 @@ async def _process_step(step): "input": step, "messages": [HumanMessage(content=step)], "step": 0, - "max_steps": 6, + "max_steps": max_iterations, "thought": None, "observation": None, - "output":"waiting for the agent to respond" + "output": "waiting for the agent to respond" } if semaphore: async with semaphore: @@ -352,7 +372,7 @@ async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: else: logger.info("ENABLE_GRAPH_AGENT is set to 0. Executing CVE agent.") agent = await _create_agent(config, builder, state) - results = await asyncio.gather(*(_process_steps(agent, steps, semaphore) + results = await asyncio.gather(*(_process_steps(agent, steps, semaphore, config.max_iterations) for steps in state.checklist_plans.values()), return_exceptions=True) results = _postprocess_results(results, config.replace_exceptions, config.replace_exceptions_value, list(state.checklist_plans.values())) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index a15029d0..cbff5d08 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -1,11 +1,3 @@ -from langchain_core.messages import HumanMessage -from langchain_core.output_parsers.openai_tools import ( - JsonOutputToolsParser, - PydanticToolsParser, -) -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_openai import ChatOpenAI - from pydantic import BaseModel, Field from typing import Any from typing import Literal @@ -58,7 +50,7 @@ class Observation(BaseModel): class AgentState(MessagesState): input: str = "" step: int = 0 - max_steps: int = 6 + max_steps: int = 10 #memory: str | None = None thought: Thought | None = None observation: Observation | None = None @@ -113,11 +105,11 @@ class AgentState(MessagesState): "package_name": "..." | null, "function_name": "..." | null, "query": "..." | null, - ,"reason": "..." }}, + "reason": "..." }}, "final_answer": "..." }} """ -FORCED_FINISH_PROMPT = f""" +FORCED_FINISH_PROMPT = """ You are at or past your maximum allowed steps. You MUST use mode="finish" and provide a final_answer NOW. From 58943c677434ecde0f678f6d3f4bbb41aeada8a2 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Fri, 27 Feb 2026 09:00:26 +0200 Subject: [PATCH 16/60] save unfinish work --- src/vuln_analysis/functions/cve_agent.py | 106 ++++++- .../functions/react_internals.py | 8 +- src/vuln_analysis/utils/prompt_factory.py | 258 ++++++++++++++++++ 3 files changed, 366 insertions(+), 6 deletions(-) create mode 100644 src/vuln_analysis/utils/prompt_factory.py diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index b05ca1c8..58ed6e38 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -14,6 +14,7 @@ # limitations under the License. import os import asyncio +import json from vuln_analysis.runtime_context import ctx_state import typing from aiq.builder.builder import Builder @@ -35,7 +36,7 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt, AgentState, Thought, _build_tool_arguments, FORCED_FINISH_PROMPT +from vuln_analysis.functions.react_internals import build_system_prompt, AgentState, Thought,Observation, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT from vuln_analysis.utils.prompting import build_tool_descriptions from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode @@ -103,6 +104,7 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build tools, tool_guidance_list, tool_descriptions_list = await common_build_tools(config, builder, state) llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) thought_llm = llm.with_structured_output(Thought) + observation_llm = llm.with_structured_output(Observation) tool_guidance = "\n".join(tool_guidance_list) descriptions = "\n".join(tool_descriptions_list) default_system_prompt = build_system_prompt(descriptions, tool_guidance) @@ -111,7 +113,7 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build THOUGHT_NODE = "thought_node" FORCED_FINISH_NODE = "forced_finish_node" PRE_PROCESS_NODE = "pre_process_node" - + OBSERVATION_NODE = "observation_node" def _filter_function_caller_finder(): filtered_tools = [ t for t in tools @@ -131,14 +133,47 @@ async def pre_process_node(state: AgentState) -> AgentState: if cve_intel.rhsa is not None: logger.debug("RHSA upstream_fix: %s", cve_intel.rhsa.upstream_fix) logger.debug("RHSA statement: %s", cve_intel.rhsa.statement) + target_packages = list(set(p.package_name for p in cve_intel.rhsa.package_state)) # + + # Build a generic guidance string + scope_guidance = f"The security advisory mentions the following affected packages: {', '.join(target_packages)}." # + + # Add important metadata facts from the 'details' or 'statement' fields + critical_context = [ + f"CVE Statement: {cve_intel.statement[:500]}...", # Highlights 'psql' and encoding risks + scope_guidance, + "TASK: Investigate usage and reachability within ALL mentioned packages." + ] + + tool_guidance_local, descriptions_local = _filter_function_caller_finder() runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) - return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt} + return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt,"observation": Observation(memory=critical_context, results=[])} async def thought_node(state: AgentState) -> AgentState: # Use runtime_prompt from pre_process if available, otherwise fall back to default active_prompt = state.get("runtime_prompt") or default_system_prompt messages = [SystemMessage(content=active_prompt)] + state["messages"] + # Extract Observation data + obs = state.get("observation",None) + if obs is not None: + memory_list = obs.memory if obs and obs.memory else ["No prior knowledge."] + recent_findings = obs.results if obs and obs.results else ["No recent findings."] + + # Format lists for the prompt + memory_context = "\n".join(f"- {m}" for m in memory_list) + findings_context = "\n".join(f"- {f}" for f in recent_findings) + + context_block = f""" + + {memory_context} + + + + {findings_context} + + """ + messages.append(AIMessage(content=context_block)) response: Thought = await thought_llm.ainvoke(messages) final_answer = "waiting for the agent to respond" @@ -211,6 +246,67 @@ async def forced_finish_node(state: AgentState) -> AgentState: "output": final_answer } + async def observation_node(state: AgentState) -> AgentState: + # Get the tool output (last message) and the agent's intent (second to last message) + tool_message = state["messages"][-1] + + # Extract context from the Thought object stored in state + last_thought_text = state["thought"].thought if state.get("thought") else "No previous thought." + tool_used = state["thought"].actions.tool if state.get("thought") and state["thought"].actions else "Unknown" + + previous_memory = state.get("observation").memory if state.get("observation") else "Initial state: No data gathered yet." + + # Construct the prompt for the structured output + # We ask the LLM to fill the 'memory' and 'results' fields of the Observation model + prompt = f""" + You are a security memory module for a CVE Analyst. Your task is to update the agent's 'memory' and 'results' lists based on tool findings. + + + {state.get('input')} + + + + {previous_memory} + + + + - Last Thought: "{last_thought_text}" + - Tool Used: "{tool_used}" + + + + {tool_message.content} + + + TASK: + Produce a structured Observation object (JSON). + + INSTRUCTIONS: + 1. memory (List of Strings): + # Updated Rule 1 + - VERIFY: Only record facts explicitly present in 'NEW_TOOL_OUTPUT'. + - ABSENCE: If no application code was found calling the library function, you MUST state "No usage found in application source" and do NOT claim the application uses the function. + - APPEND: Add new technical facts (filenames, version numbers, call chains). + - EVALUATE: If the tool result is empty or False, explicitly state that the search yielded no results for that specific path. + - DEDUPLICATE: Do not repeat facts already in PREVIOUS_MEMORY_LIST. + - CONCISE: Use plain, factual sentences. No conversational filler. + + 2. results (List of Strings): + - Provide 3-5 high-density technical "bullet points" from the LATEST tool output only. + + 3. FOCUS: Keep only information that helps determine exploitability for the specific CVE. + + RESPONSE: + {{""" + + # 2. Invoke the structured LLM + # This will return an instance of your Observation class + new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) + + return { + "observation": new_observation, + "step": state.get("step", 0), # Keep tracking steps + } async def create_graph(): flow = StateGraph(AgentState) @@ -218,6 +314,7 @@ async def create_graph(): flow.add_node(TOOL_NODE, tool_node) flow.add_node(FORCED_FINISH_NODE, forced_finish_node) flow.add_node(PRE_PROCESS_NODE, pre_process_node) + flow.add_node(OBSERVATION_NODE, observation_node) flow.add_edge(START, PRE_PROCESS_NODE) flow.add_edge(PRE_PROCESS_NODE, THOUGHT_NODE) flow.add_conditional_edges( @@ -225,7 +322,8 @@ async def create_graph(): should_continue, {END: END, TOOL_NODE: TOOL_NODE, FORCED_FINISH_NODE: FORCED_FINISH_NODE} ) - flow.add_edge(TOOL_NODE, THOUGHT_NODE) + flow.add_edge(TOOL_NODE, OBSERVATION_NODE) + flow.add_edge(OBSERVATION_NODE, THOUGHT_NODE) flow.add_edge(FORCED_FINISH_NODE, END) app = flow.compile() if config.verbose: diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index cbff5d08..8ddfeb44 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -43,9 +43,11 @@ class Thought(BaseModel): ) class Observation(BaseModel): - results: list[dict[str, Any]] = Field(description="Raw tool outputs") + results: list[str] = Field(description="Bullet points of the most important new technical facts found") - memory: str = Field(description="Compressed working memory summary") + memory: list[str] = Field( + description="A list of cumulative factual findings. Each item is a single, discrete technical fact." + ) class AgentState(MessagesState): input: str = "" @@ -123,6 +125,8 @@ class AgentState(MessagesState): {{ """ +OBSERVATION_NODE_PROMPT = """ +""" ### --- End of REACT Prompt Templates ----# def build_system_prompt( diff --git a/src/vuln_analysis/utils/prompt_factory.py b/src/vuln_analysis/utils/prompt_factory.py new file mode 100644 index 00000000..7825a844 --- /dev/null +++ b/src/vuln_analysis/utils/prompt_factory.py @@ -0,0 +1,258 @@ +from vuln_analysis.utils.prompting import AGENT_SYS_PROMPT +from vuln_analysis.tools.tool_names import ToolNames + +# Structure: (tool_name, language) -> language-specific input format and call instructions +TOOL_ECOSYSTEM_REGISTRY: dict[tuple[str, str], str] = { + # Code Semantic Search + (ToolNames.CODE_SEMANTIC_SEARCH, "python"): ( + "Input: natural language query. Use Python identifiers: e.g. 'urllib.parse url handling', " + "'PIL.Image image processing', 'asyncio usage'." + ), + (ToolNames.CODE_SEMANTIC_SEARCH, "go"): ( + "Input: natural language query. Use Go identifiers: e.g. 'encoding/json unmarshaling', " + "'net/http handler', 'context.Context propagation'." + ), + (ToolNames.CODE_SEMANTIC_SEARCH, "java"): ( + "Input: natural language query. Use Java identifiers: e.g. 'ObjectInputStream deserialization', " + "'javax.servlet request handling', 'java.util collections'." + ), + (ToolNames.CODE_SEMANTIC_SEARCH, "javascript"): ( + "Input: natural language query. Use JS identifiers: e.g. 'require() imports', " + "'express middleware', 'JSON.parse usage'." + ), + (ToolNames.CODE_SEMANTIC_SEARCH, "c"): ( + "Input: natural language query. Use C/C++ identifiers: e.g. 'EVP_EncryptInit encryption', " + "'strcpy buffer handling', 'malloc memory allocation'." + ), + # Docs Semantic Search + (ToolNames.DOCS_SEMANTIC_SEARCH, "python"): ( + "Input: natural language question. Phrase for Python docs: e.g. 'How does this app handle image uploads?'" + ), + (ToolNames.DOCS_SEMANTIC_SEARCH, "go"): ( + "Input: natural language question. Phrase for Go docs: e.g. 'How is HTTP routing configured?'" + ), + (ToolNames.DOCS_SEMANTIC_SEARCH, "java"): ( + "Input: natural language question. Phrase for Java docs: e.g. 'Application architecture and dependencies'" + ), + (ToolNames.DOCS_SEMANTIC_SEARCH, "javascript"): ( + "Input: natural language question. Phrase for JS docs: e.g. 'How does the API handle requests?'" + ), + (ToolNames.DOCS_SEMANTIC_SEARCH, "c"): ( + "Input: natural language question. Phrase for C/C++ docs: e.g. 'Cryptographic library usage'" + ), + # Code Keyword Search + (ToolNames.CODE_KEYWORD_SEARCH, "python"): ( + "Input: exact text. Python patterns: 'urllib.parse', 'from X import Y', 'import asyncio'." + ), + (ToolNames.CODE_KEYWORD_SEARCH, "go"): ( + "Input: exact text. Go patterns: 'import \"', 'pkg.FunctionName', 'encoding/json'." + ), + (ToolNames.CODE_KEYWORD_SEARCH, "java"): ( + "Input: exact text. Java patterns: 'import com.', 'ClassName.method', 'javax.servlet'." + ), + (ToolNames.CODE_KEYWORD_SEARCH, "javascript"): ( + "Input: exact text. JS patterns: 'require(', 'import { ', 'process.env'." + ), + (ToolNames.CODE_KEYWORD_SEARCH, "c"): ( + "Input: exact text. C/C++ patterns: 'EVP_EncryptInit_ex2', 'OSSL_PARAM', function names." + ), + # Function Locator + (ToolNames.FUNCTION_LOCATOR, "python"): ( + "Input: 'package_name,function_name'. Example: 'urllib,parse' or 'PIL.Image,open'. " + "Call first before Call Chain Analyzer." + ), + (ToolNames.FUNCTION_LOCATOR, "go"): ( + "Input: 'package_path,FunctionName'. Example: 'github.com/pkg/errors,New'. " + "Call first before Function Caller Finder or Call Chain Analyzer." + ), + (ToolNames.FUNCTION_LOCATOR, "java"): ( + "Input format 2(java): 'maven_gav,class_name.method_name'." + "Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'." + ), + (ToolNames.FUNCTION_LOCATOR, "javascript"): ( + "Input: 'package_name,function_name'. Example: 'express,Router' or npm scope paths. " + "Call first before Call Chain Analyzer." + ), + (ToolNames.FUNCTION_LOCATOR, "c"): ( + "Input: 'library_name,function_name'. Example: 'libxml2,xmlParseDocument' or 'openssl,EVP_EncryptInit_ex2'. " + "Call first before Call Chain Analyzer." + ), + # Call Chain Analyzer + (ToolNames.CALL_CHAIN_ANALYZER, "python"): ( + "Input: use validated 'package,function' from Function Locator. Example: 'urllib,parse'." + ), + (ToolNames.CALL_CHAIN_ANALYZER, "go"): ( + "Input format 1: 'package_name,function_name' Example: 'github.com/pkg/errors,New'." + ), + (ToolNames.CALL_CHAIN_ANALYZER, "java"): ( + "Input: use validated maven_gav,class.method from Function Locator." + ), + (ToolNames.CALL_CHAIN_ANALYZER, "javascript"): ( + "Input: use validated package,function from Function Locator." + ), + (ToolNames.CALL_CHAIN_ANALYZER, "c"): ( + "Input format 1: 'package_name,function_name'. Example: 'openssl,EVP_EncryptInit_ex2'." + ), + # Function Caller Finder (Go only) + (ToolNames.FUNCTION_CALLER_FINDER, "go"): ( + "Input: 'package_name,lib.function(args)'. Example: 'github.com/myapp,errors.New(\"text\")'. " + "Go-only. Finds callers of std/lib functions." + ), + # CVE Web Search (language-agnostic) + (ToolNames.CVE_WEB_SEARCH, "python"): " vulnerability search terms.", + (ToolNames.CVE_WEB_SEARCH, "go"): " vulnerability search terms.", + (ToolNames.CVE_WEB_SEARCH, "java"): "vulnerability search terms.", + (ToolNames.CVE_WEB_SEARCH, "javascript"): "vulnerability search terms.", + (ToolNames.CVE_WEB_SEARCH, "c"): "vulnerability search terms.", + # Container Analysis Data (language-agnostic) + (ToolNames.CONTAINER_ANALYSIS_DATA, "python"): "Input: query string for pre-analyzed container findings.", + (ToolNames.CONTAINER_ANALYSIS_DATA, "go"): "Input: query string for pre-analyzed container findings.", + (ToolNames.CONTAINER_ANALYSIS_DATA, "java"): "Input: query string for pre-analyzed container findings.", + (ToolNames.CONTAINER_ANALYSIS_DATA, "javascript"): "Input: query string for pre-analyzed container findings.", + (ToolNames.CONTAINER_ANALYSIS_DATA, "c"): "Input: query string for pre-analyzed container findings.", +} + +# Optional: Add language-specific usage hints +FEW_SHOT_EXAMPLES: dict[str, str] = { + "python": "Example: Function Locator input 'urllib,parse'; Code Keyword Search for 'urllib.parse'.", + "go": "Example: Function Caller Finder 'pkg,errors.New(\"x\")'; Function Locator 'github.com/pkg,New'.", + "java": "Example: Function Locator 'group:artifact:ver,Class.method'; Code Keyword Search 'import com.'.", + "javascript": "Example: Code Keyword Search 'require('; Function Locator 'package_name,exported'.", + "c": "Example: Function Locator 'openssl,EVP_EncryptInit_ex2'; Code Keyword Search 'EVP_aes_'.", +} + +# Per-language strategy for when to use which tools +TOOL_SELECTION_STRATEGY: dict[str, str] = { + "python": ( + "Use Code Keyword Search first for exact import/function lookups (e.g. urllib.parse, PIL.Image). " + "For reachability: call Function Locator first with package,function (e.g. urllib,parse), then Call Chain Analyzer with validated names. " + "Use Docs Semantic Search for architecture questions; Code Semantic Search for 'how is X used' patterns." + ), + "go": ( + "For std/lib function reachability: call Function Locator first to validate package paths (e.g. github.com/pkg/errors,New), " + "then use Function Caller Finder to find which app functions call the library method, then Call Chain Analyzer with validated names. " + "Use Code Keyword Search for import paths; Docs/Code Semantic Search for handler and middleware patterns." + ), + "java": ( + "Use Function Locator first with maven GAV format (group:artifact:version,ClassName.methodName). " + "For reachability, use Call Chain Analyzer with validated names from Function Locator. " + "Use Code Keyword Search for import com. and javax. patterns; Docs Semantic Search for Spring/servlet architecture." + ), + "javascript": ( + "Use Code Keyword Search first for require(, import {, and package patterns. " + "For reachability: Function Locator with package_name,exported, then Call Chain Analyzer. " + "Use Docs/Code Semantic Search for middleware, API, and npm package usage patterns." + ), + "c": ( + "Code Keyword Search:Exact text matching for function names, class names, or imports" + "For reachability: Use 'Function Caller Finder' first with library_name,function_name (e.g. openssl,EVP_EncryptInit_ex2), then Call Chain Analyzer with validated names. " + "together to trace function reachability." + "CVE Web Search: External vulnerability information lookup" + ), +} + +TOOL_GENERAL_DESCRIPTIONS: dict[str, str] = { + ToolNames.CODE_SEMANTIC_SEARCH: "Searches container source code using semantic search. " + "Finds how functions, libraries, or components are used in the codebase. " + "Answers questions about code implementation and dependencies.", + ToolNames.DOCS_SEMANTIC_SEARCH: "Searches container documentation using semantic search. " + "Answers questions about application purpose, architecture, and features.", + ToolNames.CODE_KEYWORD_SEARCH: "Performs keyword search on container source code for exact text matches. " + "Input should be a function name, class name, or code pattern. " + "Use this first before semantic search tools for precise lookups.", + ToolNames.FUNCTION_LOCATOR: "Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, and provides ecosystem type" + "Make sure the input format is matching exactly one of the following formats:", + + ToolNames.CALL_CHAIN_ANALYZER: "Checks if a function from a package is reachable from application code through the call chain." + "Make sure the input format is matching exactly one of the following formats:", + ToolNames.FUNCTION_CALLER_FINDER: "Finds functions in a package that call a specific library function.", + ToolNames.CVE_WEB_SEARCH: "Searches the web for information about CVEs, vulnerabilities, libraries, " + "and security advisories not available in the container.", + ToolNames.CONTAINER_ANALYSIS_DATA: "Retrieves pre-analyzed container image source code analysis results for the current CVE. " + "Returns a list of findings with 'response' (summary) and 'intermediate_steps' (reasoning) fields. " + "No input required - uses context automatically.", +} + + +class PromptFactory: + def __init__( + self, + template: str, + registry: dict, + examples: dict, + general_descriptions: dict[str, str], + tool_selection_strategy: dict[str, str], + ): + self.template = template + self.registry = registry + self.examples = examples + self.general_descriptions = general_descriptions + self.tool_selection_strategy = tool_selection_strategy + + def build_prompt( + self, + sys_prompt: str, + enabled_tools: list[str], + language: str, + tool_instructions: str | None = None, + ) -> str: + # Build AVAILABLE_TOOLS in format: tool_name(*args, **kwargs) - description + available_tools_lines = [] + for tool_name in enabled_tools: + desc = self.general_descriptions.get(tool_name, "Standard analysis tool.") + available_tools_lines.append(f"{tool_name}(*args, **kwargs) - {desc}") + tools_section = "\n".join(available_tools_lines) + + if tool_instructions is None: + # Fallback: build from language-specific registry + tool_instructions = "" + + #lang_examples = self.examples.get(language, "") + #disable examples for now + lang_examples = "" + tool_selection_strategy = self.tool_selection_strategy.get( + language, f"Optimize for {language} ecosystem patterns." + ) + + return self.template.format( + sys_prompt=sys_prompt, + tools=tools_section, + tool_selection_strategy=tool_selection_strategy, + tool_instructions=f"{tool_instructions}\n\n{lang_examples}" + ) + +# --- Usage --- +LANGGRAPH_SYSTEM_PROMPT_TEMPLATE = """{sys_prompt} + +Answer the investigation question using the available tools. If the input is not a question, formulate it into a question first. A Tool Selection Strategy is provided to help you decide which tools to use. Focus on answering the question. Include your intermediate reasoning in the final answer. + + +{tools} + + + +{tool_selection_strategy} + + +{tool_instructions} + +RESPONSE: +{{""" + +factory = PromptFactory( + LANGGRAPH_SYSTEM_PROMPT_TEMPLATE, + TOOL_ECOSYSTEM_REGISTRY, + FEW_SHOT_EXAMPLES, + TOOL_GENERAL_DESCRIPTIONS, + TOOL_SELECTION_STRATEGY, +) + +# Example: pass tool_instructions to use caller-provided content; omit to use registry-based fallback +python_prompt = factory.build_prompt( + sys_prompt=AGENT_SYS_PROMPT, + enabled_tools=[ToolNames.FUNCTION_LOCATOR, ToolNames.CVE_WEB_SEARCH, ToolNames.CALL_CHAIN_ANALYZER], + language="python", + tool_instructions=None, # or pass custom str +) + +#print(python_prompt) \ No newline at end of file From 3feba6e5a8e577673c5b2942efbabfed1217ba18 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Fri, 27 Feb 2026 12:31:32 +0200 Subject: [PATCH 17/60] AI claude phase 1 --- src/vuln_analysis/functions/cve_agent.py | 179 +++++++++--------- .../functions/react_internals.py | 64 +++---- 2 files changed, 108 insertions(+), 135 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 58ed6e38..79ab39c7 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -38,6 +38,7 @@ from vuln_analysis.functions.react_internals import build_system_prompt, AgentState, Thought,Observation, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT from vuln_analysis.utils.prompting import build_tool_descriptions +from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage, AIMessage @@ -114,66 +115,92 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build FORCED_FINISH_NODE = "forced_finish_node" PRE_PROCESS_NODE = "pre_process_node" OBSERVATION_NODE = "observation_node" - def _filter_function_caller_finder(): + def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list) -> tuple[str, str]: + """Build tool guidance using language-specific strategies when available.""" filtered_tools = [ - t for t in tools - if t.name != ToolNames.FUNCTION_CALLER_FINDER + t for t in available_tools + if t.name != ToolNames.FUNCTION_CALLER_FINDER or ecosystem == "go" ] list_of_tool_names = [t.name for t in filtered_tools] list_of_tool_descriptions = [t.name + ": " + t.description for t in filtered_tools] - tool_guidance_list_local = build_tool_descriptions(list_of_tool_names) - tool_guidance_local = "\n".join(tool_guidance_list_local) + + lang = ecosystem.lower() if ecosystem else "" + if lang in TOOL_SELECTION_STRATEGY: + tool_guidance_local = TOOL_SELECTION_STRATEGY[lang] + hint = FEW_SHOT_EXAMPLES.get(lang, "") + if hint: + tool_guidance_local += f"\nHint: {hint}" + else: + tool_guidance_list_local = build_tool_descriptions(list_of_tool_names) + tool_guidance_local = "\n".join(tool_guidance_list_local) + descriptions_local = "\n".join(list_of_tool_descriptions) return tool_guidance_local, descriptions_local + def _build_critical_context(cve_intel_list) -> list[str]: + """Extract key facts from all available intel sources into a compact context.""" + critical_context = [] + for cve_intel in cve_intel_list: + if cve_intel.nvd is not None: + if cve_intel.nvd.cve_description: + critical_context.append(f"CVE Description: {cve_intel.nvd.cve_description[:400]}") + if cve_intel.nvd.cwe_name: + critical_context.append(f"CWE: {cve_intel.nvd.cwe_name}") + + if cve_intel.ghsa is not None: + if cve_intel.ghsa.vulnerabilities: + for v in cve_intel.ghsa.vulnerabilities[:3]: + if hasattr(v, 'vulnerable_functions') and v.vulnerable_functions: + critical_context.append(f"Vulnerable functions (GHSA): {', '.join(v.vulnerable_functions)}") + if hasattr(v, 'package') and v.package: + pkg_info = f"Affected package: {v.package}" if isinstance(v.package, str) else "" + if pkg_info: + critical_context.append(pkg_info) + if cve_intel.ghsa.description and not any("CVE Description" in c for c in critical_context): + critical_context.append(f"CVE Description: {cve_intel.ghsa.description[:400]}") + + if cve_intel.rhsa is not None: + if cve_intel.rhsa.statement: + critical_context.append(f"RHSA Statement: {cve_intel.rhsa.statement[:300]}") + if cve_intel.rhsa.package_state: + pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name)) + if pkgs: + critical_context.append(f"Affected packages: {', '.join(pkgs)}") + + if cve_intel.ubuntu is not None: + if cve_intel.ubuntu.ubuntu_description: + critical_context.append(f"Ubuntu note: {cve_intel.ubuntu.ubuntu_description[:200]}") + + if cve_intel.plugin_data: + for pd in cve_intel.plugin_data[:2]: + critical_context.append(f"{pd.label}: {pd.description[:200]}") + + if not critical_context: + critical_context = ["No CVE intel available. Investigate using tools."] + return critical_context + async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value - for cve_intel in workflow_state.cve_intel: - if cve_intel.rhsa is not None: - logger.debug("RHSA upstream_fix: %s", cve_intel.rhsa.upstream_fix) - logger.debug("RHSA statement: %s", cve_intel.rhsa.statement) - target_packages = list(set(p.package_name for p in cve_intel.rhsa.package_state)) # - - # Build a generic guidance string - scope_guidance = f"The security advisory mentions the following affected packages: {', '.join(target_packages)}." # - - # Add important metadata facts from the 'details' or 'statement' fields - critical_context = [ - f"CVE Statement: {cve_intel.statement[:500]}...", # Highlights 'psql' and encoding risks - scope_guidance, - "TASK: Investigate usage and reachability within ALL mentioned packages." - ] - - - tool_guidance_local, descriptions_local = _filter_function_caller_finder() + + critical_context = _build_critical_context(workflow_state.cve_intel) + critical_context.append("TASK: Investigate usage and reachability of vulnerable components.") + + tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) - return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt,"observation": Observation(memory=critical_context, results=[])} + return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt, "observation": Observation(memory=critical_context, results=[])} async def thought_node(state: AgentState) -> AgentState: - # Use runtime_prompt from pre_process if available, otherwise fall back to default active_prompt = state.get("runtime_prompt") or default_system_prompt messages = [SystemMessage(content=active_prompt)] + state["messages"] - # Extract Observation data - obs = state.get("observation",None) + obs = state.get("observation", None) if obs is not None: - memory_list = obs.memory if obs and obs.memory else ["No prior knowledge."] - recent_findings = obs.results if obs and obs.results else ["No recent findings."] - - # Format lists for the prompt + memory_list = obs.memory if obs.memory else ["No prior knowledge."] + recent_findings = obs.results if obs.results else ["No recent findings."] memory_context = "\n".join(f"- {m}" for m in memory_list) findings_context = "\n".join(f"- {f}" for f in recent_findings) - - context_block = f""" - - {memory_context} - - - - {findings_context} - - """ - messages.append(AIMessage(content=context_block)) + context_block = f"KNOWLEDGE:\n{memory_context}\nLATEST FINDINGS:\n{findings_context}" + messages.append(SystemMessage(content=context_block)) response: Thought = await thought_llm.ainvoke(messages) final_answer = "waiting for the agent to respond" @@ -247,60 +274,26 @@ async def forced_finish_node(state: AgentState) -> AgentState: } async def observation_node(state: AgentState) -> AgentState: - # Get the tool output (last message) and the agent's intent (second to last message) tool_message = state["messages"][-1] - - # Extract context from the Thought object stored in state last_thought_text = state["thought"].thought if state.get("thought") else "No previous thought." tool_used = state["thought"].actions.tool if state.get("thought") and state["thought"].actions else "Unknown" - - previous_memory = state.get("observation").memory if state.get("observation") else "Initial state: No data gathered yet." - - # Construct the prompt for the structured output - # We ask the LLM to fill the 'memory' and 'results' fields of the Observation model - prompt = f""" - You are a security memory module for a CVE Analyst. Your task is to update the agent's 'memory' and 'results' lists based on tool findings. - - - {state.get('input')} - - - - {previous_memory} - - - - - Last Thought: "{last_thought_text}" - - Tool Used: "{tool_used}" - - - - {tool_message.content} - - - TASK: - Produce a structured Observation object (JSON). - - INSTRUCTIONS: - 1. memory (List of Strings): - # Updated Rule 1 - - VERIFY: Only record facts explicitly present in 'NEW_TOOL_OUTPUT'. - - ABSENCE: If no application code was found calling the library function, you MUST state "No usage found in application source" and do NOT claim the application uses the function. - - APPEND: Add new technical facts (filenames, version numbers, call chains). - - EVALUATE: If the tool result is empty or False, explicitly state that the search yielded no results for that specific path. - - DEDUPLICATE: Do not repeat facts already in PREVIOUS_MEMORY_LIST. - - CONCISE: Use plain, factual sentences. No conversational filler. - - 2. results (List of Strings): - - Provide 3-5 high-density technical "bullet points" from the LATEST tool output only. - - 3. FOCUS: Keep only information that helps determine exploitability for the specific CVE. - - RESPONSE: - {{""" - - # 2. Invoke the structured LLM - # This will return an instance of your Observation class + previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."] + + prompt = f"""Update the investigation memory based on new tool output. +GOAL: {state.get('input')} +PREVIOUS MEMORY: {previous_memory} +TOOL USED: {tool_used} +THOUGHT: {last_thought_text} +NEW OUTPUT: +{tool_message.content} + +RULES: +- memory: Append new facts from OUTPUT. Record absence if nothing found. No duplicates. Factual only. +- results: 3-5 key technical facts from this OUTPUT only. +- Keep only CVE-exploitability-relevant information. +RESPONSE: +{{""" + new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) return { diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 8ddfeb44..067db363 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -63,12 +63,16 @@ class AgentState(MessagesState): ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# AGENT_SYS_PROMPT = ( - "You are an expert security analyst investigating Common Vulnerabilities and " - "Exposures (CVE) in container images. Your role is to methodically answer " - "investigation questions using available tools to determine if vulnerabilities " - "are exploitable in the specific container context. You have access to the " - "container's source code, documentation, and dependency information through " - "specialized search and analysis tools." + "You are a security analyst investigating CVE exploitability in container images.\n" + "METHODOLOGY:\n" + "1. IDENTIFY the vulnerable component/function from the CVE description.\n" + "2. SEARCH for its presence in the container's source code.\n" + "3. TRACE whether the vulnerable code is reachable from application entry points.\n" + "4. ASSESS whether exploitation conditions are met (configuration, exposure, inputs).\n" + "RULES:\n" + "- Base conclusions ONLY on tool results, not assumptions.\n" + "- If a search returns no results, that is evidence the code is absent.\n" + "- Do NOT claim a function is used unless a tool confirmed it." ) # Update LANGGRAPH_SYSTEM_PROMPT_TEMPLATE in react_internals.py @@ -87,43 +91,19 @@ class AgentState(MessagesState): RESPONSE: {{""" -AGENT_THOUGHT_INSTRUCTIONS = """ - -1. Output MUST be valid JSON. -2. 'thought' < 100 words. 'final_answer' < 150 words. -3. If mode="act", 'actions' is REQUIRED. If mode="finish", 'final_answer' is REQUIRED. -4. For 'Function Locator', 'Function Caller Finder', 'Call Chain Analyzer': - - MUST set package_name AND function_name (e.g. "libpq", "PQescapeLiteral"). - - Do NOT use query for these tools. -5. For 'Code Keyword Search', 'Code Semantic Search', 'Docs Semantic Search', 'CVE Web Search': - - Use query with a single search term or phrase. - - -Return your response in this JSON format: -{{ - "thought": "...", - "mode": "act" | "finish", - "actions": {{ "tool": "...", - "package_name": "..." | null, - "function_name": "..." | null, - "query": "..." | null, - "reason": "..." }}, - "final_answer": "..." -}} -""" -FORCED_FINISH_PROMPT = """ - -You are at or past your maximum allowed steps. -You MUST use mode="finish" and provide a final_answer NOW. -Do NOT call any more tools. Summarize the evidence you have gathered so far into a concise -final_answer (3-5 sentences) that directly addresses the investigation question. - 'final_answer' < 300 words. - 'thought' < 300 words. - - +AGENT_THOUGHT_INSTRUCTIONS = """ +1. Output valid JSON only. thought < 100 words. final_answer < 150 words. +2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. +3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. +4. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. + + +{{"thought": "Check if urllib.parse is imported in the codebase", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "urllib.parse", "tool_input": null, "reason": "Search for direct import of vulnerable module"}}, "final_answer": null}} +""" +FORCED_FINISH_PROMPT = """Maximum steps reached. You MUST set mode="finish" and provide final_answer NOW. +Do NOT call any more tools. Summarize your evidence in 3-5 sentences. RESPONSE: -{{ -""" +{{""" OBSERVATION_NODE_PROMPT = """ """ From 3cbccb18f45f6c42672476687a12a001cf87ad41 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Fri, 27 Feb 2026 13:34:19 +0200 Subject: [PATCH 18/60] AI claude prompt iteration improvments --- src/vuln_analysis/functions/cve_agent.py | 19 ++++++++---- .../functions/react_internals.py | 30 +++++++++++++------ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 79ab39c7..95c2bf3c 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -165,7 +165,8 @@ def _build_critical_context(cve_intel_list) -> list[str]: if cve_intel.rhsa.package_state: pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name)) if pkgs: - critical_context.append(f"Affected packages: {', '.join(pkgs)}") + numbered = ", ".join(f"{i+1}) {p}" for i, p in enumerate(pkgs)) + critical_context.append(f"INVESTIGATE EACH package: {numbered}. Check reachability in ALL of them before concluding.") if cve_intel.ubuntu is not None: if cve_intel.ubuntu.ubuntu_description: @@ -184,7 +185,7 @@ async def pre_process_node(state: AgentState) -> AgentState: ecosystem = workflow_state.original_input.input.image.ecosystem.value critical_context = _build_critical_context(workflow_state.cve_intel) - critical_context.append("TASK: Investigate usage and reachability of vulnerable components.") + critical_context.append("TASK: Investigate usage and reachability in ALL listed packages. Do not stop after checking only one.") tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) @@ -289,6 +290,8 @@ async def observation_node(state: AgentState) -> AgentState: RULES: - memory: Append new facts from OUTPUT. Record absence if nothing found. No duplicates. Factual only. +- If reachability was NEGATIVE for a package, add to memory: "NOT reachable via [package]. Must check remaining packages." +- If reachability was POSITIVE, add: "REACHABLE via [package] - sufficient evidence." - results: 3-5 key technical facts from this OUTPUT only. - Keep only CVE-exploitability-relevant information. RESPONSE: @@ -456,6 +459,12 @@ async def cve_agent(config: CVEAgentExecutorToolConfig, builder: Builder): async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: trace_id.set(state.original_input.input.scan.id) ctx_state.set(state) + + checklist_plans = state.checklist_plans + if os.environ.get("DEBUG_SINGLE_QUESTION", "0") == "1": + logger.info("DEBUG_SINGLE_QUESTION is set. Limiting to first question per CVE.") + checklist_plans = {k: v[:1] for k, v in checklist_plans.items()} + agent = None if os.environ.get("ENABLE_GRAPH_AGENT","0") == "1": logger.info("ENABLE_GRAPH_AGENT is set to 1. Executing CVE agent in graph mode.") @@ -464,10 +473,10 @@ async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: logger.info("ENABLE_GRAPH_AGENT is set to 0. Executing CVE agent.") agent = await _create_agent(config, builder, state) results = await asyncio.gather(*(_process_steps(agent, steps, semaphore, config.max_iterations) - for steps in state.checklist_plans.values()), return_exceptions=True) + for steps in checklist_plans.values()), return_exceptions=True) results = _postprocess_results(results, config.replace_exceptions, config.replace_exceptions_value, - list(state.checklist_plans.values())) - state.checklist_results = dict(zip(state.checklist_plans.keys(), results)) + list(checklist_plans.values())) + state.checklist_results = dict(zip(checklist_plans.keys(), results)) return state yield FunctionInfo.from_fn( diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 067db363..f8163394 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -64,15 +64,23 @@ class AgentState(MessagesState): #---- REACT Prompt Templates ----# AGENT_SYS_PROMPT = ( "You are a security analyst investigating CVE exploitability in container images.\n" - "METHODOLOGY:\n" + "MANDATORY STEPS (follow in order, do NOT skip any):\n" "1. IDENTIFY the vulnerable component/function from the CVE description.\n" - "2. SEARCH for its presence in the container's source code.\n" - "3. TRACE whether the vulnerable code is reachable from application entry points.\n" - "4. ASSESS whether exploitation conditions are met (configuration, exposure, inputs).\n" - "RULES:\n" + "2. SEARCH for its presence using Code Keyword Search.\n" + "3. TRACE reachability using Function Locator and/or Call Chain Analyzer. " + "Keyword search alone is NOT sufficient -- you must trace the call chain.\n" + "4. If MULTIPLE packages are listed, repeat steps 2-3 for EACH package.\n" + "5. ASSESS: only after completing reachability checks, determine exploitability.\n" + "STOPPING RULES:\n" + "- POSITIVE reachability (function IS reachable): you MAY conclude exploitable and finish.\n" + "- NEGATIVE reachability (function NOT reachable in a package): you MUST continue " + "and check the NEXT package. A negative result in one package does not prove non-exploitability.\n" + "- You may only conclude 'not exploitable' after ALL packages have been checked and ALL returned negative.\n" + "GENERAL RULES:\n" "- Base conclusions ONLY on tool results, not assumptions.\n" "- If a search returns no results, that is evidence the code is absent.\n" - "- Do NOT claim a function is used unless a tool confirmed it." + "- Do NOT claim a function is used unless a tool confirmed it.\n" + "- Do NOT set mode='finish' until you have used at least one reachability tool." ) # Update LANGGRAPH_SYSTEM_PROMPT_TEMPLATE in react_internals.py @@ -96,10 +104,14 @@ class AgentState(MessagesState): 2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. 3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. 4. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. +5. If reachability is POSITIVE (reachable), you may finish. If NEGATIVE, you MUST check remaining packages before finishing. - -{{"thought": "Check if urllib.parse is imported in the codebase", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "urllib.parse", "tool_input": null, "reason": "Search for direct import of vulnerable module"}}, "final_answer": null}} -""" + +{{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "PQescapeLiteral", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} + + +{{"thought": "Found the function. Now I must check if it is reachable from application code", "mode": "act", "actions": {{"tool": "Function Locator", "package_name": "libpq", "function_name": "PQescapeLiteral", "query": null, "tool_input": null, "reason": "Validate function location before tracing call chain"}}, "final_answer": null}} +""" FORCED_FINISH_PROMPT = """Maximum steps reached. You MUST set mode="finish" and provide final_answer NOW. Do NOT call any more tools. Summarize your evidence in 3-5 sentences. RESPONSE: From da1112c0e47e589af59139b1c707db3d4cc3c794 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Fri, 27 Feb 2026 16:56:54 +0200 Subject: [PATCH 19/60] Improve context --- src/vuln_analysis/functions/cve_agent.py | 34 +++++++++++++++++------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 95c2bf3c..29af2388 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -150,12 +150,20 @@ def _build_critical_context(cve_intel_list) -> list[str]: if cve_intel.ghsa is not None: if cve_intel.ghsa.vulnerabilities: for v in cve_intel.ghsa.vulnerabilities[:3]: - if hasattr(v, 'vulnerable_functions') and v.vulnerable_functions: - critical_context.append(f"Vulnerable functions (GHSA): {', '.join(v.vulnerable_functions)}") - if hasattr(v, 'package') and v.package: - pkg_info = f"Affected package: {v.package}" if isinstance(v.package, str) else "" - if pkg_info: - critical_context.append(pkg_info) + vuln = v if isinstance(v, dict) else (v.__dict__ if hasattr(v, '__dict__') else {}) + vf = vuln.get('vulnerable_functions', []) if isinstance(vuln, dict) else getattr(v, 'vulnerable_functions', []) + pkg = vuln.get('package', None) if isinstance(vuln, dict) else getattr(v, 'package', None) + + if vf: + critical_context.append(f"Vulnerable functions (GHSA): {', '.join(vf)}") + if pkg: + if isinstance(pkg, dict): + pkg_name = pkg.get("name", "") + pkg_eco = pkg.get("ecosystem", "") + if pkg_name: + critical_context.append(f"Vulnerable module ({pkg_eco}): {pkg_name}") + elif isinstance(pkg, str): + critical_context.append(f"Affected package: {pkg}") if cve_intel.ghsa.description and not any("CVE Description" in c for c in critical_context): critical_context.append(f"CVE Description: {cve_intel.ghsa.description[:400]}") @@ -164,9 +172,14 @@ def _build_critical_context(cve_intel_list) -> list[str]: critical_context.append(f"RHSA Statement: {cve_intel.rhsa.statement[:300]}") if cve_intel.rhsa.package_state: pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name)) - if pkgs: + if len(pkgs) > 5: + critical_context.append( + f"Affected across {len(pkgs)} Red Hat products (sample: {', '.join(pkgs[:5])}). " + "Focus investigation on the vulnerable library/module, not individual products." + ) + elif pkgs: numbered = ", ".join(f"{i+1}) {p}" for i, p in enumerate(pkgs)) - critical_context.append(f"INVESTIGATE EACH package: {numbered}. Check reachability in ALL of them before concluding.") + critical_context.append(f"INVESTIGATE EACH package: {numbered}.") if cve_intel.ubuntu is not None: if cve_intel.ubuntu.ubuntu_description: @@ -185,7 +198,10 @@ async def pre_process_node(state: AgentState) -> AgentState: ecosystem = workflow_state.original_input.input.image.ecosystem.value critical_context = _build_critical_context(workflow_state.cve_intel) - critical_context.append("TASK: Investigate usage and reachability in ALL listed packages. Do not stop after checking only one.") + critical_context.append( + "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " + "Use the vulnerable module name from GHSA as primary investigation target." + ) tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) From 55df7ed502f55e7438c37c216711970c1e8fdbac Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sat, 28 Feb 2026 07:25:42 +0200 Subject: [PATCH 20/60] clean java cache --- .tekton/on-pull-request.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index 6315138b..bb1c745a 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -241,7 +241,8 @@ spec: # This is handled in the Makefile's lint-pr target and should be reverted after migration. make lint-pr TARGET_BRANCH=$TARGET_BRANCH_NAME - + #clean the java cache + rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* print_banner "RUNNING UNIT TESTS" make test-unit PYTEST_OPTS="--log-cli-level=DEBUG" From 030bb369eb19dca01adde029a04f2e922e28e41b Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sat, 28 Feb 2026 07:47:27 +0200 Subject: [PATCH 21/60] disable the java clean cache command --- .tekton/on-pull-request.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index bb1c745a..07bc25c8 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -242,7 +242,7 @@ spec: make lint-pr TARGET_BRANCH=$TARGET_BRANCH_NAME #clean the java cache - rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* print_banner "RUNNING UNIT TESTS" make test-unit PYTEST_OPTS="--log-cli-level=DEBUG" From c2dd659d0a9663755e04f536f3964fa8eac4d8e7 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 1 Mar 2026 16:23:23 +0200 Subject: [PATCH 22/60] Save last changes --- ci/scripts/analyze_traces.py | 567 ++++++++++++++++++ ci/scripts/collect_traces_local.py | 317 ++++++++++ src/vuln_analysis/functions/cve_agent.py | 109 +++- .../functions/react_internals.py | 1 + 4 files changed, 979 insertions(+), 15 deletions(-) create mode 100644 ci/scripts/analyze_traces.py create mode 100644 ci/scripts/collect_traces_local.py diff --git a/ci/scripts/analyze_traces.py b/ci/scripts/analyze_traces.py new file mode 100644 index 00000000..3d38624b --- /dev/null +++ b/ci/scripts/analyze_traces.py @@ -0,0 +1,567 @@ +""" +Trace analyzer for LLM reasoning quality. + +Reads per-trace JSON files produced by collect_traces_local.py, +reconstructs agent investigation flows, and flags quality issues. + +Usage: + python ci/scripts/analyze_traces.py --input traces_output/ + python ci/scripts/analyze_traces.py --input traces_output/ --json + python ci/scripts/analyze_traces.py --input traces_output/single_trace.json +""" +import argparse +import json +import sys +from collections import Counter +from dataclasses import dataclass, field, asdict +from enum import Enum +from pathlib import Path + + +REACHABILITY_TOOLS = {"Function Locator", "Call Chain Analyzer", "Function Caller Finder"} +SEARCH_ONLY_TOOLS = {"Code Keyword Search", "Code Semantic Search", "Docs Semantic Search"} +TOKEN_BUDGET_THRESHOLD = 6000 # 75% of 8k window +TOKEN_BUDGET_WARNING = 5800 # early warning before hitting threshold +CONTEXT_WINDOW_LIMIT = 8192 +MAX_COMPLETION_TOKENS = 2000 + + +class Severity(str, Enum): + HIGH = "HIGH" + MEDIUM = "MEDIUM" + LOW = "LOW" + + +@dataclass +class QualityFlag: + name: str + severity: Severity + detail: str + + +@dataclass +class InvestigationStep: + step_num: int + span_kind: str # LLM or TOOL + name: str + input_preview: str + output_preview: str + token_prompt: int = 0 + token_completion: int = 0 + token_total: int = 0 + llm_role: str = "" # "thought" or "observation" for LLM spans + thought_mode: str = "" # "act" or "finish" for thought spans + + +@dataclass +class InvestigationFlow: + """One agent investigation thread (one checklist question).""" + trace_id: str + job_id: str + parent_span_id: str + steps: list[InvestigationStep] = field(default_factory=list) + flags: list[QualityFlag] = field(default_factory=list) + tools_used: list[str] = field(default_factory=list) + llm_call_count: int = 0 + tool_call_count: int = 0 + token_progression: list[int] = field(default_factory=list) + cve_id: str = "" + ecosystem: str = "" + question: str = "" + + +def _preview(text: str, max_len: int = 120) -> str: + if not text: + return "" + clean = text.replace("\n", " ").strip() + return clean[:max_len] + "..." if len(clean) > max_len else clean + + +def _extract_cve_from_input(input_value: str) -> tuple[str, str]: + """Try to extract CVE ID and ecosystem from an agent executor's LLM input.""" + cve_id = "" + ecosystem = "" + try: + messages = json.loads(input_value) + for msg in messages: + content = msg.get("content", "") if isinstance(msg, dict) else "" + if "CVE-" in content: + import re + match = re.search(r"(CVE-\d{4}-\d+)", content) + if match: + cve_id = match.group(1) + if "ecosystem" in content.lower(): + for eco in ("go", "python", "java", "javascript", "c"): + if f'"{eco}"' in content.lower() or f"'{eco}'" in content.lower(): + ecosystem = eco + break + except (json.JSONDecodeError, TypeError): + pass + return cve_id, ecosystem + + +def _extract_question(input_value: str) -> str: + """Extract the human/checklist question from the first LLM call's input.""" + try: + messages = json.loads(input_value) + for msg in messages: + if isinstance(msg, dict) and msg.get("type") == "human": + return _preview(msg.get("content", ""), 200) + except (json.JSONDecodeError, TypeError): + pass + return "" + + +# --------------------------------------------------------------------------- +# Parsing +# --------------------------------------------------------------------------- + +def load_trace_files(input_path: str) -> list[dict]: + """Load trace documents from a file or directory.""" + p = Path(input_path) + traces = [] + if p.is_file(): + with open(p, "r", encoding="utf-8") as f: + traces.append(json.load(f)) + elif p.is_dir(): + for fp in sorted(p.glob("*.json")): + with open(fp, "r", encoding="utf-8") as f: + try: + traces.append(json.load(f)) + except json.JSONDecodeError as e: + print(f"Warning: skipping {fp}: {e}", file=sys.stderr) + else: + print(f"Error: {input_path} is not a file or directory", file=sys.stderr) + sys.exit(1) + return traces + + +def _classify_llm_span(span: dict) -> str: + """Classify an LLM span as 'thought' or 'observation' based on output format.""" + output = span.get("output_value", "") + stripped = output.lstrip() + if stripped.startswith('{"thought"') or stripped.startswith('{"thought'): + return "thought" + return "observation" + + +def _extract_thought_mode(output: str) -> str: + """Extract the mode ('act' or 'finish') from a thought-node JSON output.""" + try: + parsed = json.loads(output) + return parsed.get("mode", "") + except (json.JSONDecodeError, TypeError): + pass + import re + m = re.search(r'"mode"\s*:\s*"(act|finish)"', output) + return m.group(1) if m else "" + + +def _extract_goal_question(input_value: str) -> str: + """Extract the GOAL question from an observation-node input.""" + try: + messages = json.loads(input_value) + for msg in messages: + content = msg.get("content", "") if isinstance(msg, dict) else "" + if "GOAL:" in content: + goal_line = content.split("GOAL:")[1].split("\n")[0].strip() + return _preview(goal_line, 200) + except (json.JSONDecodeError, TypeError): + pass + return "" + + +def _split_by_question(spans: list[dict]) -> dict[str, list[dict]]: + """ + Split interleaved agent spans into per-question groups. + + Concurrent checklist questions share a parent_span_id. We separate them + by extracting the question from thought-node (human message) and + observation-node (GOAL field) LLM spans. TOOL spans are assigned to the + question whose thought span most recently preceded them. + """ + spans_sorted = sorted(spans, key=lambda s: int(s.get("start_time") or 0)) + + questions_seen: list[str] = [] + for span in spans_sorted: + if span.get("span_kind") == "LLM" and _classify_llm_span(span) == "thought": + q = _extract_question(span.get("input_value", "")) + if q and q not in questions_seen: + questions_seen.append(q) + + if len(questions_seen) <= 1: + return {} + + groups: dict[str, list[dict]] = {q: [] for q in questions_seen} + + last_thought_question: str | None = None + + for span in spans_sorted: + kind = span.get("span_kind", "") + + if kind == "LLM": + role = _classify_llm_span(span) + if role == "thought": + q = _extract_question(span.get("input_value", "")) + if q in groups: + last_thought_question = q + groups[q].append(span) + elif role == "observation": + goal_q = _extract_goal_question(span.get("input_value", "")) + matched = None + if goal_q: + for known_q in questions_seen: + if goal_q[:60] in known_q or known_q[:60] in goal_q: + matched = known_q + break + if matched: + groups[matched].append(span) + elif last_thought_question: + groups[last_thought_question].append(span) + + elif kind == "TOOL": + if last_thought_question: + groups[last_thought_question].append(span) + + return groups + + +def _build_flow_from_spans( + group_spans: list[dict], + trace_id: str, + job_id: str, + parent_id: str, +) -> InvestigationFlow: + """Build a single InvestigationFlow from an ordered list of spans.""" + group_spans.sort(key=lambda s: int(s.get("start_time") or 0)) + + flow = InvestigationFlow( + trace_id=trace_id, + job_id=job_id, + parent_span_id=parent_id, + ) + + for i, span in enumerate(group_spans): + kind = span.get("span_kind", "") + role = "" + mode = "" + if kind == "LLM": + role = _classify_llm_span(span) + if role == "thought": + mode = _extract_thought_mode(span.get("output_value", "")) + + step = InvestigationStep( + step_num=i + 1, + span_kind=kind, + name=span.get("name", ""), + input_preview=_preview(span.get("input_value", "")), + output_preview=_preview(span.get("output_value", "")), + token_prompt=span.get("token_prompt", 0), + token_completion=span.get("token_completion", 0), + token_total=span.get("token_total", 0), + llm_role=role, + thought_mode=mode, + ) + flow.steps.append(step) + + if kind == "LLM": + flow.llm_call_count += 1 + if step.token_prompt > 0: + flow.token_progression.append(step.token_prompt) + if i == 0: + cve_id, eco = _extract_cve_from_input(span.get("input_value", "")) + flow.cve_id = cve_id + flow.ecosystem = eco + flow.question = _extract_question(span.get("input_value", "")) + elif kind == "TOOL": + flow.tool_call_count += 1 + flow.tools_used.append(span.get("name", "")) + + return flow + + +def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: + """ + From a single trace document, reconstruct per-question investigation flows. + + When concurrent checklist questions share one parent_span_id (LangGraph), + we split by extracting the human question from each thought-node span. + Falls back to parent_span_id grouping when questions aren't interleaved. + """ + trace_id = trace_doc.get("trace_id", "") + job_id = trace_doc.get("job_id", "") + spans = trace_doc.get("spans", []) + + agent_spans = [ + s for s in spans + if s.get("function_name") == "cve_agent_executor" + and s.get("span_kind") in ("LLM", "TOOL") + ] + + parent_groups: dict[str, list[dict]] = {} + for span in agent_spans: + pid = span.get("parent_span_id", "unknown") + parent_groups.setdefault(pid, []).append(span) + + flows = [] + for parent_id, group_spans in parent_groups.items(): + question_groups = _split_by_question(group_spans) + + if question_groups: + for q_text, q_spans in question_groups.items(): + if q_spans: + flow = _build_flow_from_spans(q_spans, trace_id, job_id, parent_id) + if not flow.question: + flow.question = q_text + flows.append(flow) + else: + flow = _build_flow_from_spans(group_spans, trace_id, job_id, parent_id) + flows.append(flow) + + return flows + + +# --------------------------------------------------------------------------- +# Quality checks +# --------------------------------------------------------------------------- + +def run_quality_checks(flow: InvestigationFlow) -> list[QualityFlag]: + flags = [] + + tool_set = set(flow.tools_used) + + if not tool_set: + flags.append(QualityFlag( + name="NO_TOOL_CALLS", + severity=Severity.HIGH, + detail="Agent concluded without any tool calls.", + )) + elif not tool_set & REACHABILITY_TOOLS: + flags.append(QualityFlag( + name="NO_REACHABILITY_TOOL", + severity=Severity.HIGH, + detail=f"No reachability tool used. Tools: {', '.join(tool_set)}", + )) + + if tool_set and tool_set <= SEARCH_ONLY_TOOLS: + flags.append(QualityFlag( + name="KEYWORD_SEARCH_ONLY", + severity=Severity.HIGH, + detail="Only search tools used, never traced call chain.", + )) + + if flow.llm_call_count <= 2 and flow.tool_call_count > 0: + flags.append(QualityFlag( + name="PREMATURE_FINISH", + severity=Severity.MEDIUM, + detail=f"Agent finished in {flow.llm_call_count} LLM iterations.", + )) + + budget_flagged = False + for tok in flow.token_progression: + if tok > TOKEN_BUDGET_THRESHOLD: + flags.append(QualityFlag( + name="TOKEN_BUDGET_HIGH", + severity=Severity.MEDIUM, + detail=f"LLM call used {tok} prompt tokens (>{TOKEN_BUDGET_THRESHOLD}).", + )) + budget_flagged = True + break + if tok > TOKEN_BUDGET_WARNING: + flags.append(QualityFlag( + name="TOKEN_BUDGET_APPROACHING", + severity=Severity.LOW, + detail=f"LLM call used {tok} prompt tokens, approaching {TOKEN_BUDGET_THRESHOLD} limit.", + )) + budget_flagged = True + break + + if "Function Caller Finder" in tool_set and flow.ecosystem and flow.ecosystem != "go": + flags.append(QualityFlag( + name="WRONG_TOOL_ECOSYSTEM", + severity=Severity.HIGH, + detail=f"Function Caller Finder used on {flow.ecosystem} (Go-only tool).", + )) + + thought_steps = [s for s in flow.steps if s.llm_role == "thought"] + if thought_steps: + last_thought = thought_steps[-1] + ended_with_finish = last_thought.thought_mode == "finish" + peak_tokens = max(flow.token_progression) if flow.token_progression else 0 + available_prompt_space = CONTEXT_WINDOW_LIMIT - MAX_COMPLETION_TOKENS + + if not ended_with_finish and last_thought.thought_mode == "act": + if peak_tokens > available_prompt_space * 0.65: + flags.append(QualityFlag( + name="CONTEXT_WINDOW_EXCEEDED", + severity=Severity.HIGH, + detail=( + f"Flow ended mid-investigation (last thought: mode='act', peak tokens: {peak_tokens}, " + f"available prompt space: {available_prompt_space}). " + f"Next LLM call likely exceeded {CONTEXT_WINDOW_LIMIT} context window." + ), + )) + else: + flags.append(QualityFlag( + name="INCOMPLETE_FLOW", + severity=Severity.HIGH, + detail=( + f"Flow ended without mode='finish' (last thought: mode='act', " + f"peak tokens: {peak_tokens}). Possible unrecorded exception." + ), + )) + + return flags + + +# --------------------------------------------------------------------------- +# Report +# --------------------------------------------------------------------------- + +def print_flow_report(flow: InvestigationFlow): + header = f"=== Trace: {flow.trace_id} | Parent: {flow.parent_span_id[:12]}... ===" + print(header) + if flow.cve_id or flow.ecosystem: + print(f"CVE: {flow.cve_id or 'unknown'} | Ecosystem: {flow.ecosystem or 'unknown'}") + if flow.question: + print(f"Question: {flow.question}") + print(f"Steps: {flow.llm_call_count} LLM calls, {flow.tool_call_count} tool calls") + + if flow.token_progression: + tokens_str = " -> ".join(str(t) for t in flow.token_progression) + print(f"Token usage: {tokens_str}") + + if flow.tools_used: + tool_counts = Counter(flow.tools_used) + tools_str = ", ".join(f"{name} (x{count})" for name, count in tool_counts.items()) + print(f"Tools used: {tools_str}") + reachability_used = set(flow.tools_used) & REACHABILITY_TOOLS + if reachability_used: + print(f"Reachability tools: YES ({', '.join(reachability_used)})") + else: + print("Reachability tools: NO") + else: + print("Tools used: NONE") + + if flow.flags: + flag_strs = [f"[{f.severity.value}] {f.name}: {f.detail}" for f in flow.flags] + print(f"Flags: {len(flow.flags)}") + for fs in flag_strs: + print(f" {fs}") + else: + print("Flags: NONE") + + print() + for step in flow.steps: + if step.span_kind == "LLM": + role_label = f" [{step.llm_role}]" if step.llm_role else "" + mode_label = f" mode={step.thought_mode}" if step.thought_mode else "" + print(f"--- Step {step.step_num}: LLM{role_label}{mode_label} ---") + if step.token_prompt: + print(f" Prompt tokens: {step.token_prompt} | Completion: {step.token_completion}") + print(f" Output: \"{step.output_preview}\"") + elif step.span_kind == "TOOL": + print(f"--- Step {step.step_num}: TOOL ({step.name}) ---") + print(f" Input: {step.input_preview}") + print(f" Output: \"{step.output_preview}\"") + + print() + + +def print_summary(all_flows: list[InvestigationFlow]): + total = len(all_flows) + flagged = [f for f in all_flows if f.flags] + passed = total - len(flagged) + + print("=" * 60) + print("SUMMARY") + print("=" * 60) + print(f"Total investigation flows: {total}") + print(f"Passed (no flags): {passed}") + print(f"Flagged: {len(flagged)}") + + if flagged: + all_flag_counts: Counter = Counter() + for flow in flagged: + for flag in flow.flags: + all_flag_counts[flag.name] += 1 + for flag_name, count in all_flag_counts.most_common(): + print(f" - {flag_name}: {count}") + + if all_flows: + avg_llm = sum(f.llm_call_count for f in all_flows) / total + avg_tool = sum(f.tool_call_count for f in all_flows) / total + all_tokens = [t for f in all_flows for t in f.token_progression] + avg_tokens = sum(all_tokens) / len(all_tokens) if all_tokens else 0 + print(f"\nAvg LLM calls per flow: {avg_llm:.1f}") + print(f"Avg tool calls per flow: {avg_tool:.1f}") + print(f"Avg prompt tokens per LLM call: {avg_tokens:.0f}") + + print() + + +def build_json_report(all_flows: list[InvestigationFlow]) -> dict: + flows_data = [] + for flow in all_flows: + flows_data.append({ + "trace_id": flow.trace_id, + "job_id": flow.job_id, + "cve_id": flow.cve_id, + "ecosystem": flow.ecosystem, + "question": flow.question, + "llm_call_count": flow.llm_call_count, + "tool_call_count": flow.tool_call_count, + "token_progression": flow.token_progression, + "tools_used": flow.tools_used, + "flags": [asdict(f) for f in flow.flags], + "steps": [asdict(s) for s in flow.steps], + }) + + total = len(all_flows) + flagged_count = sum(1 for f in all_flows if f.flags) + return { + "total_flows": total, + "passed": total - flagged_count, + "flagged": flagged_count, + "flows": flows_data, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="Analyze agent trace files for LLM reasoning quality.") + parser.add_argument("--input", required=True, + help="Path to traces_output/ directory or a single trace JSON file") + parser.add_argument("--json", action="store_true", + help="Output structured JSON report instead of console text") + parser.add_argument("--summary-only", action="store_true", + help="Print only the summary, skip per-flow details") + args = parser.parse_args() + + trace_docs = load_trace_files(args.input) + if not trace_docs: + print("No trace files found.", file=sys.stderr) + sys.exit(1) + + all_flows: list[InvestigationFlow] = [] + for doc in trace_docs: + flows = build_investigation_flows(doc) + for flow in flows: + flow.flags = run_quality_checks(flow) + all_flows.extend(flows) + + if args.json: + report = build_json_report(all_flows) + print(json.dumps(report, indent=2)) + else: + if not args.summary_only: + for flow in all_flows: + print_flow_report(flow) + print_summary(all_flows) + + +if __name__ == "__main__": + main() diff --git a/ci/scripts/collect_traces_local.py b/ci/scripts/collect_traces_local.py new file mode 100644 index 00000000..c4036a43 --- /dev/null +++ b/ci/scripts/collect_traces_local.py @@ -0,0 +1,317 @@ +""" +Local trace collector -- copy of collect_and_dispatch_traces.py with per-trace +flattened JSON file output for offline analysis. + +Writes one JSON file per completed trace to traces_output/{job_id}_{trace_id}.json. +Original Tempo/endpoint dispatch is preserved. +""" +from kafka import KafkaConsumer +import argparse +import json +import os +from HttpProvider import HttpProvider, AuthType +from datetime import datetime +from typing import Any +from pydantic import BaseModel, Field, constr, RootModel +import signal +import time +from kafka.errors import NoBrokersAvailable + +TRACE_VERSION = 1 + +DEFAULT_URL = "http://localhost:8080" +DEFAULT_AUTH_TYPE = "none" +DEFAULT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" +DEFAULT_VERIFY_PATH = "/app/certs/service-ca.crt" +DEFAULT_ENABLE_VERIFY = False +TEMPO_URL = "http://localhost:8080" + +KAFKA_TOPIC = 'otel-traces' +KAFKA_BROKER = 'localhost:9092' +TRACES_OUTPUT_DIR = 'traces_output' + + +class LocalDateTime(RootModel[datetime]): + root: datetime = Field(..., examples=['2022-03-10T12:15:50']) + + +class Trace(BaseModel): + format_version: int | None = None + job_id: constr(min_length=1) = Field( + ..., description='The job run id in which the trace and spans were instrumented' + ) + execution_start_timestamp: LocalDateTime | None = None + trace_id: constr(min_length=1) = Field( + ..., + description='Trace id grouping all spans of sub-tasks in one agent stage', + ) + span_id: constr(min_length=1) = Field( + ..., description='The id corresponding to one sub task inside one agent stage.' + ) + span_payload: dict[str, Any] | None = Field( + None, + description='Span payload with metadata and instrumentation data', + ) + + +class TracesContainer(RootModel[list[Trace]]): + pass + + +def _otlp_str(val) -> str: + """Extract plain string from an OTLP attribute value like {"stringValue": "..."}.""" + if isinstance(val, dict): + return val.get("stringValue", "") + return str(val) if val else "" + + +def _otlp_int(val) -> int: + """Extract integer from an OTLP attribute value like {"intValue": "123"}.""" + if isinstance(val, dict): + raw = val.get("intValue", 0) + try: + return int(raw) + except (ValueError, TypeError): + return 0 + try: + return int(val) + except (ValueError, TypeError): + return 0 + + +def get_metadata(span: dict) -> dict: + return { + 'trace_id': span.get('traceId'), + 'span_id': span.get('spanId'), + 'parent_span_id': span.get('parentSpanId'), + 'name': span.get('name'), + 'start_time': span.get('startTimeUnixNano'), + 'end_time': span.get('endTimeUnixNano'), + 'status': span.get('status'), + } + + +def find_job_id(attributes: dict) -> str: + try: + input_value = attributes.get("input.value") + str_json = input_value.get("stringValue") + scan = json.loads(str_json) + if "scan" in scan: + return scan["scan"]["id"] + if "input" in scan: + return scan["input"]["scan"]["id"] + else: + print("Warning: Scan not found in attributes") + return "" + except Exception as e: + print(f"Error finding job id (scan id): {e}") + return "" + + +def flatten_span(metadata: dict, attrs: dict) -> dict: + """Convert a raw OTLP span (metadata + attributes) into a flat dict.""" + return { + "span_id": metadata.get("span_id", ""), + "parent_span_id": metadata.get("parent_span_id", ""), + "name": metadata.get("name", ""), + "start_time": metadata.get("start_time", ""), + "end_time": metadata.get("end_time", ""), + "span_kind": _otlp_str(attrs.get("nat.span.kind")), + "function_name": _otlp_str(attrs.get("nat.function.name")), + "input_value": _otlp_str(attrs.get("input.value")), + "output_value": _otlp_str(attrs.get("output.value")), + "token_prompt": _otlp_int(attrs.get("llm.token_count.prompt")), + "token_completion": _otlp_int(attrs.get("llm.token_count.completion")), + "token_total": _otlp_int(attrs.get("llm.token_count.total")), + } + + +def write_trace_file(trace_id: str, job_id: str, flat_spans: list[dict], output_dir: str): + """Write a completed trace as a single JSON file, spans sorted by start_time.""" + os.makedirs(output_dir, exist_ok=True) + flat_spans.sort(key=lambda s: int(s.get("start_time") or 0)) + trace_doc = { + "trace_id": trace_id, + "job_id": job_id, + "completed_at": datetime.utcnow().isoformat(), + "spans": flat_spans, + } + filename = f"{job_id}_{trace_id}.json" + filepath = os.path.join(output_dir, filename) + with open(filepath, "w", encoding="utf-8") as f: + json.dump(trace_doc, f, indent=2) + print(f"Wrote trace file: {filepath} ({len(flat_spans)} spans)") + + +def process_message(http_provider, tempo_provider, traces_table, flat_spans_table, + map_trace_id_to_job_id, data, output_dir): + # Forward original trace to Tempo + try: + for rs in data.get('resourceSpans', []): + attrs = rs.get('resource', {}).get('attributes', []) + if not any(a['key'] == 'service.name' for a in attrs): + attrs.append({"key": "service.name", "value": {"stringValue": "cve-agent"}}) + for ss in rs.get('scopeSpans', []): + for span in ss.get('spans', []): + if span.get('name') == "": + span['parentSpanId'] = "" + #tempo_status = tempo_provider.send_post(json.dumps(data)) + except Exception as e: + print(f"Failed to forward trace to Tempo: {e}") + + for rs in data.get('resourceSpans', []): + for ss in rs.get('scopeSpans', []): + for span in ss.get('spans', []): + metadata = get_metadata(span) + attrs = { + a['key']: a.get('value', {}) + for a in span.get('attributes', []) + } + trace_id = metadata.get('trace_id') + if trace_id not in traces_table: + print("add trace_id to traces_table: ", trace_id) + traces_table[trace_id] = [] + flat_spans_table[trace_id] = [] + if trace_id in map_trace_id_to_job_id: + job_id = map_trace_id_to_job_id[trace_id] + else: + job_id = find_job_id(attrs) + if len(job_id) > 0: + trace_collection = traces_table.get(trace_id, []) + if trace_id not in map_trace_id_to_job_id: + map_trace_id_to_job_id[trace_id] = job_id + print(f"add trace_id to map_trace_id_to_job_id: {trace_id} --> {job_id}") + + dict_payload = {"metadata": metadata, "attributes": attrs} + trace_obj = Trace( + format_version=TRACE_VERSION, + job_id=job_id, + execution_start_timestamp=LocalDateTime( + datetime.fromtimestamp(int(metadata.get('start_time', "0")) / 1e9) + ), + trace_id=metadata.get('trace_id'), + span_id=metadata.get('span_id'), + span_payload=dict_payload, + ) + trace_collection.append(trace_obj) + traces_table[trace_id] = trace_collection + + flat_spans_table.setdefault(trace_id, []).append( + flatten_span(metadata, attrs) + ) + + span_name = metadata.get('name') + if span_name == "": + print("finish scan workflow - sending to endpoint + writing trace file") + traces = traces_table.get(trace_id, []) + traces_container = TracesContainer(traces) + #status = http_provider.send_post(traces_container.model_dump_json()) + status = 200 + print(f"send traces(job_id: {job_id}, trace_id: {trace_id}) to endpoint status: {status}") + + write_trace_file( + trace_id, job_id, + flat_spans_table.get(trace_id, []), + output_dir, + ) + + del traces_table[trace_id] + del flat_spans_table[trace_id] + + +def run_worker( + url: str = DEFAULT_URL, + auth_type: str = DEFAULT_AUTH_TYPE, + token_path: str = DEFAULT_TOKEN_PATH, + verify_path: str = DEFAULT_VERIFY_PATH, + enable_verify: bool = DEFAULT_ENABLE_VERIFY, + output_dir: str = TRACES_OUTPUT_DIR, +): + stop_triggered = False + + def signal_handler(sig, frame): + nonlocal stop_triggered + stop_triggered = True + + signal.signal(signal.SIGTERM, signal_handler) + + print(f"Waiting for Kafka at {KAFKA_BROKER}...") + consumer = None + while consumer is None: + try: + consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + auto_offset_reset='earliest', + group_id='cve-file-writer-group', + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + request_timeout_ms=20000, + session_timeout_ms=10000, + ) + print("Successfully connected to Redpanda!") + except NoBrokersAvailable: + print("Kafka not ready yet... retrying in 2 seconds.") + time.sleep(2) + + print(f"Worker Active. Writing per-trace files to: {output_dir}/") + os.makedirs(output_dir, exist_ok=True) + traces_table = {} + flat_spans_table = {} + map_trace_id_to_job_id = {} + + endpoint_url = url + "/api/v1/traces" + EndpointAuthType = AuthType.BEARER + if auth_type != DEFAULT_AUTH_TYPE: + EndpointAuthType = AuthType.NONE + http_provider = HttpProvider( + url=TEMPO_URL, + auth_type=AuthType.NONE, + enable_verify=False, + ) + tempo_provider = HttpProvider( + url=TEMPO_URL, + auth_type=AuthType.NONE, + enable_verify=False, + ) + + try: + while not stop_triggered: + messages = consumer.poll(timeout_ms=1000) + for tp, records in messages.items(): + for message in records: + data = message.value + process_message( + http_provider, tempo_provider, + traces_table, flat_spans_table, + map_trace_id_to_job_id, data, + output_dir, + ) + finally: + consumer.close() + print("\nStopping worker...") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Collect traces from Kafka, dispatch to endpoint, and write per-trace JSON files for analysis." + ) + parser.add_argument("--url", default=DEFAULT_URL, help="Dispatch endpoint URL") + parser.add_argument("--auth-type", default=DEFAULT_AUTH_TYPE, help="Auth type (Bearer or None)") + parser.add_argument("--token-path", default=DEFAULT_TOKEN_PATH, help="Path to auth token file") + parser.add_argument("--verify-path", default=DEFAULT_VERIFY_PATH, help="Path to TLS CA cert") + parser.add_argument("--enable-verify", action="store_true", default=DEFAULT_ENABLE_VERIFY) + parser.add_argument("--output-dir", default=TRACES_OUTPUT_DIR, + help="Directory for per-trace JSON files (default: traces_output)") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + run_worker( + url=args.url, + auth_type=args.auth_type, + token_path=args.token_path, + verify_path=args.verify_path, + enable_verify=args.enable_verify, + output_dir=args.output_dir, + ) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 29af2388..ff2eebed 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -41,10 +41,13 @@ from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode -from langchain_core.messages import SystemMessage, HumanMessage, AIMessage +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage import uuid +import tiktoken +from nat.builder.context import Context logger = LoggingFactory.get_agent_logger(__name__) +AGENT_TRACER = Context.get() class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): """ @@ -74,6 +77,10 @@ class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): cve_web_search_enabled: bool = Field(default=True, description="Whether to enable CVE Web Search tool or not.") verbose: bool = Field(default=False, description="Set to true for verbose output") + context_window_token_limit: int = Field( + default=5000, + description="Estimated token threshold for pruning old messages in observation node." + ) async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState) -> tuple[list[typing.Any], list[str], list[str]]: @@ -115,6 +122,29 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build FORCED_FINISH_NODE = "forced_finish_node" PRE_PROCESS_NODE = "pre_process_node" OBSERVATION_NODE = "observation_node" + + _tiktoken_enc = tiktoken.get_encoding("cl100k_base") + + def _count_tokens(text: str) -> int: + """Count tokens using tiktoken cl100k_base encoding (~90-95% accurate for Llama 3.1).""" + try: + return len(_tiktoken_enc.encode(text)) + except Exception: + return len(text) // 4 + + def _estimate_tokens(runtime_prompt: str, messages: list, observation: Observation | None) -> int: + """Estimate the token count thought_node will send to the LLM.""" + parts = [runtime_prompt] + for msg in messages: + if hasattr(msg, "content") and isinstance(msg.content, str): + parts.append(msg.content) + if observation is not None: + for item in (observation.memory or []): + parts.append(item) + for item in (observation.results or []): + parts.append(item) + return _count_tokens("\n".join(parts)) + def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list) -> tuple[str, str]: """Build tool guidance using language-specific strategies when available.""" filtered_tools = [ @@ -196,16 +226,17 @@ def _build_critical_context(cve_intel_list) -> list[str]: async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value + with AGENT_TRACER.push_active_function("node.pre_process", input_data=ecosystem): - critical_context = _build_critical_context(workflow_state.cve_intel) - critical_context.append( - "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " - "Use the vulnerable module name from GHSA as primary investigation target." - ) + critical_context = _build_critical_context(workflow_state.cve_intel) + critical_context.append( + "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " + "Use the vulnerable module name from GHSA as primary investigation target." + ) - tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) - runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) - return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt, "observation": Observation(memory=critical_context, results=[])} + tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) + runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) + return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt, "observation": Observation(memory=critical_context, results=[])} async def thought_node(state: AgentState) -> AgentState: active_prompt = state.get("runtime_prompt") or default_system_prompt @@ -268,6 +299,10 @@ async def forced_finish_node(state: AgentState) -> AgentState: active_prompt = state.get("runtime_prompt") or default_system_prompt messages = [SystemMessage(content=active_prompt)] + state["messages"] messages.append(HumanMessage(content=FORCED_FINISH_PROMPT)) + obs = state.get("observation", None) + if obs is not None and obs.memory: + memory_context = "\n".join(f"- {m}" for m in obs.memory) + messages.append(SystemMessage(content=f"KNOWLEDGE:\n{memory_context}")) response: Thought = await thought_llm.ainvoke(messages) if response.mode == "finish" and response.final_answer: ai_message = AIMessage(content=response.final_answer) @@ -294,20 +329,33 @@ async def observation_node(state: AgentState) -> AgentState: tool_message = state["messages"][-1] last_thought_text = state["thought"].thought if state.get("thought") else "No previous thought." tool_used = state["thought"].actions.tool if state.get("thought") and state["thought"].actions else "Unknown" + tool_input_detail = "" + if state.get("thought") and state["thought"].actions: + actions = state["thought"].actions + if actions.package_name and actions.function_name: + tool_input_detail = f"{actions.package_name},{actions.function_name}" + elif actions.query: + tool_input_detail = actions.query + elif actions.tool_input: + tool_input_detail = actions.tool_input previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."] prompt = f"""Update the investigation memory based on new tool output. GOAL: {state.get('input')} PREVIOUS MEMORY: {previous_memory} TOOL USED: {tool_used} +TOOL INPUT: {tool_input_detail} THOUGHT: {last_thought_text} NEW OUTPUT: {tool_message.content} RULES: - memory: Append new facts from OUTPUT. Record absence if nothing found. No duplicates. Factual only. +- If NEW OUTPUT is empty, contains an error, or indicates tool failure, record in memory: "FAILED: {tool_used} [{tool_input_detail}] - [reason]". Do NOT infer or fabricate positive findings from failed tool output. +- results from a failed tool call must only state the failure, not speculate about what the tool might have found. - If reachability was NEGATIVE for a package, add to memory: "NOT reachable via [package]. Must check remaining packages." - If reachability was POSITIVE, add: "REACHABLE via [package] - sufficient evidence." +- Always record in memory: "CALLED: [tool] with [input] -> [outcome]" so future steps know what was already tried. - results: 3-5 key technical facts from this OUTPUT only. - Keep only CVE-exploitability-relevant information. RESPONSE: @@ -315,9 +363,27 @@ async def observation_node(state: AgentState) -> AgentState: new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) + messages = state["messages"] + active_prompt = state.get("runtime_prompt") or default_system_prompt + estimated = _estimate_tokens(active_prompt, messages, new_observation) + prune_messages = [] + + if estimated > config.context_window_token_limit and len(messages) > 3: + prunable = messages[1:-2] + for msg in prunable: + prune_messages.append(RemoveMessage(id=msg.id)) + estimated -= _count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0 + if estimated <= config.context_window_token_limit: + break + logger.info( + "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", + len(prune_messages), estimated, config.context_window_token_limit, + ) + return { + "messages": prune_messages, "observation": new_observation, - "step": state.get("step", 0), # Keep tracking steps + "step": state.get("step", 0), } async def create_graph(): @@ -384,8 +450,17 @@ async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, async def _process_steps(agent, steps, semaphore, max_iterations: int = 10): + + async def _process_step(step): + async def call_agent(initial_state,config=None): + if config: + return await agent.ainvoke(initial_state,config=config) + else: + return await agent.ainvoke(initial_state) + initial_state = {"input": step} + config = None if not isinstance(agent, AgentExecutor): initial_state = { "input": step, @@ -396,11 +471,15 @@ async def _process_step(step): "observation": None, "output": "waiting for the agent to respond" } - if semaphore: - async with semaphore: - return await agent.ainvoke(initial_state) - else: - return await agent.ainvoke(initial_state) + config = { + "recursion_limit": 50 + } + with AGENT_TRACER.push_active_function("checklist_question", input_data=step[:80]): + if semaphore: + async with semaphore: + return await call_agent(initial_state, config) + else: + return await call_agent(initial_state, config) return await asyncio.gather(*(_process_step(step) for step in steps), return_exceptions=True) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index f8163394..e7b3194d 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -105,6 +105,7 @@ class AgentState(MessagesState): 3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. 4. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. 5. If reachability is POSITIVE (reachable), you may finish. If NEGATIVE, you MUST check remaining packages before finishing. +6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. {{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "PQescapeLiteral", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} From e54104a1327d139733bd352a1798e36c84ce1904 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 1 Mar 2026 17:32:12 +0200 Subject: [PATCH 23/60] update parsing --- ci/scripts/analyze_traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/scripts/analyze_traces.py b/ci/scripts/analyze_traces.py index 3d38624b..9d076de4 100644 --- a/ci/scripts/analyze_traces.py +++ b/ci/scripts/analyze_traces.py @@ -294,7 +294,7 @@ def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: agent_spans = [ s for s in spans - if s.get("function_name") == "cve_agent_executor" + if s.get("function_name") in ("cve_agent_executor", "checklist_question") and s.get("span_kind") in ("LLM", "TOOL") ] From 8cd05e103456c4f0e3ae05671fdbdfcb3a2e8ca8 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 1 Mar 2026 18:41:54 +0200 Subject: [PATCH 24/60] temp hardcore use graph --- src/vuln_analysis/functions/cve_agent.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index ff2eebed..841ac6ed 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -559,14 +559,14 @@ async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: if os.environ.get("DEBUG_SINGLE_QUESTION", "0") == "1": logger.info("DEBUG_SINGLE_QUESTION is set. Limiting to first question per CVE.") checklist_plans = {k: v[:1] for k, v in checklist_plans.items()} - - agent = None - if os.environ.get("ENABLE_GRAPH_AGENT","0") == "1": - logger.info("ENABLE_GRAPH_AGENT is set to 1. Executing CVE agent in graph mode.") - agent = await _create_graph_agent(config, builder, state) - else: - logger.info("ENABLE_GRAPH_AGENT is set to 0. Executing CVE agent.") - agent = await _create_agent(config, builder, state) + agent = await _create_graph_agent(config, builder, state) + #agent = None + #if os.environ.get("ENABLE_GRAPH_AGENT","0") == "1": + # logger.info("ENABLE_GRAPH_AGENT is set to 1. Executing CVE agent in graph mode.") + # agent = await _create_graph_agent(config, builder, state) + #else: + # logger.info("ENABLE_GRAPH_AGENT is set to 0. Executing CVE agent.") + # agent = await _create_agent(config, builder, state) results = await asyncio.gather(*(_process_steps(agent, steps, semaphore, config.max_iterations) for steps in checklist_plans.values()), return_exceptions=True) results = _postprocess_results(results, config.replace_exceptions, config.replace_exceptions_value, From 1b418b71d67e7ceeafe75a8e94f8d92a940eae32 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Mon, 2 Mar 2026 08:47:55 +0000 Subject: [PATCH 25/60] Improve Prompt logic and behaviour --- src/vuln_analysis/functions/cve_agent.py | 80 ++++++++----------- .../functions/react_internals.py | 43 ++++++++-- src/vuln_analysis/utils/prompt_factory.py | 9 +-- 3 files changed, 74 insertions(+), 58 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 841ac6ed..dbe2f181 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -226,7 +226,7 @@ def _build_critical_context(cve_intel_list) -> list[str]: async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value - with AGENT_TRACER.push_active_function("node.pre_process", input_data=ecosystem): + with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span: critical_context = _build_critical_context(workflow_state.cve_intel) critical_context.append( @@ -236,6 +236,7 @@ async def pre_process_node(state: AgentState) -> AgentState: tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) + span.set_output({"critical_context": critical_context}) return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt, "observation": Observation(memory=critical_context, results=[])} async def thought_node(state: AgentState) -> AgentState: @@ -339,52 +340,41 @@ async def observation_node(state: AgentState) -> AgentState: elif actions.tool_input: tool_input_detail = actions.tool_input previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."] - - prompt = f"""Update the investigation memory based on new tool output. -GOAL: {state.get('input')} -PREVIOUS MEMORY: {previous_memory} -TOOL USED: {tool_used} -TOOL INPUT: {tool_input_detail} -THOUGHT: {last_thought_text} -NEW OUTPUT: -{tool_message.content} - -RULES: -- memory: Append new facts from OUTPUT. Record absence if nothing found. No duplicates. Factual only. -- If NEW OUTPUT is empty, contains an error, or indicates tool failure, record in memory: "FAILED: {tool_used} [{tool_input_detail}] - [reason]". Do NOT infer or fabricate positive findings from failed tool output. -- results from a failed tool call must only state the failure, not speculate about what the tool might have found. -- If reachability was NEGATIVE for a package, add to memory: "NOT reachable via [package]. Must check remaining packages." -- If reachability was POSITIVE, add: "REACHABLE via [package] - sufficient evidence." -- Always record in memory: "CALLED: [tool] with [input] -> [outcome]" so future steps know what was already tried. -- results: 3-5 key technical facts from this OUTPUT only. -- Keep only CVE-exploitability-relevant information. -RESPONSE: -{{""" - - new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) - - messages = state["messages"] - active_prompt = state.get("runtime_prompt") or default_system_prompt - estimated = _estimate_tokens(active_prompt, messages, new_observation) - prune_messages = [] - - if estimated > config.context_window_token_limit and len(messages) > 3: - prunable = messages[1:-2] - for msg in prunable: - prune_messages.append(RemoveMessage(id=msg.id)) - estimated -= _count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0 - if estimated <= config.context_window_token_limit: - break - logger.info( - "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", - len(prune_messages), estimated, config.context_window_token_limit, + with AGENT_TRACER.push_active_function("observation node", input_data=f"tool used:{tool_used}") as span: + prompt = OBSERVATION_NODE_PROMPT.format( + goal=state.get('input'), + previous_memory=previous_memory, + tool_used=tool_used, + tool_input_detail=tool_input_detail, + last_thought_text=last_thought_text, + tool_output=tool_message.content, ) - return { - "messages": prune_messages, - "observation": new_observation, - "step": state.get("step", 0), - } + new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) + + messages = state["messages"] + active_prompt = state.get("runtime_prompt") or default_system_prompt + estimated = _estimate_tokens(active_prompt, messages, new_observation) + prune_messages = [] + orig_estimated = estimated + + if estimated > config.context_window_token_limit and len(messages) > 3: + prunable = messages[1:-2] + for msg in prunable: + prune_messages.append(RemoveMessage(id=msg.id)) + estimated -= _count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0 + if estimated <= config.context_window_token_limit: + break + logger.info( + "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", + len(prune_messages), estimated, config.context_window_token_limit, + ) + span.set_output({"orig_estimated": orig_estimated, "estimated": estimated}) + return { + "messages": prune_messages, + "observation": new_observation, + "step": state.get("step", 0), + } async def create_graph(): flow = StateGraph(AgentState) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index e7b3194d..32bda817 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -67,8 +67,9 @@ class AgentState(MessagesState): "MANDATORY STEPS (follow in order, do NOT skip any):\n" "1. IDENTIFY the vulnerable component/function from the CVE description.\n" "2. SEARCH for its presence using Code Keyword Search.\n" - "3. TRACE reachability using Function Locator and/or Call Chain Analyzer. " - "Keyword search alone is NOT sufficient -- you must trace the call chain.\n" + "3. TRACE reachability using Call Chain Analyzer. " + " - Use the Function Locator to verify the package name and find the function name." + " - Keyword search alone is NOT sufficient -- you must trace the call chain.\n" "4. If MULTIPLE packages are listed, repeat steps 2-3 for EACH package.\n" "5. ASSESS: only after completing reachability checks, determine exploitability.\n" "STOPPING RULES:\n" @@ -80,7 +81,10 @@ class AgentState(MessagesState): "- Base conclusions ONLY on tool results, not assumptions.\n" "- If a search returns no results, that is evidence the code is absent.\n" "- Do NOT claim a function is used unless a tool confirmed it.\n" - "- Do NOT set mode='finish' until you have used at least one reachability tool." + "- Code Keyword Search proves code PRESENCE in the container, NOT reachability.\n" + "- Function Locator validates package/function NAMES, NOT reachability. It confirms the name exists, not that it is called.\n" + "- Only Call Chain Analyzer can confirm reachability. The application may contain code it never calls.\n" + "- When the question asks whether a function is called or reachable, do NOT conclude based on Code Keyword Search or Function Locator alone -- you MUST use Call Chain Analyzer." ) # Update LANGGRAPH_SYSTEM_PROMPT_TEMPLATE in react_internals.py @@ -104,22 +108,45 @@ class AgentState(MessagesState): 2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. 3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. 4. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. -5. If reachability is POSITIVE (reachable), you may finish. If NEGATIVE, you MUST check remaining packages before finishing. +5. If Call Chain Analyzer returns POSITIVE (reachable), you may finish. If NEGATIVE, you MUST check remaining packages before finishing. Code Keyword Search and Function Locator are NOT reachability proof -- only Call Chain Analyzer is. 6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. {{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "PQescapeLiteral", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} -{{"thought": "Found the function. Now I must check if it is reachable from application code", "mode": "act", "actions": {{"tool": "Function Locator", "package_name": "libpq", "function_name": "PQescapeLiteral", "query": null, "tool_input": null, "reason": "Validate function location before tracing call chain"}}, "final_answer": null}} -""" +{{"thought": "Found the function. Now use Function Locator to verify the package name and function", "mode": "act", "actions": {{"tool": "Function Locator", "package_name": "libpq", "function_name": "PQescapeLiteral", "query": null, "tool_input": null, "reason": "Validate package and function name before Call Chain Analyzer"}}, "final_answer": null}} + + +{{"thought": "Function Locator confirmed the package. Now trace reachability with Call Chain Analyzer", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "libpq", "function_name": "PQescapeLiteral", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} +""" FORCED_FINISH_PROMPT = """Maximum steps reached. You MUST set mode="finish" and provide final_answer NOW. Do NOT call any more tools. Summarize your evidence in 3-5 sentences. RESPONSE: {{""" -OBSERVATION_NODE_PROMPT = """ -""" +OBSERVATION_NODE_PROMPT = """Update the investigation memory based on new tool output. +GOAL: {goal} +PREVIOUS MEMORY: {previous_memory} +TOOL USED: {tool_used} +TOOL INPUT: {tool_input_detail} +THOUGHT: {last_thought_text} +NEW OUTPUT: +{tool_output} + +RULES: +- memory: Append new facts from OUTPUT. Record absence if nothing found. No duplicates. Factual only. +- If NEW OUTPUT is empty, contains an error, or indicates tool failure, record in memory: "FAILED: {tool_used} [{tool_input_detail}] - [reason]". Do NOT infer or fabricate positive findings from failed tool output. +- results from a failed tool call must only state the failure, not speculate about what the tool might have found. +- Reachability tags apply ONLY to Call Chain Analyzer results. Code Keyword Search and Function Locator do NOT determine reachability. +- If Call Chain Analyzer returned NEGATIVE (False), add to memory: "NOT reachable via [package]. Must check remaining packages." +- If Call Chain Analyzer returned POSITIVE (True), add: "REACHABLE via [package] - sufficient evidence." +- For Function Locator results, record: "VALIDATED: [package],[function] exists" -- this is NOT reachability. +- Always record in memory: "CALLED: [tool] with [input] -> [outcome]" so future steps know what was already tried. +- results: 3-5 key technical facts from this OUTPUT only. +- Keep only CVE-exploitability-relevant information. +RESPONSE: +{{""" ### --- End of REACT Prompt Templates ----# def build_system_prompt( diff --git a/src/vuln_analysis/utils/prompt_factory.py b/src/vuln_analysis/utils/prompt_factory.py index 7825a844..bd88b59f 100644 --- a/src/vuln_analysis/utils/prompt_factory.py +++ b/src/vuln_analysis/utils/prompt_factory.py @@ -144,11 +144,10 @@ "Use Docs/Code Semantic Search for middleware, API, and npm package usage patterns." ), "c": ( - "Code Keyword Search:Exact text matching for function names, class names, or imports" - "For reachability: Use 'Function Caller Finder' first with library_name,function_name (e.g. openssl,EVP_EncryptInit_ex2), then Call Chain Analyzer with validated names. " - "together to trace function reachability." - "CVE Web Search: External vulnerability information lookup" - ), + "Use Code Keyword Search first for exact function name lookups (e.g. PQescapeLiteral, EVP_EncryptInit_ex2). " + "For reachability: call Function Locator first with library_name,function_name (e.g. libpq,PQescapeLiteral) to validate the package and function, " + "then Call Chain Analyzer with validated names to check if it is reachable from application code." + ), } TOOL_GENERAL_DESCRIPTIONS: dict[str, str] = { From ec96fc35520b308cb623649b0b4e1789e9710674 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Mon, 2 Mar 2026 11:33:01 +0000 Subject: [PATCH 26/60] Second round of prompt improvments --- ci/scripts/collect_traces_local.py | 9 +++++++-- src/vuln_analysis/functions/cve_agent.py | 7 +++++++ src/vuln_analysis/functions/react_internals.py | 4 ++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/ci/scripts/collect_traces_local.py b/ci/scripts/collect_traces_local.py index c4036a43..5f29b49d 100644 --- a/ci/scripts/collect_traces_local.py +++ b/ci/scripts/collect_traces_local.py @@ -201,8 +201,13 @@ def process_message(http_provider, tempo_provider, traces_table, flat_spans_tabl ) span_name = metadata.get('name') - if span_name == "": - print("finish scan workflow - sending to endpoint + writing trace file") + func_name = _otlp_str(attrs.get("nat.function.name")) + if span_name == "" or func_name == "agent_finish": + if trace_id not in traces_table: + print(f"Trace {trace_id} already flushed, skipping duplicate trigger.") + continue + trigger = span_name if span_name == "" else f"agent_finish (func={func_name})" + print(f"Trace write triggered by: {trigger}") traces = traces_table.get(trace_id, []) traces_container = TracesContainer(traces) #status = http_provider.send_post(traces_container.model_dump_json()) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index dbe2f181..1e0d198b 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -562,6 +562,13 @@ async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: results = _postprocess_results(results, config.replace_exceptions, config.replace_exceptions_value, list(checklist_plans.values())) state.checklist_results = dict(zip(checklist_plans.keys(), results)) + + with AGENT_TRACER.push_active_function("agent_finish", input_data={ + "cve_count": len(checklist_plans), + "scan_id": state.original_input.input.scan.id, + }): + pass + return state yield FunctionInfo.from_fn( diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 32bda817..2377f3ea 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -108,7 +108,7 @@ class AgentState(MessagesState): 2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. 3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. 4. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. -5. If Call Chain Analyzer returns POSITIVE (reachable), you may finish. If NEGATIVE, you MUST check remaining packages before finishing. Code Keyword Search and Function Locator are NOT reachability proof -- only Call Chain Analyzer is. +5. If Call Chain Analyzer returns POSITIVE (reachable), you may finish. If NEGATIVE, check KNOWLEDGE for "INVESTIGATE EACH package:" list -- you MUST try the next unchecked package before finishing. Code Keyword Search and Function Locator are NOT reachability proof -- only Call Chain Analyzer is. 6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. @@ -139,7 +139,7 @@ class AgentState(MessagesState): - If NEW OUTPUT is empty, contains an error, or indicates tool failure, record in memory: "FAILED: {tool_used} [{tool_input_detail}] - [reason]". Do NOT infer or fabricate positive findings from failed tool output. - results from a failed tool call must only state the failure, not speculate about what the tool might have found. - Reachability tags apply ONLY to Call Chain Analyzer results. Code Keyword Search and Function Locator do NOT determine reachability. -- If Call Chain Analyzer returned NEGATIVE (False), add to memory: "NOT reachable via [package]. Must check remaining packages." +- If Call Chain Analyzer returned NEGATIVE (False), add to memory: "NOT reachable via [package]. Check INVESTIGATE list for next unchecked package." - If Call Chain Analyzer returned POSITIVE (True), add: "REACHABLE via [package] - sufficient evidence." - For Function Locator results, record: "VALIDATED: [package],[function] exists" -- this is NOT reachability. - Always record in memory: "CALLED: [tool] with [input] -> [outcome]" so future steps know what was already tried. From 0cc0f6073baa2bc4cb0f4921ea04a6542a694228 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Mon, 2 Mar 2026 18:14:39 +0200 Subject: [PATCH 27/60] Java scripts improvments --- ci/scripts/collect_traces_local.py | 73 +++++++++++++------ .../functions/react_internals.py | 6 +- .../utils/function_name_locator.py | 69 ++++++++++-------- 3 files changed, 94 insertions(+), 54 deletions(-) diff --git a/ci/scripts/collect_traces_local.py b/ci/scripts/collect_traces_local.py index 5f29b49d..f299c273 100644 --- a/ci/scripts/collect_traces_local.py +++ b/ci/scripts/collect_traces_local.py @@ -29,6 +29,7 @@ KAFKA_TOPIC = 'otel-traces' KAFKA_BROKER = 'localhost:9092' TRACES_OUTPUT_DIR = 'traces_output' +FLUSH_GRACE_PERIOD_SEC = 15 class LocalDateTime(RootModel[datetime]): @@ -144,7 +145,7 @@ def write_trace_file(trace_id: str, job_id: str, flat_spans: list[dict], output_ def process_message(http_provider, tempo_provider, traces_table, flat_spans_table, - map_trace_id_to_job_id, data, output_dir): + map_trace_id_to_job_id, pending_flushes, data, output_dir): # Forward original trace to Tempo try: for rs in data.get('resourceSpans', []): @@ -206,22 +207,41 @@ def process_message(http_provider, tempo_provider, traces_table, flat_spans_tabl if trace_id not in traces_table: print(f"Trace {trace_id} already flushed, skipping duplicate trigger.") continue - trigger = span_name if span_name == "" else f"agent_finish (func={func_name})" - print(f"Trace write triggered by: {trigger}") - traces = traces_table.get(trace_id, []) - traces_container = TracesContainer(traces) - #status = http_provider.send_post(traces_container.model_dump_json()) - status = 200 - print(f"send traces(job_id: {job_id}, trace_id: {trace_id}) to endpoint status: {status}") - - write_trace_file( - trace_id, job_id, - flat_spans_table.get(trace_id, []), - output_dir, - ) - - del traces_table[trace_id] - del flat_spans_table[trace_id] + if trace_id not in pending_flushes: + trigger = span_name if span_name == "" else f"agent_finish (func={func_name})" + pending_flushes[trace_id] = time.time() + print(f"Flush scheduled for trace {trace_id} (trigger: {trigger}, " + f"grace period: {FLUSH_GRACE_PERIOD_SEC}s, " + f"spans so far: {len(flat_spans_table.get(trace_id, []))})") + + +def flush_pending_traces(pending_flushes, traces_table, flat_spans_table, + map_trace_id_to_job_id, output_dir, force=False): + """Flush traces whose grace period has expired (or all if force=True).""" + now = time.time() + flushed = [] + for trace_id, trigger_time in list(pending_flushes.items()): + if not force and (now - trigger_time) < FLUSH_GRACE_PERIOD_SEC: + continue + if trace_id not in traces_table: + flushed.append(trace_id) + continue + job_id = map_trace_id_to_job_id.get(trace_id, "unknown") + waited = now - trigger_time + span_count = len(flat_spans_table.get(trace_id, [])) + print(f"Flushing trace {trace_id} (waited {waited:.1f}s, {span_count} spans)") + + write_trace_file( + trace_id, job_id, + flat_spans_table.get(trace_id, []), + output_dir, + ) + + del traces_table[trace_id] + del flat_spans_table[trace_id] + flushed.append(trace_id) + for trace_id in flushed: + del pending_flushes[trace_id] def run_worker( @@ -258,11 +278,12 @@ def signal_handler(sig, frame): print("Kafka not ready yet... retrying in 2 seconds.") time.sleep(2) - print(f"Worker Active. Writing per-trace files to: {output_dir}/") + print(f"Worker Active. Writing per-trace files to: {output_dir}/ (grace period: {FLUSH_GRACE_PERIOD_SEC}s)") os.makedirs(output_dir, exist_ok=True) traces_table = {} flat_spans_table = {} map_trace_id_to_job_id = {} + pending_flushes = {} endpoint_url = url + "/api/v1/traces" EndpointAuthType = AuthType.BEARER @@ -288,12 +309,22 @@ def signal_handler(sig, frame): process_message( http_provider, tempo_provider, traces_table, flat_spans_table, - map_trace_id_to_job_id, data, - output_dir, + map_trace_id_to_job_id, pending_flushes, + data, output_dir, ) + flush_pending_traces( + pending_flushes, traces_table, flat_spans_table, + map_trace_id_to_job_id, output_dir, + ) finally: + if pending_flushes: + print(f"\nFlushing {len(pending_flushes)} pending traces on shutdown...") + flush_pending_traces( + pending_flushes, traces_table, flat_spans_table, + map_trace_id_to_job_id, output_dir, force=True, + ) consumer.close() - print("\nStopping worker...") + print("Stopping worker...") def parse_args(): diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 2377f3ea..6f712362 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -112,13 +112,13 @@ class AgentState(MessagesState): 6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. -{{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "PQescapeLiteral", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} +{{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} -{{"thought": "Found the function. Now use Function Locator to verify the package name and function", "mode": "act", "actions": {{"tool": "Function Locator", "package_name": "libpq", "function_name": "PQescapeLiteral", "query": null, "tool_input": null, "reason": "Validate package and function name before Call Chain Analyzer"}}, "final_answer": null}} +{{"thought": "Found the function. Now use Function Locator to verify the package name and function", "mode": "act", "actions": {{"tool": "Function Locator", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Validate package and function name before Call Chain Analyzer"}}, "final_answer": null}} -{{"thought": "Function Locator confirmed the package. Now trace reachability with Call Chain Analyzer", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "libpq", "function_name": "PQescapeLiteral", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} +{{"thought": "Function Locator confirmed the package. Now trace reachability with Call Chain Analyzer", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} """ FORCED_FINISH_PROMPT = """Maximum steps reached. You MUST set mode="finish" and provide final_answer NOW. Do NOT call any more tools. Summarize your evidence in 3-5 sentences. diff --git a/src/vuln_analysis/utils/function_name_locator.py b/src/vuln_analysis/utils/function_name_locator.py index 13f3adbc..8ffa670b 100644 --- a/src/vuln_analysis/utils/function_name_locator.py +++ b/src/vuln_analysis/utils/function_name_locator.py @@ -138,11 +138,15 @@ def search_in_third_party_packages(self, package: str) -> bool: Returns: True if package is found in third party packages, False otherwise """ - + def is_same_package(package_name_from_input, package_name_from_tree): + return package_name_from_input.lower() == package_name_from_tree.lower() # Check in supported_packages list (handles None case) long_path = False - if self.coc_retriever.supported_packages is not None: - long_path = package in self.coc_retriever.supported_packages + + for supported_package in self.coc_retriever.supported_packages: + if is_same_package(package, supported_package): + long_path = True + # Check in short_go_package_name dict (keys) short_path = package in (self.short_go_package_name or {}) @@ -199,7 +203,7 @@ async def locate_functions(self, query: str) -> list[str]: ) ] else: - is_standard_lib_api = await quick_standard_lib_check(package, self.coc_retriever.ecosystem) + is_standard_lib_api,is_error = await quick_standard_lib_check(package, self.coc_retriever.ecosystem) if is_standard_lib_api: self.is_std_package = True self.is_package_valid = True @@ -219,12 +223,15 @@ async def locate_functions(self, query: str) -> list[str]: f"{self.coc_retriever.ecosystem.name if self.coc_retriever.ecosystem else 'Unknown'}" ) else: - error_msg = ( - f"ERROR: Package '{package}' not found in available packages. " - f"No close matches found. " - f"Available ecosystem: " - f"{self.coc_retriever.ecosystem.name if self.coc_retriever.ecosystem else 'Unknown'}" - ) + if not is_error: + error_msg = ( + f"ERROR: Package '{package}' not found in available packages. " + f"No close matches found. " + f"Available ecosystem: " + f"{self.coc_retriever.ecosystem.name if self.coc_retriever.ecosystem else 'Unknown'}" + ) + else: + error_msg = ("UNKNOWN could not determine if package is standard library") logger.error(error_msg) return [error_msg] # Return error message that LLM can see # Fallback: on cache miss, sparsely verify via API @@ -281,7 +288,7 @@ async def locate_functions(self, query: str) -> list[str]: logger.error("Error locating functions in package '%s': %s", package, e) return [] -async def quick_standard_lib_check(package_name: str, ecosystem: Ecosystem) -> bool: +async def quick_standard_lib_check(package_name: str, ecosystem: Ecosystem) -> tuple[bool, bool]: """Quick check if package is standard library Args: package_name: The package name to check @@ -289,23 +296,25 @@ async def quick_standard_lib_check(package_name: str, ecosystem: Ecosystem) -> b Returns: True if package is standard library, False otherwise """ + try: + search = MorpheusSerpAPIWrapper(max_retries=2) + result = await search.arun(f"Is '{package_name}' part of the {ecosystem.value} standard library?") + logger.info("quick_standard_lib_check Standard library check result: %s", result) + # Normalize result: if list, join into single string + if isinstance(result, list): + text = " ".join(result).lower() + elif isinstance(result, str): + text = result.lower() + else: + text = str(result).lower() - search = MorpheusSerpAPIWrapper(max_retries=2) - result = await search.arun(f"Is '{package_name}' part of the {ecosystem.value} standard library?") - logger.info("quick_standard_lib_check Standard library check result: %s", result) -# Normalize result: if list, join into single string - if isinstance(result, list): - text = " ".join(result).lower() - elif isinstance(result, str): - text = result.lower() - else: - text = str(result).lower() - - # Basic positive signals and avoidance of negative phrasing - if "part of the standard library" in text or \ - "is a standard library" in text or \ - ("standard library" in text and "not" not in text): - return True - - return False - + # Basic positive signals and avoidance of negative phrasing + if "part of the standard library" in text or \ + "is a standard library" in text or \ + ("standard library" in text and "not" not in text): + return True, False + + return False, False + except Exception as e: + logger.error("Error checking if package is standard library: %s", e) + return False, True From fc5983ad1be4448527a9edade2307086ba7692d0 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Tue, 3 Mar 2026 11:33:24 +0000 Subject: [PATCH 28/60] javascript improvments round2 --- .tekton/on-pull-request.yaml | 1 + .../javascript_functions_parser.py | 15 +++++++++++++++ .../utils/javascript_extended_segmenter.py | 19 +++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index 07bc25c8..b92e5b74 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -243,6 +243,7 @@ spec: #clean the java cache #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* print_banner "RUNNING UNIT TESTS" make test-unit PYTEST_OPTS="--log-cli-level=DEBUG" diff --git a/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py b/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py index a9e3252d..7529f66f 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py +++ b/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py @@ -49,6 +49,21 @@ def get_function_name(self, function: Document) -> str: if match: return match.group(1) + # Try to match wrapper calls: var name = wrapper(function(...) { or var name = wrapper(() => + match = re.search(r'(?:const|let|var)\s+(\w+)\s*=\s*\w+\s*\(', content) + if match: + return match.group(1) + + # Try to match property assignment functions: obj.name = function(...) { + match = re.search(r'(?:[\w.]+)\.(\w+)\s*=\s*(?:async\s+)?function\s*\(', content) + if match: + return match.group(1) + + # Try to match property assignment arrow functions: obj.name = (...) => + match = re.search(r'(?:[\w.]+)\.(\w+)\s*=\s*(?:async\s*)?\([^)]*\)\s*=>', content) + if match: + return match.group(1) + # Try to match computed property methods: [Symbol.iterator]() { or [expr]() { # Return the computed expression inside brackets match = re.search(r'^\s*\[([^\]]+)\]\s*\([^)]*\)\s*\{', content, re.MULTILINE) diff --git a/src/exploit_iq_commons/utils/javascript_extended_segmenter.py b/src/exploit_iq_commons/utils/javascript_extended_segmenter.py index eb4fc9ab..d5168a36 100644 --- a/src/exploit_iq_commons/utils/javascript_extended_segmenter.py +++ b/src/exploit_iq_commons/utils/javascript_extended_segmenter.py @@ -106,6 +106,16 @@ def extract_functions_classes(self) -> List[str]: esprima.nodes.ClassDeclaration)): functions_classes.append(self._extract_code(node)) + # Handle property assignments with function expressions, + # e.g. hb.compile = function(input, options) { ... } + elif isinstance(node, esprima.nodes.ExpressionStatement): + if isinstance(node.expression, esprima.nodes.AssignmentExpression): + if isinstance(node.expression.right, ( + esprima.nodes.FunctionExpression, + esprima.nodes.ArrowFunctionExpression + )): + functions_classes.append(self._extract_code(node)) + # Extract individual class methods class_methods = self._extract_all_class_methods() functions_classes.extend(class_methods) @@ -178,6 +188,15 @@ def _extract_arrow_functions(self, var_node: esprima.nodes.VariableDeclaration) esprima.nodes.FunctionExpression )) + # Check if init is a CallExpression wrapping a function, + # e.g. var defaultsDeep = baseRest(function(args) { ... }) + if not is_function_expr and isinstance(declarator.init, esprima.nodes.CallExpression): + is_function_expr = any( + isinstance(arg, (esprima.nodes.FunctionExpression, + esprima.nodes.ArrowFunctionExpression)) + for arg in declarator.init.arguments + ) + if is_function_expr: # Extract the entire variable declaration including the assignment code = self._extract_code(var_node) From b3c99bae56857feb1b54deaf82eba4e90a72c601 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 4 Mar 2026 09:19:53 +0200 Subject: [PATCH 29/60] before cleanup and rollback --- src/vuln_analysis/functions/cve_agent.py | 24 ++++++++++++++++--- .../functions/react_internals.py | 8 +++++++ .../utils/function_name_locator.py | 1 + 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 1e0d198b..2e9fed9a 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -36,7 +36,7 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt, AgentState, Thought,Observation, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT +from vuln_analysis.functions.react_internals import build_system_prompt, AgentState, Thought, Observation, Classification, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT from vuln_analysis.utils.prompting import build_tool_descriptions from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from langgraph.graph import StateGraph, END, START @@ -113,6 +113,7 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) thought_llm = llm.with_structured_output(Thought) observation_llm = llm.with_structured_output(Observation) + reachability_llm = llm.with_structured_output(Classification) tool_guidance = "\n".join(tool_guidance_list) descriptions = "\n".join(tool_descriptions_list) default_system_prompt = build_system_prompt(descriptions, tool_guidance) @@ -234,11 +235,27 @@ async def pre_process_node(state: AgentState) -> AgentState: "Use the vulnerable module name from GHSA as primary investigation target." ) + question = state.get("input") or "" + context_block = "\n".join(critical_context) + classification_prompt = ( + "You are classifying a CVE investigation question.\n\n" + "Context (CVE / vulnerable packages):\n" + f"{context_block}\n\n" + f"Question: {question}\n\n" + "A reachability question asks whether the vulnerable code/symbol is called or reachable in the codebase, " + "or whether untrusted data can reach it. Is this a reachability question? Answer only yes or no." + ) + classification_result: Classification = await reachability_llm.ainvoke([HumanMessage(content=classification_prompt)]) + span.set_output({ + "critical_context": critical_context, + "reachability_question": classification_result.is_reachability, + }) + tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) - span.set_output({"critical_context": critical_context}) return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt, "observation": Observation(memory=critical_context, results=[])} + async def thought_node(state: AgentState) -> AgentState: active_prompt = state.get("runtime_prompt") or default_system_prompt messages = [SystemMessage(content=active_prompt)] + state["messages"] @@ -393,9 +410,10 @@ async def create_graph(): flow.add_edge(TOOL_NODE, OBSERVATION_NODE) flow.add_edge(OBSERVATION_NODE, THOUGHT_NODE) flow.add_edge(FORCED_FINISH_NODE, END) + app = flow.compile() if config.verbose: - app.get_graph().draw_mermaid_png(output_file_path="flow.png") + app.get_graph().draw_mermaid_png(output_file_path="flow.png") return app return await create_graph() async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 6f712362..c07cc5c0 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -49,6 +49,14 @@ class Observation(BaseModel): description="A list of cumulative factual findings. Each item is a single, discrete technical fact." ) + +class Classification(BaseModel): + """Structured output for reachability-question classification.""" + is_reachability: Literal["yes", "no"] = Field( + description="Answer 'yes' if the question is a reachability question, 'no' otherwise." + ) + + class AgentState(MessagesState): input: str = "" step: int = 0 diff --git a/src/vuln_analysis/utils/function_name_locator.py b/src/vuln_analysis/utils/function_name_locator.py index 8ffa670b..55e866ab 100644 --- a/src/vuln_analysis/utils/function_name_locator.py +++ b/src/vuln_analysis/utils/function_name_locator.py @@ -146,6 +146,7 @@ def is_same_package(package_name_from_input, package_name_from_tree): for supported_package in self.coc_retriever.supported_packages: if is_same_package(package, supported_package): long_path = True + break # Check in short_go_package_name dict (keys) From 352a60087dcf3deb8d2719bd1bcbd5f545f20f4a Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 4 Mar 2026 09:24:38 +0200 Subject: [PATCH 30/60] clean: move classification query prompt to internals --- src/vuln_analysis/functions/cve_agent.py | 11 ++--------- src/vuln_analysis/functions/react_internals.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 2e9fed9a..9d441692 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -36,7 +36,7 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt, AgentState, Thought, Observation, Classification, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT +from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, AgentState, Thought, Observation, Classification, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT from vuln_analysis.utils.prompting import build_tool_descriptions from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from langgraph.graph import StateGraph, END, START @@ -237,14 +237,7 @@ async def pre_process_node(state: AgentState) -> AgentState: question = state.get("input") or "" context_block = "\n".join(critical_context) - classification_prompt = ( - "You are classifying a CVE investigation question.\n\n" - "Context (CVE / vulnerable packages):\n" - f"{context_block}\n\n" - f"Question: {question}\n\n" - "A reachability question asks whether the vulnerable code/symbol is called or reachable in the codebase, " - "or whether untrusted data can reach it. Is this a reachability question? Answer only yes or no." - ) + classification_prompt = build_classification_prompt(context_block, question) classification_result: Classification = await reachability_llm.ainvoke([HumanMessage(content=classification_prompt)]) span.set_output({ "critical_context": critical_context, diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index c07cc5c0..ef5ed1e9 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -128,6 +128,15 @@ class AgentState(MessagesState): {{"thought": "Function Locator confirmed the package. Now trace reachability with Call Chain Analyzer", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} """ +CLASSIFICATION_PROMPT_TEMPLATE = """You are classifying a CVE investigation question. + +Context (CVE / vulnerable packages): +{context_block} + +Question: {question} + +A reachability question asks whether the vulnerable code/symbol is called or reachable in the codebase, or whether untrusted data can reach it. Is this a reachability question? Answer only yes or no.""" + FORCED_FINISH_PROMPT = """Maximum steps reached. You MUST set mode="finish" and provide final_answer NOW. Do NOT call any more tools. Summarize your evidence in 3-5 sentences. RESPONSE: @@ -171,6 +180,14 @@ def build_system_prompt( tool_selection_strategy=tool_guidance, ) + +def build_classification_prompt(context_block: str, question: str) -> str: + """Build the reachability-question classification prompt from context and user question.""" + return CLASSIFICATION_PROMPT_TEMPLATE.format( + context_block=context_block, + question=question, + ) + def _build_tool_arguments(actions: ToolCall)->dict[str, Any]: pkg_tools = {"Function Locator", "Function Caller Finder", "Call Chain Analyzer"} if actions.tool in pkg_tools and actions.package_name and actions.function_name: From b8b0e36e243ef30043b89ff0188f14257d417006 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 4 Mar 2026 16:01:05 +0000 Subject: [PATCH 31/60] Round 1 cleanup Inteli don't support multi package --- .../utils/chain_of_calls_retriever.py | 6 +- src/vuln_analysis/functions/cve_agent.py | 126 ++++++++++++++++-- .../functions/react_internals.py | 104 +++++++++++++-- 3 files changed, 217 insertions(+), 19 deletions(-) diff --git a/src/exploit_iq_commons/utils/chain_of_calls_retriever.py b/src/exploit_iq_commons/utils/chain_of_calls_retriever.py index 277ac439..f14935a2 100644 --- a/src/exploit_iq_commons/utils/chain_of_calls_retriever.py +++ b/src/exploit_iq_commons/utils/chain_of_calls_retriever.py @@ -571,7 +571,11 @@ def __find_initial_function(self, function_name: str, package_name: str, documen package_exclusions = self.tree_dict.get(package_name)[EXCLUSIONS_INDEX] #for index, document in enumerate(get_functions_for_package(package_name, relevant_docs, language_parser)): - for document in self.get_functions_for_package(package_name, relevant_docs): + from itertools import chain + for document in chain( + self.get_functions_for_package(package_name, relevant_docs, sources_location_packages=True), + self.get_functions_for_package(package_name, relevant_docs, sources_location_packages=False), + ): # document_function_calls_input_function = True if function_name.lower() == self.language_parser.get_function_name(document).lower(): # if language_parser.search_for_called_function(document, callee_function=function_name): diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 9d441692..258762bc 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -36,12 +36,20 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, AgentState, Thought, Observation, Classification, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT +from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT from vuln_analysis.utils.prompting import build_tool_descriptions from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage + +ECOSYSTEM_DEP_DIRS = { + "c": "rpm_libs", + "go": "vendor", + "javascript": "node_modules", + "python": "transitive_env", + "java": "dependencies-sources", +} import uuid import tiktoken from nat.builder.context import Context @@ -114,6 +122,7 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build thought_llm = llm.with_structured_output(Thought) observation_llm = llm.with_structured_output(Observation) reachability_llm = llm.with_structured_output(Classification) + package_filter_llm = llm.with_structured_output(PackageSelection) tool_guidance = "\n".join(tool_guidance_list) descriptions = "\n".join(tool_descriptions_list) default_system_prompt = build_system_prompt(descriptions, tool_guidance) @@ -168,9 +177,16 @@ def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list) -> descriptions_local = "\n".join(list_of_tool_descriptions) return tool_guidance_local, descriptions_local - def _build_critical_context(cve_intel_list) -> list[str]: - """Extract key facts from all available intel sources into a compact context.""" + def _build_critical_context(cve_intel_list) -> tuple[list[str], list[dict]]: + """Extract key facts from all available intel sources into a compact context. + + Returns (critical_context, candidate_packages) where candidate_packages + contains dicts with 'name', 'source', and optional 'ecosystem' keys. + """ critical_context = [] + candidate_packages: list[dict] = [] + seen_packages: set[str] = set() + for cve_intel in cve_intel_list: if cve_intel.nvd is not None: if cve_intel.nvd.cve_description: @@ -187,14 +203,23 @@ def _build_critical_context(cve_intel_list) -> list[str]: if vf: critical_context.append(f"Vulnerable functions (GHSA): {', '.join(vf)}") + short_names = [f.rsplit('.', 1)[-1] for f in vf if '.' in f] + if short_names: + critical_context.append(f"Search keywords: {', '.join(short_names)}") if pkg: if isinstance(pkg, dict): pkg_name = pkg.get("name", "") pkg_eco = pkg.get("ecosystem", "") if pkg_name: critical_context.append(f"Vulnerable module ({pkg_eco}): {pkg_name}") + if pkg_name not in seen_packages: + seen_packages.add(pkg_name) + candidate_packages.append({"name": pkg_name, "source": "ghsa", "ecosystem": pkg_eco}) elif isinstance(pkg, str): critical_context.append(f"Affected package: {pkg}") + if pkg not in seen_packages: + seen_packages.add(pkg) + candidate_packages.append({"name": pkg, "source": "ghsa"}) if cve_intel.ghsa.description and not any("CVE Description" in c for c in critical_context): critical_context.append(f"CVE Description: {cve_intel.ghsa.description[:400]}") @@ -203,6 +228,10 @@ def _build_critical_context(cve_intel_list) -> list[str]: critical_context.append(f"RHSA Statement: {cve_intel.rhsa.statement[:300]}") if cve_intel.rhsa.package_state: pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name)) + for p in pkgs: + if p not in seen_packages: + seen_packages.add(p) + candidate_packages.append({"name": p, "source": "rhsa"}) if len(pkgs) > 5: critical_context.append( f"Affected across {len(pkgs)} Red Hat products (sample: {', '.join(pkgs[:5])}). " @@ -222,14 +251,82 @@ def _build_critical_context(cve_intel_list) -> list[str]: if not critical_context: critical_context = ["No CVE intel available. Investigate using tools."] - return critical_context + return critical_context, candidate_packages + + def _filter_context_to_package(critical_context: list[str], selected: str, all_candidates: list[dict]) -> list[str]: + """Remove context entries and rejected package name references for non-selected packages.""" + rejected_names = {c["name"] for c in all_candidates if c["name"] != selected} + filtered = [] + for entry in critical_context: + if entry.startswith("INVESTIGATE EACH package:"): + filtered.append(f"Target package: {selected}") + continue + if entry.startswith("Vulnerable module (") or entry.startswith("Affected package:"): + if any(rn in entry for rn in rejected_names): + continue + for rn in rejected_names: + entry = entry.replace(f" {rn} ", " ") + entry = entry.replace(f" {rn},", ",") + filtered.append(entry) + return filtered + + def _group_code_search_results(tool_output: str, ecosystem: str, app_package: str) -> str: + """Split Code Keyword Search results into main-application vs dependency groups.""" + dep_dir = ECOSYSTEM_DEP_DIRS.get(ecosystem) + if not dep_dir: + return tool_output + try: + parsed = json.loads(tool_output) + if not isinstance(parsed, list): + return tool_output + except (json.JSONDecodeError, TypeError): + return tool_output + + main_app_results = [] + dep_results = [] + dep_prefix = dep_dir + "/" + for item in parsed: + source = item.get("source", "") if isinstance(item, dict) else "" + if source.startswith(dep_prefix): + dep_results.append(item) + else: + main_app_results.append(item) + + parts = [] + if main_app_results: + parts.append(f"Main application - {app_package}\n{json.dumps(main_app_results)}") + if dep_results: + parts.append(f"Application library dependencies\n{json.dumps(dep_results)}") + if not parts: + return tool_output + return "\n\n".join(parts) async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span: - critical_context = _build_critical_context(workflow_state.cve_intel) + critical_context, candidate_packages = _build_critical_context(workflow_state.cve_intel) + + selected_package = None + app_package = None + if len(candidate_packages) > 1: + image_input = workflow_state.original_input.input.image + image_name = image_input.name + source_repos = image_input.source_info + image_repo = source_repos[0].git_repo if source_repos else None + filter_prompt = build_package_filter_prompt( + ecosystem, candidate_packages, + image_name=image_name, image_repo=image_repo, + critical_context=critical_context, + ) + selection: PackageSelection = await package_filter_llm.ainvoke([HumanMessage(content=filter_prompt)]) + selected_package = selection.selected_package + app_package = selected_package + logger.info("Package filter selected '%s' from %d candidates (reason: %s)", + selected_package, len(candidate_packages), selection.reason) + critical_context = _filter_context_to_package(critical_context, selected_package, candidate_packages) + critical_context.append( "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " "Use the vulnerable module name from GHSA as primary investigation target." @@ -241,12 +338,20 @@ async def pre_process_node(state: AgentState) -> AgentState: classification_result: Classification = await reachability_llm.ainvoke([HumanMessage(content=classification_prompt)]) span.set_output({ "critical_context": critical_context, + "candidate_packages": candidate_packages, + "selected_package": selected_package, + "app_package": app_package if selected_package else None, "reachability_question": classification_result.is_reachability, }) tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) - return {"ecosystem": ecosystem, "runtime_prompt": runtime_prompt, "observation": Observation(memory=critical_context, results=[])} + return { + "ecosystem": ecosystem, + "runtime_prompt": runtime_prompt, + "observation": Observation(memory=critical_context, results=[]), + "app_package": app_package if selected_package else None, + } async def thought_node(state: AgentState) -> AgentState: @@ -294,7 +399,6 @@ async def thought_node(state: AgentState) -> AgentState: "thought": response, "step": state.get("step", 0) + 1, "max_steps": config.max_iterations, - "observation": None, "output": final_answer } @@ -351,13 +455,19 @@ async def observation_node(state: AgentState) -> AgentState: tool_input_detail = actions.tool_input previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."] with AGENT_TRACER.push_active_function("observation node", input_data=f"tool used:{tool_used}") as span: + tool_output_for_llm = tool_message.content + if tool_used == "Code Keyword Search" and state.get("app_package") and state.get("ecosystem"): + tool_output_for_llm = _group_code_search_results( + tool_message.content, state["ecosystem"], state["app_package"], + ) prompt = OBSERVATION_NODE_PROMPT.format( goal=state.get('input'), + selected_package=state.get('app_package') or "N/A", previous_memory=previous_memory, tool_used=tool_used, tool_input_detail=tool_input_detail, last_thought_text=last_thought_text, - tool_output=tool_message.content, + tool_output=tool_output_for_llm, ) new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index ef5ed1e9..108d859b 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -57,6 +57,16 @@ class Classification(BaseModel): ) +class PackageSelection(BaseModel): + """Structured output for selecting the most relevant package from multiple candidates.""" + selected_package: str = Field( + description="The exact name of the single most relevant package to investigate, copied verbatim from the candidates list." + ) + reason: str = Field( + description="One-sentence justification for why this package is the best investigation target." + ) + + class AgentState(MessagesState): input: str = "" step: int = 0 @@ -67,6 +77,7 @@ class AgentState(MessagesState): output: str = "" ecosystem: str | None = None runtime_prompt: str | None = None + app_package: str | None = None ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# @@ -78,13 +89,7 @@ class AgentState(MessagesState): "3. TRACE reachability using Call Chain Analyzer. " " - Use the Function Locator to verify the package name and find the function name." " - Keyword search alone is NOT sufficient -- you must trace the call chain.\n" - "4. If MULTIPLE packages are listed, repeat steps 2-3 for EACH package.\n" - "5. ASSESS: only after completing reachability checks, determine exploitability.\n" - "STOPPING RULES:\n" - "- POSITIVE reachability (function IS reachable): you MAY conclude exploitable and finish.\n" - "- NEGATIVE reachability (function NOT reachable in a package): you MUST continue " - "and check the NEXT package. A negative result in one package does not prove non-exploitability.\n" - "- You may only conclude 'not exploitable' after ALL packages have been checked and ALL returned negative.\n" + "4. ASSESS: only after completing reachability checks, determine exploitability.\n" "GENERAL RULES:\n" "- Base conclusions ONLY on tool results, not assumptions.\n" "- If a search returns no results, that is evidence the code is absent.\n" @@ -92,7 +97,11 @@ class AgentState(MessagesState): "- Code Keyword Search proves code PRESENCE in the container, NOT reachability.\n" "- Function Locator validates package/function NAMES, NOT reachability. It confirms the name exists, not that it is called.\n" "- Only Call Chain Analyzer can confirm reachability. The application may contain code it never calls.\n" - "- When the question asks whether a function is called or reachable, do NOT conclude based on Code Keyword Search or Function Locator alone -- you MUST use Call Chain Analyzer." + "- When the question asks whether a function is called or reachable, do NOT conclude based on Code Keyword Search or Function Locator alone -- you MUST use Call Chain Analyzer.\n" + "STOPPING RULES:\n" + "- POSITIVE reachability (Call Chain Analyzer returns True): you MAY conclude exploitable and finish.\n" + "- NEGATIVE reachability (Call Chain Analyzer returns False): record the result. " + "You may only conclude 'not exploitable' after Call Chain Analyzer has confirmed the function is not reachable." ) # Update LANGGRAPH_SYSTEM_PROMPT_TEMPLATE in react_internals.py @@ -116,8 +125,9 @@ class AgentState(MessagesState): 2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. 3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. 4. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. -5. If Call Chain Analyzer returns POSITIVE (reachable), you may finish. If NEGATIVE, check KNOWLEDGE for "INVESTIGATE EACH package:" list -- you MUST try the next unchecked package before finishing. Code Keyword Search and Function Locator are NOT reachability proof -- only Call Chain Analyzer is. +5. Code Keyword Search and Function Locator are NOT reachability proof -- only Call Chain Analyzer is. 6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. +7. If Code Keyword Search returns no results and the query contains dots (e.g. a.b.ClassName), retry with just the final component (e.g. ClassName). This does NOT apply to simple names without dots. {{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} @@ -137,6 +147,19 @@ class AgentState(MessagesState): A reachability question asks whether the vulnerable code/symbol is called or reachable in the codebase, or whether untrusted data can reach it. Is this a reachability question? Answer only yes or no.""" +PACKAGE_FILTER_PROMPT_TEMPLATE = """The container runs "{image_context}". +{image_match_note} +Candidate packages: +{candidates} + +Ecosystem: {ecosystem} +{critical_context_section} +Rules: +1. If a candidate was pre-identified as matching the container image/repo (see note above), you MUST select it. This OVERRIDES all other rules. Do NOT select a sub-library or component package. +2. Otherwise, pick the package whose ecosystem matches "{ecosystem}". +3. Discard ecosystem-repackaging wrappers (e.g. maven webjars wrapping an npm library) in favour of the native package. +4. Return the package name exactly as it appears in the candidates list above.""" + FORCED_FINISH_PROMPT = """Maximum steps reached. You MUST set mode="finish" and provide final_answer NOW. Do NOT call any more tools. Summarize your evidence in 3-5 sentences. RESPONSE: @@ -144,6 +167,7 @@ class AgentState(MessagesState): OBSERVATION_NODE_PROMPT = """Update the investigation memory based on new tool output. GOAL: {goal} +TARGET PACKAGE (vulnerability): {selected_package} PREVIOUS MEMORY: {previous_memory} TOOL USED: {tool_used} TOOL INPUT: {tool_input_detail} @@ -156,9 +180,10 @@ class AgentState(MessagesState): - If NEW OUTPUT is empty, contains an error, or indicates tool failure, record in memory: "FAILED: {tool_used} [{tool_input_detail}] - [reason]". Do NOT infer or fabricate positive findings from failed tool output. - results from a failed tool call must only state the failure, not speculate about what the tool might have found. - Reachability tags apply ONLY to Call Chain Analyzer results. Code Keyword Search and Function Locator do NOT determine reachability. -- If Call Chain Analyzer returned NEGATIVE (False), add to memory: "NOT reachable via [package]. Check INVESTIGATE list for next unchecked package." +- If Call Chain Analyzer returned NEGATIVE (False), add to memory: "NOT reachable via [package]." - If Call Chain Analyzer returned POSITIVE (True), add: "REACHABLE via [package] - sufficient evidence." - For Function Locator results, record: "VALIDATED: [package],[function] exists" -- this is NOT reachability. +- When Code Keyword Search results are grouped into "Main application" and "Application library dependencies", prioritize findings from the main application group and use its package name as package_name for subsequent tool calls. - Always record in memory: "CALLED: [tool] with [input] -> [outcome]" so future steps know what was already tried. - results: 3-5 key technical facts from this OUTPUT only. - Keep only CVE-exploitability-relevant information. @@ -188,6 +213,65 @@ def build_classification_prompt(context_block: str, question: str) -> str: question=question, ) + +def _find_image_matching_candidate( + candidates: list[dict], + image_name: str | None, + image_repo: str | None, +) -> str | None: + """Return the candidate name that appears in the image/repo string, or None.""" + context = " ".join( + part.lower() for part in (image_name, image_repo) if part + ) + if not context: + return None + for c in candidates: + name = c["name"].lower() + if len(name) >= 3 and name in context: + return c["name"] + return None + + +def build_package_filter_prompt( + ecosystem: str, + candidates: list[dict], + image_name: str | None = None, + image_repo: str | None = None, + critical_context: list[str] | None = None, +) -> str: + """Build the package-selection prompt from ecosystem, candidate packages, and container image info.""" + candidate_lines = "\n".join( + f"- {c['name']} (source: {c.get('source', 'unknown')}, ecosystem: {c.get('ecosystem', 'N/A')})" + for c in candidates + ) + parts = [] + if image_name: + parts.append(image_name) + if image_repo: + parts.append(f"repo: {image_repo}") + image_context = ", ".join(parts) if parts else "unknown" + + matched = _find_image_matching_candidate(candidates, image_name, image_repo) + if matched: + image_match_note = f'MATCH DETECTED: candidate "{matched}" matches the container image/repo. Select it (Rule 1).' + critical_context_section = "" + else: + image_match_note = "NO MATCH: no candidate package name was found in the image/repo identifier. Rule 1 does not apply — use Rule 2." + if critical_context: + context_block = "\n".join(critical_context) + critical_context_section = f"\nVulnerability context (use to disambiguate candidates):\n{context_block}\n" + else: + critical_context_section = "" + + return PACKAGE_FILTER_PROMPT_TEMPLATE.format( + ecosystem=ecosystem, + candidates=candidate_lines, + image_context=image_context, + image_match_note=image_match_note, + critical_context_section=critical_context_section, + ) + + def _build_tool_arguments(actions: ToolCall)->dict[str, Any]: pkg_tools = {"Function Locator", "Function Caller Finder", "Call Chain Analyzer"} if actions.tool in pkg_tools and actions.package_name and actions.function_name: From 5e16242abda7f5da6f304cebe45cb6d50d90948b Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 4 Mar 2026 16:14:51 +0000 Subject: [PATCH 32/60] remove del --- .tekton/on-pull-request.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index d7f4ef99..38c4dcb9 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -247,7 +247,7 @@ spec: #clean the java cache #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* - rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* + #rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* print_banner "RUNNING UNIT TESTS" make test-unit PYTEST_OPTS="--log-cli-level=DEBUG" From 962b81f1d9c4e2088590d22efa935c84a8e56319 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 5 Mar 2026 07:20:11 +0000 Subject: [PATCH 33/60] Add Rule Tracker check for rule no. 7 --- src/vuln_analysis/functions/cve_agent.py | 11 ++++- .../functions/react_internals.py | 47 ++++++++++++++++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 258762bc..858a03f2 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -36,7 +36,7 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT +from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT from vuln_analysis.utils.prompting import build_tool_descriptions from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from langgraph.graph import StateGraph, END, START @@ -454,12 +454,18 @@ async def observation_node(state: AgentState) -> AgentState: elif actions.tool_input: tool_input_detail = actions.tool_input previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."] + rules_tracker = state.get("rules_tracker") with AGENT_TRACER.push_active_function("observation node", input_data=f"tool used:{tool_used}") as span: tool_output_for_llm = tool_message.content if tool_used == "Code Keyword Search" and state.get("app_package") and state.get("ecosystem"): tool_output_for_llm = _group_code_search_results( tool_message.content, state["ecosystem"], state["app_package"], ) + result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm) + if result: + span.set_output({"rule_error": error_message}) + return {"messages": [HumanMessage(content=error_message)]} + prompt = OBSERVATION_NODE_PROMPT.format( goal=state.get('input'), selected_package=state.get('app_package') or "N/A", @@ -580,7 +586,8 @@ async def call_agent(initial_state,config=None): "max_steps": max_iterations, "thought": None, "observation": None, - "output": "waiting for the agent to respond" + "output": "waiting for the agent to respond", + "rules_tracker": SystemRulesTracker(), } config = { "recursion_limit": 50 diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 108d859b..dede20c5 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -66,7 +66,51 @@ class PackageSelection(BaseModel): description="One-sentence justification for why this package is the best investigation target." ) - +class SystemRulesTracker: + def __init__(self): + self.action_history = {} + + @staticmethod + def _is_empty_result(output) -> bool: + """Check if a tool output represents an empty/no-results response. + + Handles both string ('[]', '') and list ([]) formats since + ToolMessage.content can be either type. + """ + if isinstance(output, list): + return len(output) == 0 + if isinstance(output, str): + return output.strip() in ("[]", "") + return False + + def add_action(self, action: str, action_input: str, output): + entry = {"input": action_input, "output": output} + if action not in self.action_history: + self.action_history[action] = [entry] + else: + self.action_history[action].append(entry) + + def _rule_number_7(self, action: str, action_input: str, output) -> bool: + if action != "Code Keyword Search": + return False + if "." not in action_input: + return False + if not self._is_empty_result(output): + return False + if action not in self.action_history: + return False + prev = self.action_history[action][-1] + if "." in prev["input"] and self._is_empty_result(prev["output"]): + return True + return False + + def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]: + if self._rule_number_7(action, action_input, output): + return True, ("You are NOT following Rule 7. Your query contains dots and returned " + "no results. You MUST retry with just the final component. Follow the rules.") + self.add_action(action, action_input, output) + return False, "" + class AgentState(MessagesState): input: str = "" step: int = 0 @@ -78,6 +122,7 @@ class AgentState(MessagesState): ecosystem: str | None = None runtime_prompt: str | None = None app_package: str | None = None + rules_tracker: SystemRulesTracker = SystemRulesTracker() ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# From dcdeaf4363c61cabea0de08c267bccd3970ff0a4 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 5 Mar 2026 07:54:14 +0000 Subject: [PATCH 34/60] exception protection and python fix and rule tracker --- ci/scripts/analyze_traces.py | 55 +++- src/exploit_iq_commons/utils/dep_tree.py | 11 +- src/vuln_analysis/functions/cve_agent.py | 353 ++++++++++++----------- 3 files changed, 245 insertions(+), 174 deletions(-) diff --git a/ci/scripts/analyze_traces.py b/ci/scripts/analyze_traces.py index 9d076de4..5824c128 100644 --- a/ci/scripts/analyze_traces.py +++ b/ci/scripts/analyze_traces.py @@ -20,6 +20,7 @@ REACHABILITY_TOOLS = {"Function Locator", "Call Chain Analyzer", "Function Caller Finder"} SEARCH_ONLY_TOOLS = {"Code Keyword Search", "Code Semantic Search", "Docs Semantic Search"} +NODE_FUNCTION_NAMES = {"thought node", "observation node", "pre_process node", "forced_finish node"} TOKEN_BUDGET_THRESHOLD = 6000 # 75% of 8k window TOKEN_BUDGET_WARNING = 5800 # early warning before hitting threshold CONTEXT_WINDOW_LIMIT = 8192 @@ -42,7 +43,7 @@ class QualityFlag: @dataclass class InvestigationStep: step_num: int - span_kind: str # LLM or TOOL + span_kind: str # LLM, TOOL, or FUNCTION name: str input_preview: str output_preview: str @@ -51,6 +52,8 @@ class InvestigationStep: token_total: int = 0 llm_role: str = "" # "thought" or "observation" for LLM spans thought_mode: str = "" # "act" or "finish" for thought spans + has_error: bool = False + error_detail: str = "" @dataclass @@ -136,6 +139,22 @@ def load_trace_files(input_path: str) -> list[dict]: return traces +def _check_span_error(span: dict) -> tuple[bool, str]: + """Check if a span's output_value contains an error recorded by try/except tracing.""" + output = span.get("output_value", "") + if not output: + return False, "" + try: + parsed = json.loads(output) + if isinstance(parsed, dict) and "error" in parsed: + exc_type = parsed.get("exception_type", "Exception") + error_msg = parsed["error"] + return True, f"{exc_type}: {error_msg}" + except (json.JSONDecodeError, TypeError): + pass + return False, "" + + def _classify_llm_span(span: dict) -> str: """Classify an LLM span as 'thought' or 'observation' based on output format.""" output = span.get("output_value", "") @@ -219,7 +238,7 @@ def _split_by_question(spans: list[dict]) -> dict[str, list[dict]]: elif last_thought_question: groups[last_thought_question].append(span) - elif kind == "TOOL": + elif kind in ("TOOL", "FUNCTION"): if last_thought_question: groups[last_thought_question].append(span) @@ -250,6 +269,8 @@ def _build_flow_from_spans( if role == "thought": mode = _extract_thought_mode(span.get("output_value", "")) + has_error, error_detail = _check_span_error(span) + step = InvestigationStep( step_num=i + 1, span_kind=kind, @@ -261,6 +282,8 @@ def _build_flow_from_spans( token_total=span.get("token_total", 0), llm_role=role, thought_mode=mode, + has_error=has_error, + error_detail=error_detail, ) flow.steps.append(step) @@ -294,8 +317,13 @@ def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: agent_spans = [ s for s in spans - if s.get("function_name") in ("cve_agent_executor", "checklist_question") - and s.get("span_kind") in ("LLM", "TOOL") + if ( + s.get("function_name") in ("cve_agent_executor", "checklist_question") + and s.get("span_kind") in ("LLM", "TOOL") + ) or ( + s.get("function_name") in NODE_FUNCTION_NAMES + and s.get("span_kind") == "FUNCTION" + ) ] parent_groups: dict[str, list[dict]] = {} @@ -328,6 +356,14 @@ def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: def run_quality_checks(flow: InvestigationFlow) -> list[QualityFlag]: flags = [] + error_steps = [s for s in flow.steps if s.has_error] + for es in error_steps: + flags.append(QualityFlag( + name="NODE_EXCEPTION", + severity=Severity.HIGH, + detail=f"Exception in {es.name} (step {es.step_num}): {es.error_detail}", + )) + tool_set = set(flow.tools_used) if not tool_set: @@ -453,17 +489,24 @@ def print_flow_report(flow: InvestigationFlow): print() for step in flow.steps: + error_marker = " *** ERROR ***" if step.has_error else "" if step.span_kind == "LLM": role_label = f" [{step.llm_role}]" if step.llm_role else "" mode_label = f" mode={step.thought_mode}" if step.thought_mode else "" - print(f"--- Step {step.step_num}: LLM{role_label}{mode_label} ---") + print(f"--- Step {step.step_num}: LLM{role_label}{mode_label}{error_marker} ---") if step.token_prompt: print(f" Prompt tokens: {step.token_prompt} | Completion: {step.token_completion}") print(f" Output: \"{step.output_preview}\"") elif step.span_kind == "TOOL": - print(f"--- Step {step.step_num}: TOOL ({step.name}) ---") + print(f"--- Step {step.step_num}: TOOL ({step.name}){error_marker} ---") print(f" Input: {step.input_preview}") print(f" Output: \"{step.output_preview}\"") + elif step.span_kind == "FUNCTION": + print(f"--- Step {step.step_num}: NODE ({step.name}){error_marker} ---") + if step.has_error: + print(f" Exception: {step.error_detail}") + else: + print(f" Output: \"{step.output_preview}\"") print() diff --git a/src/exploit_iq_commons/utils/dep_tree.py b/src/exploit_iq_commons/utils/dep_tree.py index b84a2051..9eb47868 100644 --- a/src/exploit_iq_commons/utils/dep_tree.py +++ b/src/exploit_iq_commons/utils/dep_tree.py @@ -996,15 +996,18 @@ def __parse_dependency_line(self, line: str) -> Tuple[Optional[int], Optional[st class PythonDependencyTreeBuilder(DependencyTreeBuilder): def build_tree(self, manifest_path: Path) -> defaultdict[Any, list]: - cmd = f'{manifest_path}/{TRANSITIVE_ENV_NAME}/bin/python -m pip install deptree' - run_command(cmd) - cmd = f'{manifest_path}/{TRANSITIVE_ENV_NAME}/bin/deptree', + venv_python = f'{manifest_path}/{TRANSITIVE_ENV_NAME}/bin/python' + run_command(f'{venv_python} -m pip install "setuptools<81" deptree') + cmd = f'{manifest_path}/{TRANSITIVE_ENV_NAME}/bin/deptree' dependencies = run_command(cmd) + if not dependencies or not dependencies.strip(): + logger.error("deptree returned empty output — third-party dependency tree will be incomplete. " + "Check that the virtual environment at %s has packages installed.", manifest_path) parent_stack = [] tree = defaultdict(set) ROOT_PROJECT = 'root_project' tree[ROOT_PROJECT] = [ROOT_LEVEL_SENTINEL] - for line in dependencies.split(os.linesep): + for line in dependencies.split(os.linesep) if dependencies else []: level = 0 while line.startswith(' '): level += 1 diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 858a03f2..6c012464 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -305,102 +305,114 @@ async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span: - - critical_context, candidate_packages = _build_critical_context(workflow_state.cve_intel) - - selected_package = None - app_package = None - if len(candidate_packages) > 1: - image_input = workflow_state.original_input.input.image - image_name = image_input.name - source_repos = image_input.source_info - image_repo = source_repos[0].git_repo if source_repos else None - filter_prompt = build_package_filter_prompt( - ecosystem, candidate_packages, - image_name=image_name, image_repo=image_repo, - critical_context=critical_context, + try: + critical_context, candidate_packages = _build_critical_context(workflow_state.cve_intel) + + selected_package = None + app_package = None + if len(candidate_packages) > 1: + image_input = workflow_state.original_input.input.image + image_name = image_input.name + source_repos = image_input.source_info + image_repo = source_repos[0].git_repo if source_repos else None + filter_prompt = build_package_filter_prompt( + ecosystem, candidate_packages, + image_name=image_name, image_repo=image_repo, + critical_context=critical_context, + ) + selection: PackageSelection = await package_filter_llm.ainvoke([HumanMessage(content=filter_prompt)]) + selected_package = selection.selected_package + app_package = selected_package + logger.info("Package filter selected '%s' from %d candidates (reason: %s)", + selected_package, len(candidate_packages), selection.reason) + critical_context = _filter_context_to_package(critical_context, selected_package, candidate_packages) + + critical_context.append( + "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " + "Use the vulnerable module name from GHSA as primary investigation target." ) - selection: PackageSelection = await package_filter_llm.ainvoke([HumanMessage(content=filter_prompt)]) - selected_package = selection.selected_package - app_package = selected_package - logger.info("Package filter selected '%s' from %d candidates (reason: %s)", - selected_package, len(candidate_packages), selection.reason) - critical_context = _filter_context_to_package(critical_context, selected_package, candidate_packages) - - critical_context.append( - "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " - "Use the vulnerable module name from GHSA as primary investigation target." - ) - - question = state.get("input") or "" - context_block = "\n".join(critical_context) - classification_prompt = build_classification_prompt(context_block, question) - classification_result: Classification = await reachability_llm.ainvoke([HumanMessage(content=classification_prompt)]) - span.set_output({ - "critical_context": critical_context, - "candidate_packages": candidate_packages, - "selected_package": selected_package, - "app_package": app_package if selected_package else None, - "reachability_question": classification_result.is_reachability, - }) - - tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) - runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) - return { - "ecosystem": ecosystem, - "runtime_prompt": runtime_prompt, - "observation": Observation(memory=critical_context, results=[]), - "app_package": app_package if selected_package else None, - } + + question = state.get("input") or "" + context_block = "\n".join(critical_context) + classification_prompt = build_classification_prompt(context_block, question) + classification_result: Classification = await reachability_llm.ainvoke([HumanMessage(content=classification_prompt)]) + span.set_output({ + "critical_context": critical_context, + "candidate_packages": candidate_packages, + "selected_package": selected_package, + "app_package": app_package if selected_package else None, + "reachability_question": classification_result.is_reachability, + }) + + tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) + runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) + return { + "ecosystem": ecosystem, + "runtime_prompt": runtime_prompt, + "observation": Observation(memory=critical_context, results=[]), + "app_package": app_package if selected_package else None, + } + except Exception as e: + logger.exception("pre_process_node failed") + span.set_output({"error": str(e), "exception_type": type(e).__name__}) + raise async def thought_node(state: AgentState) -> AgentState: - active_prompt = state.get("runtime_prompt") or default_system_prompt - messages = [SystemMessage(content=active_prompt)] + state["messages"] - obs = state.get("observation", None) - if obs is not None: - memory_list = obs.memory if obs.memory else ["No prior knowledge."] - recent_findings = obs.results if obs.results else ["No recent findings."] - memory_context = "\n".join(f"- {m}" for m in memory_list) - findings_context = "\n".join(f"- {f}" for f in recent_findings) - context_block = f"KNOWLEDGE:\n{memory_context}\nLATEST FINDINGS:\n{findings_context}" - messages.append(SystemMessage(content=context_block)) - response: Thought = await thought_llm.ainvoke(messages) - - final_answer = "waiting for the agent to respond" - if response.mode == "finish": - ai_message = AIMessage(content=response.final_answer) - final_answer = response.final_answer - elif response.actions is None: - logger.warning("LLM returned mode='act' but actions is None, forcing finish") - ai_message = AIMessage(content=response.thought or "No actions provided, finishing.") - response = Thought( - thought=response.thought or "No actions provided", - mode="finish", - actions=None, - final_answer=response.thought or "Insufficient evidence to provide a definitive answer." - ) - final_answer = response.final_answer - else: - tool_name = response.actions.tool - arguments = _build_tool_arguments(response.actions) - tool_call_id = str(uuid.uuid4()) - ai_message = AIMessage( - content=response.thought, - tool_calls=[{ - "name": tool_name, - "args": arguments, - "id": tool_call_id - }] - ) - - return { - "messages": [ai_message], - "thought": response, - "step": state.get("step", 0) + 1, - "max_steps": config.max_iterations, - "output": final_answer - } + step_num = state.get("step", 0) + with AGENT_TRACER.push_active_function("thought node", input_data=f"step:{step_num}") as span: + try: + active_prompt = state.get("runtime_prompt") or default_system_prompt + messages = [SystemMessage(content=active_prompt)] + state["messages"] + obs = state.get("observation", None) + if obs is not None: + memory_list = obs.memory if obs.memory else ["No prior knowledge."] + recent_findings = obs.results if obs.results else ["No recent findings."] + memory_context = "\n".join(f"- {m}" for m in memory_list) + findings_context = "\n".join(f"- {f}" for f in recent_findings) + context_block = f"KNOWLEDGE:\n{memory_context}\nLATEST FINDINGS:\n{findings_context}" + messages.append(SystemMessage(content=context_block)) + response: Thought = await thought_llm.ainvoke(messages) + + final_answer = "waiting for the agent to respond" + if response.mode == "finish": + ai_message = AIMessage(content=response.final_answer) + final_answer = response.final_answer + elif response.actions is None: + logger.warning("LLM returned mode='act' but actions is None, forcing finish") + ai_message = AIMessage(content=response.thought or "No actions provided, finishing.") + response = Thought( + thought=response.thought or "No actions provided", + mode="finish", + actions=None, + final_answer=response.thought or "Insufficient evidence to provide a definitive answer." + ) + final_answer = response.final_answer + else: + tool_name = response.actions.tool + arguments = _build_tool_arguments(response.actions) + tool_call_id = str(uuid.uuid4()) + ai_message = AIMessage( + content=response.thought, + tool_calls=[{ + "name": tool_name, + "args": arguments, + "id": tool_call_id + }] + ) + + span.set_output({"mode": response.mode, "step": step_num + 1}) + return { + "messages": [ai_message], + "thought": response, + "step": step_num + 1, + "max_steps": config.max_iterations, + "output": final_answer + } + except Exception as e: + logger.exception("thought_node failed at step %d", step_num) + span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num}) + raise async def should_continue(state: AgentState) -> str: thought = state.get("thought", None) @@ -411,34 +423,42 @@ async def should_continue(state: AgentState) -> str: return TOOL_NODE async def forced_finish_node(state: AgentState) -> AgentState: - active_prompt = state.get("runtime_prompt") or default_system_prompt - messages = [SystemMessage(content=active_prompt)] + state["messages"] - messages.append(HumanMessage(content=FORCED_FINISH_PROMPT)) - obs = state.get("observation", None) - if obs is not None and obs.memory: - memory_context = "\n".join(f"- {m}" for m in obs.memory) - messages.append(SystemMessage(content=f"KNOWLEDGE:\n{memory_context}")) - response: Thought = await thought_llm.ainvoke(messages) - if response.mode == "finish" and response.final_answer: - ai_message = AIMessage(content=response.final_answer) - final_answer = response.final_answer - else: - final_answer = "Failed to generate a final answer within the maximum allowed steps." - ai_message = AIMessage(content=final_answer) - response = Thought( - thought=response.thought or "Max steps exceeded", - mode="finish", - actions=None, - final_answer=final_answer - ) - return { - "messages": [ai_message], - "thought": response, - "step": state.get("step", 0), - "max_steps": state.get("max_steps", config.max_iterations), - "observation": state.get("observation", None), - "output": final_answer - } + step_num = state.get("step", 0) + with AGENT_TRACER.push_active_function("forced_finish node", input_data=f"step:{step_num}") as span: + try: + active_prompt = state.get("runtime_prompt") or default_system_prompt + messages = [SystemMessage(content=active_prompt)] + state["messages"] + messages.append(HumanMessage(content=FORCED_FINISH_PROMPT)) + obs = state.get("observation", None) + if obs is not None and obs.memory: + memory_context = "\n".join(f"- {m}" for m in obs.memory) + messages.append(SystemMessage(content=f"KNOWLEDGE:\n{memory_context}")) + response: Thought = await thought_llm.ainvoke(messages) + if response.mode == "finish" and response.final_answer: + ai_message = AIMessage(content=response.final_answer) + final_answer = response.final_answer + else: + final_answer = "Failed to generate a final answer within the maximum allowed steps." + ai_message = AIMessage(content=final_answer) + response = Thought( + thought=response.thought or "Max steps exceeded", + mode="finish", + actions=None, + final_answer=final_answer + ) + span.set_output({"final_answer_length": len(final_answer), "step": step_num}) + return { + "messages": [ai_message], + "thought": response, + "step": step_num, + "max_steps": state.get("max_steps", config.max_iterations), + "observation": state.get("observation", None), + "output": final_answer + } + except Exception as e: + logger.exception("forced_finish_node failed at step %d", step_num) + span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num}) + raise async def observation_node(state: AgentState) -> AgentState: tool_message = state["messages"][-1] @@ -456,51 +476,56 @@ async def observation_node(state: AgentState) -> AgentState: previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."] rules_tracker = state.get("rules_tracker") with AGENT_TRACER.push_active_function("observation node", input_data=f"tool used:{tool_used}") as span: - tool_output_for_llm = tool_message.content - if tool_used == "Code Keyword Search" and state.get("app_package") and state.get("ecosystem"): - tool_output_for_llm = _group_code_search_results( - tool_message.content, state["ecosystem"], state["app_package"], + try: + tool_output_for_llm = tool_message.content + if tool_used == "Code Keyword Search" and state.get("app_package") and state.get("ecosystem"): + tool_output_for_llm = _group_code_search_results( + tool_message.content, state["ecosystem"], state["app_package"], + ) + result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm) + if result: + span.set_output({"rule_error": error_message}) + return {"messages": [HumanMessage(content=error_message)]} + + prompt = OBSERVATION_NODE_PROMPT.format( + goal=state.get('input'), + selected_package=state.get('app_package') or "N/A", + previous_memory=previous_memory, + tool_used=tool_used, + tool_input_detail=tool_input_detail, + last_thought_text=last_thought_text, + tool_output=tool_output_for_llm, ) - result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm) - if result: - span.set_output({"rule_error": error_message}) - return {"messages": [HumanMessage(content=error_message)]} - - prompt = OBSERVATION_NODE_PROMPT.format( - goal=state.get('input'), - selected_package=state.get('app_package') or "N/A", - previous_memory=previous_memory, - tool_used=tool_used, - tool_input_detail=tool_input_detail, - last_thought_text=last_thought_text, - tool_output=tool_output_for_llm, - ) - - new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) - - messages = state["messages"] - active_prompt = state.get("runtime_prompt") or default_system_prompt - estimated = _estimate_tokens(active_prompt, messages, new_observation) - prune_messages = [] - orig_estimated = estimated - - if estimated > config.context_window_token_limit and len(messages) > 3: - prunable = messages[1:-2] - for msg in prunable: - prune_messages.append(RemoveMessage(id=msg.id)) - estimated -= _count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0 - if estimated <= config.context_window_token_limit: - break - logger.info( - "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", - len(prune_messages), estimated, config.context_window_token_limit, - ) - span.set_output({"orig_estimated": orig_estimated, "estimated": estimated}) - return { - "messages": prune_messages, - "observation": new_observation, - "step": state.get("step", 0), - } + + new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) + + messages = state["messages"] + active_prompt = state.get("runtime_prompt") or default_system_prompt + estimated = _estimate_tokens(active_prompt, messages, new_observation) + prune_messages = [] + orig_estimated = estimated + + if estimated > config.context_window_token_limit and len(messages) > 3: + prunable = messages[1:-2] + for msg in prunable: + prune_messages.append(RemoveMessage(id=msg.id)) + estimated -= _count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0 + if estimated <= config.context_window_token_limit: + break + logger.info( + "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", + len(prune_messages), estimated, config.context_window_token_limit, + ) + span.set_output({"orig_estimated": orig_estimated, "estimated": estimated}) + return { + "messages": prune_messages, + "observation": new_observation, + "step": state.get("step", 0), + } + except Exception as e: + logger.exception("observation_node failed") + span.set_output({"error": str(e), "exception_type": type(e).__name__}) + raise async def create_graph(): flow = StateGraph(AgentState) From d0bde114c26cf848515ca7392d8e8461d306f229 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 5 Mar 2026 10:25:42 +0000 Subject: [PATCH 35/60] fix java and cpp --- .../utils/c_segmenter_custom.py | 59 +- src/exploit_iq_commons/utils/dep_tree.py | 15 +- .../tools/tests/test_segmenter.py | 35 +- .../tools/tests/tests_data/jsonpath.c | 1076 +++++++++++++++++ 4 files changed, 1177 insertions(+), 8 deletions(-) create mode 100644 src/vuln_analysis/tools/tests/tests_data/jsonpath.c diff --git a/src/exploit_iq_commons/utils/c_segmenter_custom.py b/src/exploit_iq_commons/utils/c_segmenter_custom.py index b1155150..c842f1e7 100644 --- a/src/exploit_iq_commons/utils/c_segmenter_custom.py +++ b/src/exploit_iq_commons/utils/c_segmenter_custom.py @@ -5,9 +5,12 @@ #class extened CSegmenter class CSegmenterExtended(CSegmenter): + _DEFINE_FUNC_RE = re.compile( + r'^\s*#\s*define\s+(?P[a-zA-Z_]\w*)\s*\((?P[^)]*)\)' + ) + def __init__(self, code: str): - # Preprocess code: remove comments and macro blocks - + self._raw_code = code code = self.remove_macro_blocks(code) super().__init__(code) self.structs_enums: List[str] = [] @@ -169,6 +172,57 @@ def find_top_level_blocks(c_code: str): return blocks + @classmethod + def extract_define_functions(cls, code: str) -> List[str]: + """ + Extract function-like #define macros whose name contains at least one + lowercase letter and convert them into dummy C function segments. + + Handles multi-line macros joined by backslash continuation. + Strips ``do { ... } while(0)`` wrappers when present. + """ + lines = code.splitlines() + total = len(lines) + results: List[str] = [] + i = 0 + + while i < total: + m = cls._DEFINE_FUNC_RE.match(lines[i]) + if not m or m.group('name').isupper(): + i += 1 + continue + + name = m.group('name') + args = m.group('args').strip() + + macro_lines = [lines[i]] + while macro_lines[-1].rstrip().endswith('\\') and i + 1 < total: + i += 1 + macro_lines.append(lines[i]) + + full_line = ' '.join( + ln.rstrip().rstrip('\\').strip() for ln in macro_lines + ) + + header_re = re.compile( + r'#\s*define\s+' + re.escape(name) + r'\s*\([^)]*\)\s*' + ) + body = header_re.sub('', full_line, count=1).strip() + + do_while_re = re.compile( + r'^do\s*\{(?P.*)\}\s*while\s*\(\s*0\s*\)\s*;?\s*$', + re.DOTALL, + ) + dw = do_while_re.match(body) + if dw: + body = dw.group('inner').strip() + + dummy = f"void {name}({args}) {{ {body} }}" + results.append(dummy) + i += 1 + + return results + def extract_functions_classes(self) -> List[str]: segments = super().extract_functions_classes() for i, seg in enumerate(segments): @@ -183,6 +237,7 @@ def extract_functions_classes(self) -> List[str]: hidden_segments.append(new_seg) break # support only one segment into 2 segments we only add hidden segments into segments list that might be functions segments.extend(hidden_segments) + segments.extend(self.extract_define_functions(self._raw_code)) return segments \ No newline at end of file diff --git a/src/exploit_iq_commons/utils/dep_tree.py b/src/exploit_iq_commons/utils/dep_tree.py index 9eb47868..f48928e5 100644 --- a/src/exploit_iq_commons/utils/dep_tree.py +++ b/src/exploit_iq_commons/utils/dep_tree.py @@ -849,12 +849,17 @@ def install_dependencies(self, manifest_path: Path): full_source_path = manifest_path / source_path for jar in full_source_path.glob("*-sources.jar"): - dest = full_source_path / jar.stem # folder named after jar + if jar.stat().st_size > 0: + dest = full_source_path / jar.stem # folder named after jar + + if not dest.exists(): + dest.mkdir(exist_ok=True) + result = subprocess.run(["jar", "xf", str(jar.resolve())], cwd=dest) + if result.returncode != 0: + logger.warning("Failed to extract sources jar: %s (exit code %d)", jar, result.returncode) + else: + logger.warning("Empty sources jar (size=0), possibly corrupt: %s", jar) - if not dest.exists(): - dest.mkdir(exist_ok=True) - with zipfile.ZipFile(jar, "r") as zf: - zf.extractall(dest) def build_tree(self, manifest_path: Path) -> dict[str, list[str]]: dependency_file = manifest_path / "dependency_tree.txt" diff --git a/src/vuln_analysis/tools/tests/test_segmenter.py b/src/vuln_analysis/tools/tests/test_segmenter.py index 91cd047b..2e4fdcfe 100644 --- a/src/vuln_analysis/tools/tests/test_segmenter.py +++ b/src/vuln_analysis/tools/tests/test_segmenter.py @@ -179,4 +179,37 @@ def test_integration_ess_lib_c(ess_lib_c_code): assert "ESS_CERT_ID_V2_new_init" in names # static function (v2 version) assert "ess_issuer_serial_cmp" in names # static comparison function - \ No newline at end of file + +@pytest.fixture(scope="module") +def jsonpath_c_code(c_test_files_dir): + file_path = c_test_files_dir / "jsonpath.c" + return file_path.read_text(encoding="utf-8") + + +def test_jsonpath_c_define_macros_captured(jsonpath_c_code): + """ + Lowercase function-like #define macros (read_byte, read_int32, read_int32_n) + should be captured as dummy function segments. The raw ``#define`` text + must not leak through — only the synthesized ``void name(args) { body }`` form. + Functions that *use* these macros (e.g. jspInitByBuffer) must still be captured. + """ + segmenter = CSegmenterExtended(jsonpath_c_code) + segments = segmenter.extract_functions_classes() + + assert len(segments) > 0, "Should find segments in jsonpath.c" + + for seg in segments: + assert "#define read_byte" not in seg, ( + "Raw #define read_byte should not appear — only the dummy function form" + ) + + names = {_get_function_name_from_segment(seg) for seg in segments if _get_function_name_from_segment(seg)} + + assert "read_byte" in names, "read_byte macro should be captured as a dummy function segment" + assert "read_int32" in names, "read_int32 macro should be captured as a dummy function segment" + assert "read_int32_n" in names, "read_int32_n macro should be captured as a dummy function segment" + + assert "jspInitByBuffer" in names, ( + "jspInitByBuffer (which uses read_byte) should be captured as a function segment" + ) + assert "jspInit" in names diff --git a/src/vuln_analysis/tools/tests/tests_data/jsonpath.c b/src/vuln_analysis/tools/tests/tests_data/jsonpath.c new file mode 100644 index 00000000..db06e6f2 --- /dev/null +++ b/src/vuln_analysis/tools/tests/tests_data/jsonpath.c @@ -0,0 +1,1076 @@ +/*------------------------------------------------------------------------- + * + * jsonpath.c + * Input/output and supporting routines for jsonpath + * + * jsonpath expression is a chain of path items. First path item is $, $var, + * literal or arithmetic expression. Subsequent path items are accessors + * (.key, .*, [subscripts], [*]), filters (? (predicate)) and methods (.type(), + * .size() etc). + * + * For instance, structure of path items for simple expression: + * + * $.a[*].type() + * + * is pretty evident: + * + * $ => .a => [*] => .type() + * + * Some path items such as arithmetic operations, predicates or array + * subscripts may comprise subtrees. For instance, more complex expression + * + * ($.a + $[1 to 5, 7] ? (@ > 3).double()).type() + * + * have following structure of path items: + * + * + => .type() + * ___/ \___ + * / \ + * $ => .a $ => [] => ? => .double() + * _||_ | + * / \ > + * to to / \ + * / \ / @ 3 + * 1 5 7 + * + * Binary encoding of jsonpath constitutes a sequence of 4-bytes aligned + * variable-length path items connected by links. Every item has a header + * consisting of item type (enum JsonPathItemType) and offset of next item + * (zero means no next item). After the header, item may have payload + * depending on item type. For instance, payload of '.key' accessor item is + * length of key name and key name itself. Payload of '>' arithmetic operator + * item is offsets of right and left operands. + * + * So, binary representation of sample expression above is: + * (bottom arrows are next links, top lines are argument links) + * + * _____ + * _____ ___/____ \ __ + * _ /_ \ _____/__/____ \ \ __ _ /_ \ + * / / \ \ / / / \ \ \ / \ / / \ \ + * +(LR) $ .a $ [](* to *, * to *) 1 5 7 ?(A) >(LR) @ 3 .double() .type() + * | | ^ | ^| ^| ^ ^ + * | |__| |__||________________________||___________________| | + * |_______________________________________________________________________| + * + * Copyright (c) 2019-2020, PostgreSQL Global Development Group + * + * IDENTIFICATION + * src/backend/utils/adt/jsonpath.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include "funcapi.h" +#include "lib/stringinfo.h" +#include "libpq/pqformat.h" +#include "miscadmin.h" +#include "utils/builtins.h" +#include "utils/json.h" +#include "utils/jsonpath.h" + + +static Datum jsonPathFromCstring(char *in, int len); +static char *jsonPathToCstring(StringInfo out, JsonPath *in, + int estimated_len); +static int flattenJsonPathParseItem(StringInfo buf, JsonPathParseItem *item, + int nestingLevel, bool insideArraySubscript); +static void alignStringInfoInt(StringInfo buf); +static int32 reserveSpaceForItemPointer(StringInfo buf); +static void printJsonPathItem(StringInfo buf, JsonPathItem *v, bool inKey, + bool printBracketes); +static int operationPriority(JsonPathItemType op); + + +/**************************** INPUT/OUTPUT ********************************/ + +/* + * jsonpath type input function + */ +Datum +jsonpath_in(PG_FUNCTION_ARGS) +{ + char *in = PG_GETARG_CSTRING(0); + int len = strlen(in); + + return jsonPathFromCstring(in, len); +} + +/* + * jsonpath type recv function + * + * The type is sent as text in binary mode, so this is almost the same + * as the input function, but it's prefixed with a version number so we + * can change the binary format sent in future if necessary. For now, + * only version 1 is supported. + */ +Datum +jsonpath_recv(PG_FUNCTION_ARGS) +{ + StringInfo buf = (StringInfo) PG_GETARG_POINTER(0); + int version = pq_getmsgint(buf, 1); + char *str; + int nbytes; + + if (version == JSONPATH_VERSION) + str = pq_getmsgtext(buf, buf->len - buf->cursor, &nbytes); + else + elog(ERROR, "unsupported jsonpath version number: %d", version); + + return jsonPathFromCstring(str, nbytes); +} + +/* + * jsonpath type output function + */ +Datum +jsonpath_out(PG_FUNCTION_ARGS) +{ + JsonPath *in = PG_GETARG_JSONPATH_P(0); + + PG_RETURN_CSTRING(jsonPathToCstring(NULL, in, VARSIZE(in))); +} + +/* + * jsonpath type send function + * + * Just send jsonpath as a version number, then a string of text + */ +Datum +jsonpath_send(PG_FUNCTION_ARGS) +{ + JsonPath *in = PG_GETARG_JSONPATH_P(0); + StringInfoData buf; + StringInfoData jtext; + int version = JSONPATH_VERSION; + + initStringInfo(&jtext); + (void) jsonPathToCstring(&jtext, in, VARSIZE(in)); + + pq_begintypsend(&buf); + pq_sendint8(&buf, version); + pq_sendtext(&buf, jtext.data, jtext.len); + pfree(jtext.data); + + PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); +} + +/* + * Converts C-string to a jsonpath value. + * + * Uses jsonpath parser to turn string into an AST, then + * flattenJsonPathParseItem() does second pass turning AST into binary + * representation of jsonpath. + */ +static Datum +jsonPathFromCstring(char *in, int len) +{ + JsonPathParseResult *jsonpath = parsejsonpath(in, len); + JsonPath *res; + StringInfoData buf; + + initStringInfo(&buf); + enlargeStringInfo(&buf, 4 * len /* estimation */ ); + + appendStringInfoSpaces(&buf, JSONPATH_HDRSZ); + + if (!jsonpath) + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type %s: \"%s\"", "jsonpath", + in))); + + flattenJsonPathParseItem(&buf, jsonpath->expr, 0, false); + + res = (JsonPath *) buf.data; + SET_VARSIZE(res, buf.len); + res->header = JSONPATH_VERSION; + if (jsonpath->lax) + res->header |= JSONPATH_LAX; + + PG_RETURN_JSONPATH_P(res); +} + +/* + * Converts jsonpath value to a C-string. + * + * If 'out' argument is non-null, the resulting C-string is stored inside the + * StringBuffer. The resulting string is always returned. + */ +static char * +jsonPathToCstring(StringInfo out, JsonPath *in, int estimated_len) +{ + StringInfoData buf; + JsonPathItem v; + + if (!out) + { + out = &buf; + initStringInfo(out); + } + enlargeStringInfo(out, estimated_len); + + if (!(in->header & JSONPATH_LAX)) + appendBinaryStringInfo(out, "strict ", 7); + + jspInit(&v, in); + printJsonPathItem(out, &v, false, true); + + return out->data; +} + +/* + * Recursive function converting given jsonpath parse item and all its + * children into a binary representation. + */ +static int +flattenJsonPathParseItem(StringInfo buf, JsonPathParseItem *item, + int nestingLevel, bool insideArraySubscript) +{ + /* position from beginning of jsonpath data */ + int32 pos = buf->len - JSONPATH_HDRSZ; + int32 chld; + int32 next; + int argNestingLevel = 0; + + check_stack_depth(); + CHECK_FOR_INTERRUPTS(); + + appendStringInfoChar(buf, (char) (item->type)); + + /* + * We align buffer to int32 because a series of int32 values often goes + * after the header, and we want to read them directly by dereferencing + * int32 pointer (see jspInitByBuffer()). + */ + alignStringInfoInt(buf); + + /* + * Reserve space for next item pointer. Actual value will be recorded + * later, after next and children items processing. + */ + next = reserveSpaceForItemPointer(buf); + + switch (item->type) + { + case jpiString: + case jpiVariable: + case jpiKey: + appendBinaryStringInfo(buf, (char *) &item->value.string.len, + sizeof(item->value.string.len)); + appendBinaryStringInfo(buf, item->value.string.val, + item->value.string.len); + appendStringInfoChar(buf, '\0'); + break; + case jpiNumeric: + appendBinaryStringInfo(buf, (char *) item->value.numeric, + VARSIZE(item->value.numeric)); + break; + case jpiBool: + appendBinaryStringInfo(buf, (char *) &item->value.boolean, + sizeof(item->value.boolean)); + break; + case jpiAnd: + case jpiOr: + case jpiEqual: + case jpiNotEqual: + case jpiLess: + case jpiGreater: + case jpiLessOrEqual: + case jpiGreaterOrEqual: + case jpiAdd: + case jpiSub: + case jpiMul: + case jpiDiv: + case jpiMod: + case jpiStartsWith: + { + /* + * First, reserve place for left/right arg's positions, then + * record both args and sets actual position in reserved + * places. + */ + int32 left = reserveSpaceForItemPointer(buf); + int32 right = reserveSpaceForItemPointer(buf); + + chld = !item->value.args.left ? pos : + flattenJsonPathParseItem(buf, item->value.args.left, + nestingLevel + argNestingLevel, + insideArraySubscript); + *(int32 *) (buf->data + left) = chld - pos; + + chld = !item->value.args.right ? pos : + flattenJsonPathParseItem(buf, item->value.args.right, + nestingLevel + argNestingLevel, + insideArraySubscript); + *(int32 *) (buf->data + right) = chld - pos; + } + break; + case jpiLikeRegex: + { + int32 offs; + + appendBinaryStringInfo(buf, + (char *) &item->value.like_regex.flags, + sizeof(item->value.like_regex.flags)); + offs = reserveSpaceForItemPointer(buf); + appendBinaryStringInfo(buf, + (char *) &item->value.like_regex.patternlen, + sizeof(item->value.like_regex.patternlen)); + appendBinaryStringInfo(buf, item->value.like_regex.pattern, + item->value.like_regex.patternlen); + appendStringInfoChar(buf, '\0'); + + chld = flattenJsonPathParseItem(buf, item->value.like_regex.expr, + nestingLevel, + insideArraySubscript); + *(int32 *) (buf->data + offs) = chld - pos; + } + break; + case jpiFilter: + argNestingLevel++; + /* FALLTHROUGH */ + case jpiIsUnknown: + case jpiNot: + case jpiPlus: + case jpiMinus: + case jpiExists: + case jpiDatetime: + { + int32 arg = reserveSpaceForItemPointer(buf); + + chld = !item->value.arg ? pos : + flattenJsonPathParseItem(buf, item->value.arg, + nestingLevel + argNestingLevel, + insideArraySubscript); + *(int32 *) (buf->data + arg) = chld - pos; + } + break; + case jpiNull: + break; + case jpiRoot: + break; + case jpiAnyArray: + case jpiAnyKey: + break; + case jpiCurrent: + if (nestingLevel <= 0) + ereport(ERROR, + (errcode(ERRCODE_SYNTAX_ERROR), + errmsg("@ is not allowed in root expressions"))); + break; + case jpiLast: + if (!insideArraySubscript) + ereport(ERROR, + (errcode(ERRCODE_SYNTAX_ERROR), + errmsg("LAST is allowed only in array subscripts"))); + break; + case jpiIndexArray: + { + int32 nelems = item->value.array.nelems; + int offset; + int i; + + appendBinaryStringInfo(buf, (char *) &nelems, sizeof(nelems)); + + offset = buf->len; + + appendStringInfoSpaces(buf, sizeof(int32) * 2 * nelems); + + for (i = 0; i < nelems; i++) + { + int32 *ppos; + int32 topos; + int32 frompos = + flattenJsonPathParseItem(buf, + item->value.array.elems[i].from, + nestingLevel, true) - pos; + + if (item->value.array.elems[i].to) + topos = flattenJsonPathParseItem(buf, + item->value.array.elems[i].to, + nestingLevel, true) - pos; + else + topos = 0; + + ppos = (int32 *) &buf->data[offset + i * 2 * sizeof(int32)]; + + ppos[0] = frompos; + ppos[1] = topos; + } + } + break; + case jpiAny: + appendBinaryStringInfo(buf, + (char *) &item->value.anybounds.first, + sizeof(item->value.anybounds.first)); + appendBinaryStringInfo(buf, + (char *) &item->value.anybounds.last, + sizeof(item->value.anybounds.last)); + break; + case jpiType: + case jpiSize: + case jpiAbs: + case jpiFloor: + case jpiCeiling: + case jpiDouble: + case jpiKeyValue: + break; + default: + elog(ERROR, "unrecognized jsonpath item type: %d", item->type); + } + + if (item->next) + { + chld = flattenJsonPathParseItem(buf, item->next, nestingLevel, + insideArraySubscript) - pos; + *(int32 *) (buf->data + next) = chld; + } + + return pos; +} + +/* + * Align StringInfo to int by adding zero padding bytes + */ +static void +alignStringInfoInt(StringInfo buf) +{ + switch (INTALIGN(buf->len) - buf->len) + { + case 3: + appendStringInfoCharMacro(buf, 0); + /* FALLTHROUGH */ + case 2: + appendStringInfoCharMacro(buf, 0); + /* FALLTHROUGH */ + case 1: + appendStringInfoCharMacro(buf, 0); + /* FALLTHROUGH */ + default: + break; + } +} + +/* + * Reserve space for int32 JsonPathItem pointer. Now zero pointer is written, + * actual value will be recorded at '(int32 *) &buf->data[pos]' later. + */ +static int32 +reserveSpaceForItemPointer(StringInfo buf) +{ + int32 pos = buf->len; + int32 ptr = 0; + + appendBinaryStringInfo(buf, (char *) &ptr, sizeof(ptr)); + + return pos; +} + +/* + * Prints text representation of given jsonpath item and all its children. + */ +static void +printJsonPathItem(StringInfo buf, JsonPathItem *v, bool inKey, + bool printBracketes) +{ + JsonPathItem elem; + int i; + + check_stack_depth(); + CHECK_FOR_INTERRUPTS(); + + switch (v->type) + { + case jpiNull: + appendStringInfoString(buf, "null"); + break; + case jpiKey: + if (inKey) + appendStringInfoChar(buf, '.'); + escape_json(buf, jspGetString(v, NULL)); + break; + case jpiString: + escape_json(buf, jspGetString(v, NULL)); + break; + case jpiVariable: + appendStringInfoChar(buf, '$'); + escape_json(buf, jspGetString(v, NULL)); + break; + case jpiNumeric: + appendStringInfoString(buf, + DatumGetCString(DirectFunctionCall1(numeric_out, + NumericGetDatum(jspGetNumeric(v))))); + break; + case jpiBool: + if (jspGetBool(v)) + appendBinaryStringInfo(buf, "true", 4); + else + appendBinaryStringInfo(buf, "false", 5); + break; + case jpiAnd: + case jpiOr: + case jpiEqual: + case jpiNotEqual: + case jpiLess: + case jpiGreater: + case jpiLessOrEqual: + case jpiGreaterOrEqual: + case jpiAdd: + case jpiSub: + case jpiMul: + case jpiDiv: + case jpiMod: + case jpiStartsWith: + if (printBracketes) + appendStringInfoChar(buf, '('); + jspGetLeftArg(v, &elem); + printJsonPathItem(buf, &elem, false, + operationPriority(elem.type) <= + operationPriority(v->type)); + appendStringInfoChar(buf, ' '); + appendStringInfoString(buf, jspOperationName(v->type)); + appendStringInfoChar(buf, ' '); + jspGetRightArg(v, &elem); + printJsonPathItem(buf, &elem, false, + operationPriority(elem.type) <= + operationPriority(v->type)); + if (printBracketes) + appendStringInfoChar(buf, ')'); + break; + case jpiLikeRegex: + if (printBracketes) + appendStringInfoChar(buf, '('); + + jspInitByBuffer(&elem, v->base, v->content.like_regex.expr); + printJsonPathItem(buf, &elem, false, + operationPriority(elem.type) <= + operationPriority(v->type)); + + appendBinaryStringInfo(buf, " like_regex ", 12); + + escape_json(buf, v->content.like_regex.pattern); + + if (v->content.like_regex.flags) + { + appendBinaryStringInfo(buf, " flag \"", 7); + + if (v->content.like_regex.flags & JSP_REGEX_ICASE) + appendStringInfoChar(buf, 'i'); + if (v->content.like_regex.flags & JSP_REGEX_DOTALL) + appendStringInfoChar(buf, 's'); + if (v->content.like_regex.flags & JSP_REGEX_MLINE) + appendStringInfoChar(buf, 'm'); + if (v->content.like_regex.flags & JSP_REGEX_WSPACE) + appendStringInfoChar(buf, 'x'); + if (v->content.like_regex.flags & JSP_REGEX_QUOTE) + appendStringInfoChar(buf, 'q'); + + appendStringInfoChar(buf, '"'); + } + + if (printBracketes) + appendStringInfoChar(buf, ')'); + break; + case jpiPlus: + case jpiMinus: + if (printBracketes) + appendStringInfoChar(buf, '('); + appendStringInfoChar(buf, v->type == jpiPlus ? '+' : '-'); + jspGetArg(v, &elem); + printJsonPathItem(buf, &elem, false, + operationPriority(elem.type) <= + operationPriority(v->type)); + if (printBracketes) + appendStringInfoChar(buf, ')'); + break; + case jpiFilter: + appendBinaryStringInfo(buf, "?(", 2); + jspGetArg(v, &elem); + printJsonPathItem(buf, &elem, false, false); + appendStringInfoChar(buf, ')'); + break; + case jpiNot: + appendBinaryStringInfo(buf, "!(", 2); + jspGetArg(v, &elem); + printJsonPathItem(buf, &elem, false, false); + appendStringInfoChar(buf, ')'); + break; + case jpiIsUnknown: + appendStringInfoChar(buf, '('); + jspGetArg(v, &elem); + printJsonPathItem(buf, &elem, false, false); + appendBinaryStringInfo(buf, ") is unknown", 12); + break; + case jpiExists: + appendBinaryStringInfo(buf, "exists (", 8); + jspGetArg(v, &elem); + printJsonPathItem(buf, &elem, false, false); + appendStringInfoChar(buf, ')'); + break; + case jpiCurrent: + Assert(!inKey); + appendStringInfoChar(buf, '@'); + break; + case jpiRoot: + Assert(!inKey); + appendStringInfoChar(buf, '$'); + break; + case jpiLast: + appendBinaryStringInfo(buf, "last", 4); + break; + case jpiAnyArray: + appendBinaryStringInfo(buf, "[*]", 3); + break; + case jpiAnyKey: + if (inKey) + appendStringInfoChar(buf, '.'); + appendStringInfoChar(buf, '*'); + break; + case jpiIndexArray: + appendStringInfoChar(buf, '['); + for (i = 0; i < v->content.array.nelems; i++) + { + JsonPathItem from; + JsonPathItem to; + bool range = jspGetArraySubscript(v, &from, &to, i); + + if (i) + appendStringInfoChar(buf, ','); + + printJsonPathItem(buf, &from, false, false); + + if (range) + { + appendBinaryStringInfo(buf, " to ", 4); + printJsonPathItem(buf, &to, false, false); + } + } + appendStringInfoChar(buf, ']'); + break; + case jpiAny: + if (inKey) + appendStringInfoChar(buf, '.'); + + if (v->content.anybounds.first == 0 && + v->content.anybounds.last == PG_UINT32_MAX) + appendBinaryStringInfo(buf, "**", 2); + else if (v->content.anybounds.first == v->content.anybounds.last) + { + if (v->content.anybounds.first == PG_UINT32_MAX) + appendStringInfo(buf, "**{last}"); + else + appendStringInfo(buf, "**{%u}", + v->content.anybounds.first); + } + else if (v->content.anybounds.first == PG_UINT32_MAX) + appendStringInfo(buf, "**{last to %u}", + v->content.anybounds.last); + else if (v->content.anybounds.last == PG_UINT32_MAX) + appendStringInfo(buf, "**{%u to last}", + v->content.anybounds.first); + else + appendStringInfo(buf, "**{%u to %u}", + v->content.anybounds.first, + v->content.anybounds.last); + break; + case jpiType: + appendBinaryStringInfo(buf, ".type()", 7); + break; + case jpiSize: + appendBinaryStringInfo(buf, ".size()", 7); + break; + case jpiAbs: + appendBinaryStringInfo(buf, ".abs()", 6); + break; + case jpiFloor: + appendBinaryStringInfo(buf, ".floor()", 8); + break; + case jpiCeiling: + appendBinaryStringInfo(buf, ".ceiling()", 10); + break; + case jpiDouble: + appendBinaryStringInfo(buf, ".double()", 9); + break; + case jpiDatetime: + appendBinaryStringInfo(buf, ".datetime(", 10); + if (v->content.arg) + { + jspGetArg(v, &elem); + printJsonPathItem(buf, &elem, false, false); + } + appendStringInfoChar(buf, ')'); + break; + case jpiKeyValue: + appendBinaryStringInfo(buf, ".keyvalue()", 11); + break; + default: + elog(ERROR, "unrecognized jsonpath item type: %d", v->type); + } + + if (jspGetNext(v, &elem)) + printJsonPathItem(buf, &elem, true, true); +} + +const char * +jspOperationName(JsonPathItemType type) +{ + switch (type) + { + case jpiAnd: + return "&&"; + case jpiOr: + return "||"; + case jpiEqual: + return "=="; + case jpiNotEqual: + return "!="; + case jpiLess: + return "<"; + case jpiGreater: + return ">"; + case jpiLessOrEqual: + return "<="; + case jpiGreaterOrEqual: + return ">="; + case jpiPlus: + case jpiAdd: + return "+"; + case jpiMinus: + case jpiSub: + return "-"; + case jpiMul: + return "*"; + case jpiDiv: + return "/"; + case jpiMod: + return "%"; + case jpiStartsWith: + return "starts with"; + case jpiLikeRegex: + return "like_regex"; + case jpiType: + return "type"; + case jpiSize: + return "size"; + case jpiKeyValue: + return "keyvalue"; + case jpiDouble: + return "double"; + case jpiAbs: + return "abs"; + case jpiFloor: + return "floor"; + case jpiCeiling: + return "ceiling"; + case jpiDatetime: + return "datetime"; + default: + elog(ERROR, "unrecognized jsonpath item type: %d", type); + return NULL; + } +} + +static int +operationPriority(JsonPathItemType op) +{ + switch (op) + { + case jpiOr: + return 0; + case jpiAnd: + return 1; + case jpiEqual: + case jpiNotEqual: + case jpiLess: + case jpiGreater: + case jpiLessOrEqual: + case jpiGreaterOrEqual: + case jpiStartsWith: + return 2; + case jpiAdd: + case jpiSub: + return 3; + case jpiMul: + case jpiDiv: + case jpiMod: + return 4; + case jpiPlus: + case jpiMinus: + return 5; + default: + return 6; + } +} + +/******************* Support functions for JsonPath *************************/ + +/* + * Support macros to read stored values + */ + +#define read_byte(v, b, p) do { \ + (v) = *(uint8*)((b) + (p)); \ + (p) += 1; \ +} while(0) \ + +#define read_int32(v, b, p) do { \ + (v) = *(uint32*)((b) + (p)); \ + (p) += sizeof(int32); \ +} while(0) \ + +#define read_int32_n(v, b, p, n) do { \ + (v) = (void *)((b) + (p)); \ + (p) += sizeof(int32) * (n); \ +} while(0) \ + +/* + * Read root node and fill root node representation + */ +void +jspInit(JsonPathItem *v, JsonPath *js) +{ + Assert((js->header & ~JSONPATH_LAX) == JSONPATH_VERSION); + jspInitByBuffer(v, js->data, 0); +} + +/* + * Read node from buffer and fill its representation + */ +void +jspInitByBuffer(JsonPathItem *v, char *base, int32 pos) +{ + v->base = base + pos; + + read_byte(v->type, base, pos); + pos = INTALIGN((uintptr_t) (base + pos)) - (uintptr_t) base; + read_int32(v->nextPos, base, pos); + + switch (v->type) + { + case jpiNull: + case jpiRoot: + case jpiCurrent: + case jpiAnyArray: + case jpiAnyKey: + case jpiType: + case jpiSize: + case jpiAbs: + case jpiFloor: + case jpiCeiling: + case jpiDouble: + case jpiKeyValue: + case jpiLast: + break; + case jpiKey: + case jpiString: + case jpiVariable: + read_int32(v->content.value.datalen, base, pos); + /* FALLTHROUGH */ + case jpiNumeric: + case jpiBool: + v->content.value.data = base + pos; + break; + case jpiAnd: + case jpiOr: + case jpiAdd: + case jpiSub: + case jpiMul: + case jpiDiv: + case jpiMod: + case jpiEqual: + case jpiNotEqual: + case jpiLess: + case jpiGreater: + case jpiLessOrEqual: + case jpiGreaterOrEqual: + case jpiStartsWith: + read_int32(v->content.args.left, base, pos); + read_int32(v->content.args.right, base, pos); + break; + case jpiLikeRegex: + read_int32(v->content.like_regex.flags, base, pos); + read_int32(v->content.like_regex.expr, base, pos); + read_int32(v->content.like_regex.patternlen, base, pos); + v->content.like_regex.pattern = base + pos; + break; + case jpiNot: + case jpiExists: + case jpiIsUnknown: + case jpiPlus: + case jpiMinus: + case jpiFilter: + case jpiDatetime: + read_int32(v->content.arg, base, pos); + break; + case jpiIndexArray: + read_int32(v->content.array.nelems, base, pos); + read_int32_n(v->content.array.elems, base, pos, + v->content.array.nelems * 2); + break; + case jpiAny: + read_int32(v->content.anybounds.first, base, pos); + read_int32(v->content.anybounds.last, base, pos); + break; + default: + elog(ERROR, "unrecognized jsonpath item type: %d", v->type); + } +} + +void +jspGetArg(JsonPathItem *v, JsonPathItem *a) +{ + Assert(v->type == jpiFilter || + v->type == jpiNot || + v->type == jpiIsUnknown || + v->type == jpiExists || + v->type == jpiPlus || + v->type == jpiMinus || + v->type == jpiDatetime); + + jspInitByBuffer(a, v->base, v->content.arg); +} + +bool +jspGetNext(JsonPathItem *v, JsonPathItem *a) +{ + if (jspHasNext(v)) + { + Assert(v->type == jpiString || + v->type == jpiNumeric || + v->type == jpiBool || + v->type == jpiNull || + v->type == jpiKey || + v->type == jpiAny || + v->type == jpiAnyArray || + v->type == jpiAnyKey || + v->type == jpiIndexArray || + v->type == jpiFilter || + v->type == jpiCurrent || + v->type == jpiExists || + v->type == jpiRoot || + v->type == jpiVariable || + v->type == jpiLast || + v->type == jpiAdd || + v->type == jpiSub || + v->type == jpiMul || + v->type == jpiDiv || + v->type == jpiMod || + v->type == jpiPlus || + v->type == jpiMinus || + v->type == jpiEqual || + v->type == jpiNotEqual || + v->type == jpiGreater || + v->type == jpiGreaterOrEqual || + v->type == jpiLess || + v->type == jpiLessOrEqual || + v->type == jpiAnd || + v->type == jpiOr || + v->type == jpiNot || + v->type == jpiIsUnknown || + v->type == jpiType || + v->type == jpiSize || + v->type == jpiAbs || + v->type == jpiFloor || + v->type == jpiCeiling || + v->type == jpiDouble || + v->type == jpiDatetime || + v->type == jpiKeyValue || + v->type == jpiStartsWith || + v->type == jpiLikeRegex); + + if (a) + jspInitByBuffer(a, v->base, v->nextPos); + return true; + } + + return false; +} + +void +jspGetLeftArg(JsonPathItem *v, JsonPathItem *a) +{ + Assert(v->type == jpiAnd || + v->type == jpiOr || + v->type == jpiEqual || + v->type == jpiNotEqual || + v->type == jpiLess || + v->type == jpiGreater || + v->type == jpiLessOrEqual || + v->type == jpiGreaterOrEqual || + v->type == jpiAdd || + v->type == jpiSub || + v->type == jpiMul || + v->type == jpiDiv || + v->type == jpiMod || + v->type == jpiStartsWith); + + jspInitByBuffer(a, v->base, v->content.args.left); +} + +void +jspGetRightArg(JsonPathItem *v, JsonPathItem *a) +{ + Assert(v->type == jpiAnd || + v->type == jpiOr || + v->type == jpiEqual || + v->type == jpiNotEqual || + v->type == jpiLess || + v->type == jpiGreater || + v->type == jpiLessOrEqual || + v->type == jpiGreaterOrEqual || + v->type == jpiAdd || + v->type == jpiSub || + v->type == jpiMul || + v->type == jpiDiv || + v->type == jpiMod || + v->type == jpiStartsWith); + + jspInitByBuffer(a, v->base, v->content.args.right); +} + +bool +jspGetBool(JsonPathItem *v) +{ + Assert(v->type == jpiBool); + + return (bool) *v->content.value.data; +} + +Numeric +jspGetNumeric(JsonPathItem *v) +{ + Assert(v->type == jpiNumeric); + + return (Numeric) v->content.value.data; +} + +char * +jspGetString(JsonPathItem *v, int32 *len) +{ + Assert(v->type == jpiKey || + v->type == jpiString || + v->type == jpiVariable); + + if (len) + *len = v->content.value.datalen; + return v->content.value.data; +} + +bool +jspGetArraySubscript(JsonPathItem *v, JsonPathItem *from, JsonPathItem *to, + int i) +{ + Assert(v->type == jpiIndexArray); + + jspInitByBuffer(from, v->base, v->content.array.elems[i].from); + + if (!v->content.array.elems[i].to) + return false; + + jspInitByBuffer(to, v->base, v->content.array.elems[i].to); + + return true; +} From b7c1cae89685307f8dd5232767046d6c91b3a507 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 5 Mar 2026 10:27:49 +0000 Subject: [PATCH 36/60] clear postgress cache --- .tekton/on-pull-request.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index 38c4dcb9..9e662624 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -248,6 +248,7 @@ spec: #clean the java cache #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* #rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* + rm -rf ./.cache/am_cache/pickle/https.github.com.postgres.postgres-* print_banner "RUNNING UNIT TESTS" make test-unit PYTEST_OPTS="--log-cli-level=DEBUG" From bc1c1ded3dc034552f57ed61b6b6ce2f7fee4851 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 5 Mar 2026 12:57:41 +0200 Subject: [PATCH 37/60] clean java cache --- .tekton/on-pull-request.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index 9e662624..a9978240 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -246,9 +246,10 @@ spec: make lint-pr TARGET_BRANCH=$TARGET_BRANCH_NAME #clean the java cache - #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat #rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* - rm -rf ./.cache/am_cache/pickle/https.github.com.postgres.postgres-* + #rm -rf ./.cache/am_cache/pickle/https.github.com.postgres.postgres-* print_banner "RUNNING UNIT TESTS" make test-unit PYTEST_OPTS="--log-cli-level=DEBUG" From d3193bd33b9c1fb03ad535373c736e78ab68d4ba Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 5 Mar 2026 12:13:37 +0000 Subject: [PATCH 38/60] Add Rule number 8 --- .tekton/on-pull-request.yaml | 4 ++-- src/vuln_analysis/functions/cve_agent.py | 1 + .../functions/react_internals.py | 21 +++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index a9978240..575461d4 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -246,8 +246,8 @@ spec: make lint-pr TARGET_BRANCH=$TARGET_BRANCH_NAME #clean the java cache - rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* - rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat + #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + #rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat #rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* #rm -rf ./.cache/am_cache/pickle/https.github.com.postgres.postgres-* print_banner "RUNNING UNIT TESTS" diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 6c012464..41f5de9d 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -482,6 +482,7 @@ async def observation_node(state: AgentState) -> AgentState: tool_output_for_llm = _group_code_search_results( tool_message.content, state["ecosystem"], state["app_package"], ) + rules_tracker.set_target_package(state.get("app_package")) result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm) if result: span.set_output({"rule_error": error_message}) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index dede20c5..d40a609a 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -69,6 +69,9 @@ class PackageSelection(BaseModel): class SystemRulesTracker: def __init__(self): self.action_history = {} + self.target_package = None + def set_target_package(self, target_package: str): + self.target_package = target_package @staticmethod def _is_empty_result(output) -> bool: @@ -104,10 +107,27 @@ def _rule_number_7(self, action: str, action_input: str, output) -> bool: return True return False + @staticmethod + def _normalize_package_name(name: str) -> str: + return name.strip().lower().replace("-", "_") + + def _rule_number_8(self, action: str, action_input: str, output) -> bool: + if self.target_package is None: + return False + if action not in ("Function Locator", "Call Chain Analyzer", "Function Caller Finder"): + return False + if action not in self.action_history: + input_pkg = self._normalize_package_name(action_input.split(",")[0]) + target_pkg = self._normalize_package_name(self.target_package) + if input_pkg != target_pkg: + return True + return False def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]: if self._rule_number_7(action, action_input, output): return True, ("You are NOT following Rule 7. Your query contains dots and returned " "no results. You MUST retry with just the final component. Follow the rules.") + if self._rule_number_8(action, action_input, output): + return True, (f"You are NOT following Rule 8. You are using the wrong package name. You MUST use the target package name {self.target_package} see KNOWLEDGE as the package_name before trying alternative packages. Follow the rules.") self.add_action(action, action_input, output) return False, "" @@ -173,6 +193,7 @@ class AgentState(MessagesState): 5. Code Keyword Search and Function Locator are NOT reachability proof -- only Call Chain Analyzer is. 6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. 7. If Code Keyword Search returns no results and the query contains dots (e.g. a.b.ClassName), retry with just the final component (e.g. ClassName). This does NOT apply to simple names without dots. +8. When using Function Locator, Call Chain Analyzer, or Function Caller Finder, always start with the TARGET PACKAGE from KNOWLEDGE as the package_name before trying alternative packages. {{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} From 5b14fa83e8bb5153fa998ba0eb8d563c09bf88a1 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 5 Mar 2026 15:17:55 +0000 Subject: [PATCH 39/60] Route prompts by checklist question if it is reachability question or not --- ci/scripts/analyze_traces.py | 14 +++++- src/vuln_analysis/functions/cve_agent.py | 44 ++++++++++++----- .../functions/react_internals.py | 49 +++++++++++++++++++ src/vuln_analysis/utils/prompt_factory.py | 36 ++++++++++++++ 4 files changed, 130 insertions(+), 13 deletions(-) diff --git a/ci/scripts/analyze_traces.py b/ci/scripts/analyze_traces.py index 5824c128..b9ea3281 100644 --- a/ci/scripts/analyze_traces.py +++ b/ci/scripts/analyze_traces.py @@ -71,6 +71,7 @@ class InvestigationFlow: cve_id: str = "" ecosystem: str = "" question: str = "" + is_reachability: str = "" def _preview(text: str, max_len: int = 120) -> str: @@ -299,6 +300,13 @@ def _build_flow_from_spans( elif kind == "TOOL": flow.tool_call_count += 1 flow.tools_used.append(span.get("name", "")) + elif kind == "FUNCTION" and span.get("name", "") == "pre_process node": + try: + output = json.loads(span.get("output_value", "{}")) + if isinstance(output, dict) and "reachability_question" in output: + flow.is_reachability = output["reachability_question"] + except (json.JSONDecodeError, TypeError): + pass return flow @@ -366,20 +374,22 @@ def run_quality_checks(flow: InvestigationFlow) -> list[QualityFlag]: tool_set = set(flow.tools_used) + is_reachability_question = flow.is_reachability != "no" + if not tool_set: flags.append(QualityFlag( name="NO_TOOL_CALLS", severity=Severity.HIGH, detail="Agent concluded without any tool calls.", )) - elif not tool_set & REACHABILITY_TOOLS: + elif is_reachability_question and not tool_set & REACHABILITY_TOOLS: flags.append(QualityFlag( name="NO_REACHABILITY_TOOL", severity=Severity.HIGH, detail=f"No reachability tool used. Tools: {', '.join(tool_set)}", )) - if tool_set and tool_set <= SEARCH_ONLY_TOOLS: + if is_reachability_question and tool_set and tool_set <= SEARCH_ONLY_TOOLS: flags.append(QualityFlag( name="KEYWORD_SEARCH_ONLY", severity=Severity.HIGH, diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 41f5de9d..ab84f01f 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -36,9 +36,9 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT +from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT, AGENT_SYS_PROMPT_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY from vuln_analysis.utils.prompting import build_tool_descriptions -from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES +from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_SELECTION_STRATEGY_NON_REACHABILITY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage @@ -155,7 +155,7 @@ def _estimate_tokens(runtime_prompt: str, messages: list, observation: Observati parts.append(item) return _count_tokens("\n".join(parts)) - def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list) -> tuple[str, str]: + def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list, is_reachability: str = "yes") -> tuple[str, str]: """Build tool guidance using language-specific strategies when available.""" filtered_tools = [ t for t in available_tools @@ -164,12 +164,14 @@ def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list) -> list_of_tool_names = [t.name for t in filtered_tools] list_of_tool_descriptions = [t.name + ": " + t.description for t in filtered_tools] + strategy = TOOL_SELECTION_STRATEGY if is_reachability == "yes" else TOOL_SELECTION_STRATEGY_NON_REACHABILITY lang = ecosystem.lower() if ecosystem else "" - if lang in TOOL_SELECTION_STRATEGY: - tool_guidance_local = TOOL_SELECTION_STRATEGY[lang] - hint = FEW_SHOT_EXAMPLES.get(lang, "") - if hint: - tool_guidance_local += f"\nHint: {hint}" + if lang in strategy: + tool_guidance_local = strategy[lang] + if is_reachability == "yes": + hint = FEW_SHOT_EXAMPLES.get(lang, "") + if hint: + tool_guidance_local += f"\nHint: {hint}" else: tool_guidance_list_local = build_tool_descriptions(list_of_tool_names) tool_guidance_local = "\n".join(tool_guidance_list_local) @@ -344,11 +346,31 @@ async def pre_process_node(state: AgentState) -> AgentState: "reachability_question": classification_result.is_reachability, }) - tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) - runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) + is_reachability = classification_result.is_reachability + + if is_reachability == "yes": + tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) + runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) + active_tool_names = [t.name for t in tools] + else: + reachability_tool_names = {ToolNames.FUNCTION_LOCATOR, ToolNames.CALL_CHAIN_ANALYZER, ToolNames.FUNCTION_CALLER_FINDER} + non_reach_tools = [t for t in tools if t.name not in reachability_tool_names] + tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, non_reach_tools, is_reachability="no") + runtime_prompt = build_system_prompt( + descriptions_local, tool_guidance_local, + instructions=AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY, + sys_prompt=AGENT_SYS_PROMPT_NON_REACHABILITY, + ) + active_tool_names = [t.name for t in non_reach_tools] + logger.info("Non-reachability question detected; removed reachability tools from prompt") + rules_tracker = state.get("rules_tracker") + app_package = app_package if selected_package else None + rules_tracker.set_target_package(app_package) + rules_tracker.set_allowed_tools(active_tool_names) return { "ecosystem": ecosystem, "runtime_prompt": runtime_prompt, + "is_reachability": is_reachability, "observation": Observation(memory=critical_context, results=[]), "app_package": app_package if selected_package else None, } @@ -482,7 +504,7 @@ async def observation_node(state: AgentState) -> AgentState: tool_output_for_llm = _group_code_search_results( tool_message.content, state["ecosystem"], state["app_package"], ) - rules_tracker.set_target_package(state.get("app_package")) + result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm) if result: span.set_output({"rule_error": error_message}) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index d40a609a..211d7763 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -70,6 +70,9 @@ class SystemRulesTracker: def __init__(self): self.action_history = {} self.target_package = None + self.allowed_tools = [] + def set_allowed_tools(self, allowed_tools: list[str]): + self.allowed_tools = allowed_tools def set_target_package(self, target_package: str): self.target_package = target_package @@ -122,12 +125,19 @@ def _rule_number_8(self, action: str, action_input: str, output) -> bool: if input_pkg != target_pkg: return True return False + def _rule_use_allowed_tools(self, action: str) -> bool: + if action not in self.allowed_tools: + return True + return False + def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]: if self._rule_number_7(action, action_input, output): return True, ("You are NOT following Rule 7. Your query contains dots and returned " "no results. You MUST retry with just the final component. Follow the rules.") if self._rule_number_8(action, action_input, output): return True, (f"You are NOT following Rule 8. You are using the wrong package name. You MUST use the target package name {self.target_package} see KNOWLEDGE as the package_name before trying alternative packages. Follow the rules.") + if self._rule_use_allowed_tools(action): + return True, (f"You are NOT following AVAILABLE_TOOLS. You MUST use the allowed tools {self.allowed_tools}. Follow the rules.") self.add_action(action, action_input, output) return False, "" @@ -142,6 +152,7 @@ class AgentState(MessagesState): ecosystem: str | None = None runtime_prompt: str | None = None app_package: str | None = None + is_reachability: str = "yes" rules_tracker: SystemRulesTracker = SystemRulesTracker() ### --- End of REACT Schemas ----# @@ -204,6 +215,44 @@ class AgentState(MessagesState): {{"thought": "Function Locator confirmed the package. Now trace reachability with Call Chain Analyzer", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} """ +AGENT_SYS_PROMPT_NON_REACHABILITY = ( + "You are a security analyst investigating CVE exploitability in container images.\n" + "This is NOT a reachability question -- do NOT trace call chains.\n" + "MANDATORY STEPS (follow in order, do NOT skip any):\n" + "1. IDENTIFY the vulnerable component/function from the CVE description.\n" + "2. SEARCH for its presence using Code Keyword Search.\n" + "3. DISTINGUISH where the code was found: main application code vs. package dependencies. " + " Results from Code Keyword Search are grouped into 'Main application' and 'Application library dependencies'. " + " Pay close attention to which group contains the match.\n" + "4. ASSESS: determine the answer based on code presence, version info, configuration, " + " or any other evidence relevant to the question. Use Code Semantic Search or Docs Semantic Search " + " for deeper understanding when needed.\n" + "GENERAL RULES:\n" + "- Base conclusions ONLY on tool results, not assumptions.\n" + "- If a search returns no results, that is evidence the code is absent.\n" + "- Do NOT claim a function is used unless a tool confirmed it.\n" + "- Pay attention to whether findings come from the main application or from library dependencies.\n" + "- Use CVE Web Search to gather additional vulnerability context if needed." +) + +AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY = """ +1. Output valid JSON only. thought < 100 words. final_answer < 150 words. +2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. +3. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. +4. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. +5. If Code Keyword Search returns no results and the query contains dots (e.g. a.b.ClassName), retry with just the final component (e.g. ClassName). This does NOT apply to simple names without dots. +6. When Code Keyword Search results are grouped, note whether matches are in "Main application" or "Application library dependencies" -- this distinction is important for your analysis. + + +{{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "", "tool_input": null, "reason": "Check if vulnerable function is present in the container"}}, "final_answer": null}} + + +{{"thought": "Found the function in dependencies. Search for how it is used in the application code", "mode": "act", "actions": {{"tool": "Code Semantic Search", "package_name": null, "function_name": null, "query": "how is used", "tool_input": null, "reason": "Understand usage context in the application"}}, "final_answer": null}} + + +{{"thought": "Have enough evidence about the vulnerable component presence and context", "mode": "finish", "actions": null, "final_answer": "The vulnerable function was found in the container's dependency libraries under . The main application code does not directly reference it. Based on the CVE description and the evidence gathered, ..."}} +""" + CLASSIFICATION_PROMPT_TEMPLATE = """You are classifying a CVE investigation question. Context (CVE / vulnerable packages): diff --git a/src/vuln_analysis/utils/prompt_factory.py b/src/vuln_analysis/utils/prompt_factory.py index bd88b59f..d17c1577 100644 --- a/src/vuln_analysis/utils/prompt_factory.py +++ b/src/vuln_analysis/utils/prompt_factory.py @@ -150,6 +150,42 @@ ), } +TOOL_SELECTION_STRATEGY_NON_REACHABILITY: dict[str, str] = { + "python": ( + "Use Code Keyword Search first for exact import/function lookups " + "(e.g. urllib.parse, PIL.Image). " + "Pay attention to whether results appear in the main application code " + "or in package dependencies (transitive_env/). " + "Use Code Semantic Search to understand how a component is used; " + "Docs Semantic Search for architecture questions. " + "Use CVE Web Search for additional vulnerability context." + ), + "go": ( + "Use Code Keyword Search for import paths and function names. " + "Pay attention to whether results appear in the main application code or in vendor/ dependencies. " + "Use Code Semantic Search to understand usage patterns; Docs Semantic Search for architecture. " + "Use CVE Web Search for additional vulnerability context." + ), + "java": ( + "Use Code Keyword Search for import statements and class/method patterns (e.g. import com., javax.servlet). " + "Pay attention to whether results appear in the main application code or in library dependencies. " + "Use Docs Semantic Search for Spring/servlet architecture; Code Semantic Search for usage patterns. " + "Use CVE Web Search for additional vulnerability context." + ), + "javascript": ( + "Use Code Keyword Search first for require(, import {, and package patterns. " + "Pay attention to whether results appear in the main application code or in node_modules/ dependencies. " + "Use Code/Docs Semantic Search for middleware, API, and npm package usage patterns. " + "Use CVE Web Search for additional vulnerability context." + ), + "c": ( + "Use Code Keyword Search first for exact function name lookups (e.g. PQescapeLiteral, EVP_EncryptInit_ex2). " + "Pay attention to whether results appear in the main application code or in rpm_libs/ dependencies. " + "Use Code Semantic Search to understand usage patterns; Docs Semantic Search for library documentation. " + "Use CVE Web Search for additional vulnerability context." + ), +} + TOOL_GENERAL_DESCRIPTIONS: dict[str, str] = { ToolNames.CODE_SEMANTIC_SEARCH: "Searches container source code using semantic search. " "Finds how functions, libraries, or components are used in the codebase. " From 33c593b051019ffa4118dff5d097b0cccdfd2454 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Fri, 6 Mar 2026 09:54:34 +0200 Subject: [PATCH 40/60] Applly feedback from last runs --- src/vuln_analysis/functions/cve_agent.py | 45 +------------------ .../functions/react_internals.py | 21 ++++++--- .../tools/lexical_full_search.py | 2 +- src/vuln_analysis/utils/full_text_search.py | 45 ++++++++++++++++--- src/vuln_analysis/utils/prompt_factory.py | 40 ++++++++++++----- 5 files changed, 87 insertions(+), 66 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index ab84f01f..9f900477 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -43,13 +43,6 @@ from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage -ECOSYSTEM_DEP_DIRS = { - "c": "rpm_libs", - "go": "vendor", - "javascript": "node_modules", - "python": "transitive_env", - "java": "dependencies-sources", -} import uuid import tiktoken from nat.builder.context import Context @@ -272,37 +265,6 @@ def _filter_context_to_package(critical_context: list[str], selected: str, all_c filtered.append(entry) return filtered - def _group_code_search_results(tool_output: str, ecosystem: str, app_package: str) -> str: - """Split Code Keyword Search results into main-application vs dependency groups.""" - dep_dir = ECOSYSTEM_DEP_DIRS.get(ecosystem) - if not dep_dir: - return tool_output - try: - parsed = json.loads(tool_output) - if not isinstance(parsed, list): - return tool_output - except (json.JSONDecodeError, TypeError): - return tool_output - - main_app_results = [] - dep_results = [] - dep_prefix = dep_dir + "/" - for item in parsed: - source = item.get("source", "") if isinstance(item, dict) else "" - if source.startswith(dep_prefix): - dep_results.append(item) - else: - main_app_results.append(item) - - parts = [] - if main_app_results: - parts.append(f"Main application - {app_package}\n{json.dumps(main_app_results)}") - if dep_results: - parts.append(f"Application library dependencies\n{json.dumps(dep_results)}") - if not parts: - return tool_output - return "\n\n".join(parts) - async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value @@ -353,7 +315,7 @@ async def pre_process_node(state: AgentState) -> AgentState: runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) active_tool_names = [t.name for t in tools] else: - reachability_tool_names = {ToolNames.FUNCTION_LOCATOR, ToolNames.CALL_CHAIN_ANALYZER, ToolNames.FUNCTION_CALLER_FINDER} + reachability_tool_names = { ToolNames.CALL_CHAIN_ANALYZER, ToolNames.FUNCTION_CALLER_FINDER} non_reach_tools = [t for t in tools if t.name not in reachability_tool_names] tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, non_reach_tools, is_reachability="no") runtime_prompt = build_system_prompt( @@ -500,11 +462,6 @@ async def observation_node(state: AgentState) -> AgentState: with AGENT_TRACER.push_active_function("observation node", input_data=f"tool used:{tool_used}") as span: try: tool_output_for_llm = tool_message.content - if tool_used == "Code Keyword Search" and state.get("app_package") and state.get("ecosystem"): - tool_output_for_llm = _group_code_search_results( - tool_message.content, state["ecosystem"], state["app_package"], - ) - result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm) if result: span.set_output({"rule_error": error_message}) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 211d7763..eaf8b2dc 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -227,27 +227,36 @@ class AgentState(MessagesState): "4. ASSESS: determine the answer based on code presence, version info, configuration, " " or any other evidence relevant to the question. Use Code Semantic Search or Docs Semantic Search " " for deeper understanding when needed.\n" + "CRITICAL RULE:\n" + "- If a vulnerable function is found ONLY in 'Application library dependencies' " + " and NOT in 'Main application', this means the code exists in the dependency " + " tree but the main application does NOT directly use it. " + " Do NOT conclude the application is vulnerable based solely on presence in " + " dependency libraries. State clearly: 'found in dependency libraries but not " + " directly referenced by the main application.'\n" "GENERAL RULES:\n" "- Base conclusions ONLY on tool results, not assumptions.\n" "- If a search returns no results, that is evidence the code is absent.\n" "- Do NOT claim a function is used unless a tool confirmed it.\n" - "- Pay attention to whether findings come from the main application or from library dependencies.\n" "- Use CVE Web Search to gather additional vulnerability context if needed." ) AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY = """ 1. Output valid JSON only. thought < 100 words. final_answer < 150 words. 2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. -3. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. -4. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. -5. If Code Keyword Search returns no results and the query contains dots (e.g. a.b.ClassName), retry with just the final component (e.g. ClassName). This does NOT apply to simple names without dots. -6. When Code Keyword Search results are grouped, note whether matches are in "Main application" or "Application library dependencies" -- this distinction is important for your analysis. +3. Function Locator: MUST set package_name AND function_name. Do NOT use query. +4. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. +5. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. +6. If Code Keyword Search returns no results and the query contains dots (e.g. a.b.ClassName), retry with just the final component (e.g. ClassName). This does NOT apply to simple names without dots. +7. When Code Keyword Search results are grouped, note whether matches are in "Main application" or "Application library dependencies" -- this distinction is important for your analysis. +8. Function Locator validates package/function NAMES only. It confirms the name exists, not that it is called or reachable. +9. When using Function Locator, always start with the TARGET PACKAGE from KNOWLEDGE as the package_name before trying alternative packages. {{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "", "tool_input": null, "reason": "Check if vulnerable function is present in the container"}}, "final_answer": null}} -{{"thought": "Found the function in dependencies. Search for how it is used in the application code", "mode": "act", "actions": {{"tool": "Code Semantic Search", "package_name": null, "function_name": null, "query": "how is used", "tool_input": null, "reason": "Understand usage context in the application"}}, "final_answer": null}} +{{"thought": "Found the function in dependencies. Use Function Locator to validate the package and function name", "mode": "act", "actions": {{"tool": "Function Locator", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Validate package and function name"}}, "final_answer": null}} {{"thought": "Have enough evidence about the vulnerable component presence and context", "mode": "finish", "actions": null, "final_answer": "The vulnerable function was found in the container's dependency libraries under . The main application code does not directly reference it. Based on the CVE description and the evidence gathered, ..."}} diff --git a/src/vuln_analysis/tools/lexical_full_search.py b/src/vuln_analysis/tools/lexical_full_search.py index 7166b5e5..0b24fcc1 100644 --- a/src/vuln_analysis/tools/lexical_full_search.py +++ b/src/vuln_analysis/tools/lexical_full_search.py @@ -41,7 +41,7 @@ async def lexical_search(config: LexicalSearchToolConfig, builder: Builder): # from vuln_analysis.utils.full_text_search import FullTextSearch @catch_tool_errors(LEXICAL_CODE_SEARCH) - async def _arun(query: str) -> list: + async def _arun(query: str) -> str: workflow_state = ctx_state.get() code_index_path = workflow_state.code_index_path full_text_search = FullTextSearch(cache_path=code_index_path) diff --git a/src/vuln_analysis/utils/full_text_search.py b/src/vuln_analysis/utils/full_text_search.py index 3f09349b..47f6108c 100644 --- a/src/vuln_analysis/utils/full_text_search.py +++ b/src/vuln_analysis/utils/full_text_search.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import re from pathlib import Path @@ -84,6 +85,13 @@ def replace_quoted(match): return input_query +ECOSYSTEM_DEP_DIRS = { + "c": "rpm_libs/", + "go": "vendor/", + "javascript": "node_modules/", + "python": "transitive_env/", + "java": "dependencies-sources/", +} class FullTextSearch: INDEX_TYPE = "tantivy" @@ -120,7 +128,7 @@ def add_documents(self, documents: Iterable): def is_empty(self): return self.index.searcher().num_docs == 0 - def search_index(self, query: str, top_k: int = 10) -> list[dict] | str: + def search_index(self, query: str, top_k: int = 10) -> str: self.index.reload() try: @@ -129,10 +137,37 @@ def search_index(self, query: str, top_k: int = 10) -> list[dict] | str: query = clean_query(query) query, _ = self.index.parse_query_lenient(query) searcher = self.index.searcher() - results = searcher.search(query, limit=top_k).hits - return [{ - "source": searcher.doc(doc_id)["file_path"][0], "content": searcher.doc(doc_id)["content"][0] - } for _, doc_id in results] + results = searcher.search(query, limit=50).hits + app_docs = [] + dep_docs = [] + + + vendors_list = list(ECOSYSTEM_DEP_DIRS.values()) + for _, doc_id in results: + raw = searcher.doc(doc_id) + doc = {"source": raw["file_path"][0], "content": raw["content"][0]} + if any(doc["source"].startswith(vendor) for vendor in vendors_list): + dep_docs.append(doc) + else: + app_docs.append(doc) + + total_app = len(app_docs) + total_dep = len(dep_docs) + app_docs = app_docs[:top_k] + remaining = top_k - len(app_docs) + dep_docs = dep_docs[:max(remaining, 0)] + + app_header = f"Main application ({len(app_docs)} of {total_app} results)" + dep_header = f"Application library dependencies ({len(dep_docs)} of {total_dep} results)" + + parts = [app_header] + if app_docs: + parts.append(json.dumps(app_docs)) + parts.append(dep_header) + if dep_docs: + parts.append(json.dumps(dep_docs)) + + return "\n".join(parts) except Exception as e: logger.exception(e) diff --git a/src/vuln_analysis/utils/prompt_factory.py b/src/vuln_analysis/utils/prompt_factory.py index d17c1577..853fa5b0 100644 --- a/src/vuln_analysis/utils/prompt_factory.py +++ b/src/vuln_analysis/utils/prompt_factory.py @@ -156,32 +156,52 @@ "(e.g. urllib.parse, PIL.Image). " "Pay attention to whether results appear in the main application code " "or in package dependencies (transitive_env/). " + "Use Function Locator to validate package and function names " + "(e.g. urllib,parse). " "Use Code Semantic Search to understand how a component is used; " "Docs Semantic Search for architecture questions. " "Use CVE Web Search for additional vulnerability context." ), "go": ( "Use Code Keyword Search for import paths and function names. " - "Pay attention to whether results appear in the main application code or in vendor/ dependencies. " - "Use Code Semantic Search to understand usage patterns; Docs Semantic Search for architecture. " + "Pay attention to whether results appear in the main application code " + "or in vendor/ dependencies. " + "Use Function Locator to validate package paths and function names " + "(e.g. github.com/pkg/errors,New). " + "Use Code Semantic Search to understand usage patterns; " + "Docs Semantic Search for architecture. " "Use CVE Web Search for additional vulnerability context." ), "java": ( - "Use Code Keyword Search for import statements and class/method patterns (e.g. import com., javax.servlet). " - "Pay attention to whether results appear in the main application code or in library dependencies. " - "Use Docs Semantic Search for Spring/servlet architecture; Code Semantic Search for usage patterns. " + "Use Code Keyword Search for import statements and class/method patterns " + "(e.g. import com., javax.servlet). " + "Pay attention to whether results appear in the main application code " + "or in library dependencies. " + "Use Function Locator with maven GAV format to validate names " + "(e.g. group:artifact:version,ClassName.methodName). " + "Use Docs Semantic Search for Spring/servlet architecture; " + "Code Semantic Search for usage patterns. " "Use CVE Web Search for additional vulnerability context." ), "javascript": ( "Use Code Keyword Search first for require(, import {, and package patterns. " - "Pay attention to whether results appear in the main application code or in node_modules/ dependencies. " - "Use Code/Docs Semantic Search for middleware, API, and npm package usage patterns. " + "Pay attention to whether results appear in the main application code " + "or in node_modules/ dependencies. " + "Use Function Locator to validate package and function names " + "(e.g. lodash,defaultsDeep). " + "Use Code/Docs Semantic Search for middleware, API, and npm package " + "usage patterns. " "Use CVE Web Search for additional vulnerability context." ), "c": ( - "Use Code Keyword Search first for exact function name lookups (e.g. PQescapeLiteral, EVP_EncryptInit_ex2). " - "Pay attention to whether results appear in the main application code or in rpm_libs/ dependencies. " - "Use Code Semantic Search to understand usage patterns; Docs Semantic Search for library documentation. " + "Use Code Keyword Search first for exact function name lookups " + "(e.g. PQescapeLiteral, EVP_EncryptInit_ex2). " + "Pay attention to whether results appear in the main application code " + "or in rpm_libs/ dependencies. " + "Use Function Locator to validate library and function names " + "(e.g. libpq,PQescapeLiteral). " + "Use Code Semantic Search to understand usage patterns; " + "Docs Semantic Search for library documentation. " "Use CVE Web Search for additional vulnerability context." ), } From 821676bc934e4b55c10bc828a9d80ab401e4dea4 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Fri, 6 Mar 2026 12:41:12 +0200 Subject: [PATCH 41/60] fixed edge case where inteli for go is not review -> no vull array need to fetch them --- ci/scripts/analyze_traces.py | 49 ++++++- src/vuln_analysis/functions/cve_agent.py | 104 ++----------- src/vuln_analysis/utils/intel_utils.py | 177 +++++++++++++++++++++++ 3 files changed, 232 insertions(+), 98 deletions(-) diff --git a/ci/scripts/analyze_traces.py b/ci/scripts/analyze_traces.py index b9ea3281..6213f7ab 100644 --- a/ci/scripts/analyze_traces.py +++ b/ci/scripts/analyze_traces.py @@ -263,19 +263,23 @@ def _build_flow_from_spans( for i, span in enumerate(group_spans): kind = span.get("span_kind", "") + span_name = span.get("name", "") role = "" mode = "" if kind == "LLM": role = _classify_llm_span(span) if role == "thought": mode = _extract_thought_mode(span.get("output_value", "")) + elif kind == "FUNCTION" and span_name == "thought node": + role = "thought" + mode = _extract_thought_mode(span.get("output_value", "")) has_error, error_detail = _check_span_error(span) step = InvestigationStep( step_num=i + 1, span_kind=kind, - name=span.get("name", ""), + name=span_name, input_preview=_preview(span.get("input_value", "")), output_preview=_preview(span.get("output_value", "")), token_prompt=span.get("token_prompt", 0), @@ -297,10 +301,14 @@ def _build_flow_from_spans( flow.cve_id = cve_id flow.ecosystem = eco flow.question = _extract_question(span.get("input_value", "")) + elif kind == "FUNCTION" and span_name == "thought node": + flow.llm_call_count += 1 + if step.token_prompt > 0: + flow.token_progression.append(step.token_prompt) elif kind == "TOOL": flow.tool_call_count += 1 - flow.tools_used.append(span.get("name", "")) - elif kind == "FUNCTION" and span.get("name", "") == "pre_process node": + flow.tools_used.append(span_name) + elif kind == "FUNCTION" and span_name == "pre_process node": try: output = json.loads(span.get("output_value", "{}")) if isinstance(output, dict) and "reachability_question" in output: @@ -323,6 +331,20 @@ def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: job_id = trace_doc.get("job_id", "") spans = trace_doc.get("spans", []) + llm_tokens_by_parent: dict[str, dict[str, int]] = {} + for s in spans: + if ( + s.get("span_kind") == "LLM" + and s.get("function_name") in NODE_FUNCTION_NAMES + ): + pid = s.get("parent_span_id", "") + if pid and s.get("token_prompt", 0) > 0: + llm_tokens_by_parent[pid] = { + "token_prompt": s.get("token_prompt", 0), + "token_completion": s.get("token_completion", 0), + "token_total": s.get("token_total", 0), + } + agent_spans = [ s for s in spans if ( @@ -334,6 +356,11 @@ def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: ) ] + for s in agent_spans: + sid = s.get("span_id", "") + if sid in llm_tokens_by_parent: + s.update(llm_tokens_by_parent[sid]) + parent_groups: dict[str, list[dict]] = {} for span in agent_spans: pid = span.get("parent_span_id", "unknown") @@ -429,6 +456,22 @@ def run_quality_checks(flow: InvestigationFlow) -> list[QualityFlag]: detail=f"Function Caller Finder used on {flow.ecosystem} (Go-only tool).", )) + seen_tool_calls: set[tuple[str, str]] = set() + duplicate_count = 0 + for step in flow.steps: + if step.span_kind == "TOOL": + key = (step.name, step.input_preview) + if key in seen_tool_calls: + duplicate_count += 1 + else: + seen_tool_calls.add(key) + if duplicate_count > 0: + flags.append(QualityFlag( + name="DUPLICATE_TOOL_CALL", + severity=Severity.MEDIUM, + detail=f"{duplicate_count} duplicate tool call(s) with identical inputs.", + )) + thought_steps = [s for s in flow.steps if s.llm_role == "thought"] if thought_steps: last_thought = thought_steps[-1] diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 9f900477..8fb81c88 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -39,6 +39,7 @@ from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT, AGENT_SYS_PROMPT_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY from vuln_analysis.utils.prompting import build_tool_descriptions from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_SELECTION_STRATEGY_NON_REACHABILITY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES +from vuln_analysis.utils.intel_utils import build_critical_context, enrich_go_from_osv, filter_context_to_package from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage @@ -172,105 +173,18 @@ def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list, is descriptions_local = "\n".join(list_of_tool_descriptions) return tool_guidance_local, descriptions_local - def _build_critical_context(cve_intel_list) -> tuple[list[str], list[dict]]: - """Extract key facts from all available intel sources into a compact context. - - Returns (critical_context, candidate_packages) where candidate_packages - contains dicts with 'name', 'source', and optional 'ecosystem' keys. - """ - critical_context = [] - candidate_packages: list[dict] = [] - seen_packages: set[str] = set() - - for cve_intel in cve_intel_list: - if cve_intel.nvd is not None: - if cve_intel.nvd.cve_description: - critical_context.append(f"CVE Description: {cve_intel.nvd.cve_description[:400]}") - if cve_intel.nvd.cwe_name: - critical_context.append(f"CWE: {cve_intel.nvd.cwe_name}") - - if cve_intel.ghsa is not None: - if cve_intel.ghsa.vulnerabilities: - for v in cve_intel.ghsa.vulnerabilities[:3]: - vuln = v if isinstance(v, dict) else (v.__dict__ if hasattr(v, '__dict__') else {}) - vf = vuln.get('vulnerable_functions', []) if isinstance(vuln, dict) else getattr(v, 'vulnerable_functions', []) - pkg = vuln.get('package', None) if isinstance(vuln, dict) else getattr(v, 'package', None) - - if vf: - critical_context.append(f"Vulnerable functions (GHSA): {', '.join(vf)}") - short_names = [f.rsplit('.', 1)[-1] for f in vf if '.' in f] - if short_names: - critical_context.append(f"Search keywords: {', '.join(short_names)}") - if pkg: - if isinstance(pkg, dict): - pkg_name = pkg.get("name", "") - pkg_eco = pkg.get("ecosystem", "") - if pkg_name: - critical_context.append(f"Vulnerable module ({pkg_eco}): {pkg_name}") - if pkg_name not in seen_packages: - seen_packages.add(pkg_name) - candidate_packages.append({"name": pkg_name, "source": "ghsa", "ecosystem": pkg_eco}) - elif isinstance(pkg, str): - critical_context.append(f"Affected package: {pkg}") - if pkg not in seen_packages: - seen_packages.add(pkg) - candidate_packages.append({"name": pkg, "source": "ghsa"}) - if cve_intel.ghsa.description and not any("CVE Description" in c for c in critical_context): - critical_context.append(f"CVE Description: {cve_intel.ghsa.description[:400]}") - - if cve_intel.rhsa is not None: - if cve_intel.rhsa.statement: - critical_context.append(f"RHSA Statement: {cve_intel.rhsa.statement[:300]}") - if cve_intel.rhsa.package_state: - pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name)) - for p in pkgs: - if p not in seen_packages: - seen_packages.add(p) - candidate_packages.append({"name": p, "source": "rhsa"}) - if len(pkgs) > 5: - critical_context.append( - f"Affected across {len(pkgs)} Red Hat products (sample: {', '.join(pkgs[:5])}). " - "Focus investigation on the vulnerable library/module, not individual products." - ) - elif pkgs: - numbered = ", ".join(f"{i+1}) {p}" for i, p in enumerate(pkgs)) - critical_context.append(f"INVESTIGATE EACH package: {numbered}.") - - if cve_intel.ubuntu is not None: - if cve_intel.ubuntu.ubuntu_description: - critical_context.append(f"Ubuntu note: {cve_intel.ubuntu.ubuntu_description[:200]}") - - if cve_intel.plugin_data: - for pd in cve_intel.plugin_data[:2]: - critical_context.append(f"{pd.label}: {pd.description[:200]}") - - if not critical_context: - critical_context = ["No CVE intel available. Investigate using tools."] - return critical_context, candidate_packages - - def _filter_context_to_package(critical_context: list[str], selected: str, all_candidates: list[dict]) -> list[str]: - """Remove context entries and rejected package name references for non-selected packages.""" - rejected_names = {c["name"] for c in all_candidates if c["name"] != selected} - filtered = [] - for entry in critical_context: - if entry.startswith("INVESTIGATE EACH package:"): - filtered.append(f"Target package: {selected}") - continue - if entry.startswith("Vulnerable module (") or entry.startswith("Affected package:"): - if any(rn in entry for rn in rejected_names): - continue - for rn in rejected_names: - entry = entry.replace(f" {rn} ", " ") - entry = entry.replace(f" {rn},", ",") - filtered.append(entry) - return filtered - async def pre_process_node(state: AgentState) -> AgentState: workflow_state = ctx_state.get() ecosystem = workflow_state.original_input.input.image.ecosystem.value with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span: try: - critical_context, candidate_packages = _build_critical_context(workflow_state.cve_intel) + critical_context, candidate_packages = build_critical_context(workflow_state.cve_intel) + + ghsa_has_packages = any(c.get("source") == "ghsa" for c in candidate_packages) + if ecosystem == "go" and not ghsa_has_packages: + cve_intel = workflow_state.cve_intel[0] if workflow_state.cve_intel else None + if cve_intel: + await enrich_go_from_osv(cve_intel, critical_context, candidate_packages) selected_package = None app_package = None @@ -289,7 +203,7 @@ async def pre_process_node(state: AgentState) -> AgentState: app_package = selected_package logger.info("Package filter selected '%s' from %d candidates (reason: %s)", selected_package, len(candidate_packages), selection.reason) - critical_context = _filter_context_to_package(critical_context, selected_package, candidate_packages) + critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages) critical_context.append( "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " diff --git a/src/vuln_analysis/utils/intel_utils.py b/src/vuln_analysis/utils/intel_utils.py index ef7e7efb..8163a5a8 100644 --- a/src/vuln_analysis/utils/intel_utils.py +++ b/src/vuln_analysis/utils/intel_utils.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + from packaging.version import InvalidVersion from packaging.version import parse as parse_version from pydpkg import Dpkg @@ -20,6 +22,10 @@ from exploit_iq_commons.data_models.cve_intel import CveIntelNvd +from exploit_iq_commons.logging.loggers_factory import LoggingFactory + +logger = LoggingFactory.get_agent_logger(__name__) + def update_version(incoming_version, current_version, compare): """ @@ -154,3 +160,174 @@ def parse(configurations: list): version_info.append(obj) return version_info + + +# --------------------------------------------------------------------------- +# Critical-context helpers (used by pre_process_node in cve_agent.py) +# --------------------------------------------------------------------------- + +def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict]]: + """Extract key facts from all available intel sources into a compact context. + + Returns (critical_context, candidate_packages) where candidate_packages + contains dicts with 'name', 'source', and optional 'ecosystem' keys. + """ + critical_context = [] + candidate_packages: list[dict] = [] + seen_packages: set[str] = set() + + for cve_intel in cve_intel_list: + if cve_intel.nvd is not None: + if cve_intel.nvd.cve_description: + critical_context.append(f"CVE Description: {cve_intel.nvd.cve_description[:400]}") + if cve_intel.nvd.cwe_name: + critical_context.append(f"CWE: {cve_intel.nvd.cwe_name}") + + if cve_intel.ghsa is not None: + if cve_intel.ghsa.vulnerabilities: + for v in cve_intel.ghsa.vulnerabilities[:3]: + vuln = v if isinstance(v, dict) else (v.__dict__ if hasattr(v, '__dict__') else {}) + vf = vuln.get('vulnerable_functions', []) if isinstance(vuln, dict) else getattr(v, 'vulnerable_functions', []) + pkg = vuln.get('package', None) if isinstance(vuln, dict) else getattr(v, 'package', None) + + if vf: + critical_context.append(f"Vulnerable functions (GHSA): {', '.join(vf)}") + short_names = [f.rsplit('.', 1)[-1] for f in vf if '.' in f] + if short_names: + critical_context.append(f"Search keywords: {', '.join(short_names)}") + if pkg: + if isinstance(pkg, dict): + pkg_name = pkg.get("name", "") + pkg_eco = pkg.get("ecosystem", "") + if pkg_name: + critical_context.append(f"Vulnerable module ({pkg_eco}): {pkg_name}") + if pkg_name not in seen_packages: + seen_packages.add(pkg_name) + candidate_packages.append({"name": pkg_name, "source": "ghsa", "ecosystem": pkg_eco}) + elif isinstance(pkg, str): + critical_context.append(f"Affected package: {pkg}") + if pkg not in seen_packages: + seen_packages.add(pkg) + candidate_packages.append({"name": pkg, "source": "ghsa"}) + if cve_intel.ghsa.description and not any("CVE Description" in c for c in critical_context): + critical_context.append(f"CVE Description: {cve_intel.ghsa.description[:400]}") + + if cve_intel.rhsa is not None: + if cve_intel.rhsa.statement: + critical_context.append(f"RHSA Statement: {cve_intel.rhsa.statement[:300]}") + if cve_intel.rhsa.package_state: + pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name)) + for p in pkgs: + if p not in seen_packages: + seen_packages.add(p) + candidate_packages.append({"name": p, "source": "rhsa"}) + if len(pkgs) > 5: + critical_context.append( + f"Affected across {len(pkgs)} Red Hat products (sample: {', '.join(pkgs[:5])}). " + "Focus investigation on the vulnerable library/module, not individual products." + ) + elif pkgs: + numbered = ", ".join(f"{i+1}) {p}" for i, p in enumerate(pkgs)) + critical_context.append(f"INVESTIGATE EACH package: {numbered}.") + + if cve_intel.ubuntu is not None: + if cve_intel.ubuntu.ubuntu_description: + critical_context.append(f"Ubuntu note: {cve_intel.ubuntu.ubuntu_description[:200]}") + + if cve_intel.plugin_data: + for pd in cve_intel.plugin_data[:2]: + critical_context.append(f"{pd.label}: {pd.description[:200]}") + + if not critical_context: + critical_context = ["No CVE intel available. Investigate using tools."] + return critical_context, candidate_packages + + +_GO_VULN_RE = re.compile(r"pkg\.go\.dev/vuln/(GO-\d{4}-\d+)") +_OSV_API_URL = "https://api.osv.dev/v1/vulns/" +_OSV_TIMEOUT_SECONDS = 5 + + +async def enrich_go_from_osv( + cve_intel, + critical_context: list[str], + candidate_packages: list[dict], +) -> None: + """Query the OSV API for Go module paths when GHSA has no package data. + + Looks for a pkg.go.dev/vuln/GO-XXXX-XXXX link in GHSA or NVD + references, fetches the advisory from OSV, and injects module + paths and vulnerable symbols into critical_context and + candidate_packages. Fails silently on any error. + """ + refs: list[str] = [] + if cve_intel.ghsa is not None: + refs.extend(getattr(cve_intel.ghsa, "references", None) or []) + if cve_intel.nvd is not None: + refs.extend(cve_intel.nvd.references or []) + + go_id = None + for ref in refs: + m = _GO_VULN_RE.search(ref if isinstance(ref, str) else "") + if m: + go_id = m.group(1) + break + + if not go_id: + return + + try: + import aiohttp + timeout = aiohttp.ClientTimeout(total=_OSV_TIMEOUT_SECONDS) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(f"{_OSV_API_URL}{go_id}") as resp: + if resp.status != 200: + logger.warning("OSV lookup for %s returned status %d", go_id, resp.status) + return + osv_data = await resp.json() + + seen_paths: set[str] = set() + all_symbols: list[str] = [] + for affected in osv_data.get("affected", []): + eco_specific = affected.get("ecosystem_specific", {}) + for imp in eco_specific.get("imports", []): + path = imp.get("path", "") + if path and path not in seen_paths: + seen_paths.add(path) + candidate_packages.append( + {"name": path, "source": "osv", "ecosystem": "Go"}, + ) + critical_context.append(f"Vulnerable module (Go): {path}") + symbols = imp.get("symbols", []) + all_symbols.extend(symbols) + + if all_symbols: + critical_context.append( + f"Vulnerable functions (Go vuln DB): {', '.join(all_symbols)}" + ) + short_names = [s.rsplit(".", 1)[-1] for s in all_symbols if "." in s] + unique_keywords = list(dict.fromkeys(all_symbols + short_names)) + critical_context.append(f"Search keywords: {', '.join(unique_keywords)}") + + if seen_paths: + logger.info("OSV enrichment for %s added Go modules: %s", go_id, seen_paths) + except Exception: + logger.warning("OSV enrichment failed for %s", go_id, exc_info=True) + + +def filter_context_to_package(critical_context: list[str], selected: str, all_candidates: list[dict]) -> list[str]: + """Remove context entries and rejected package name references for non-selected packages.""" + rejected_names = {c["name"] for c in all_candidates if c["name"] != selected} + filtered = [] + for entry in critical_context: + if entry.startswith("INVESTIGATE EACH package:"): + filtered.append(f"Target package: {selected}") + continue + if entry.startswith("Vulnerable module (") or entry.startswith("Affected package:"): + if any(rn in entry for rn in rejected_names): + continue + for rn in rejected_names: + entry = entry.replace(f" {rn} ", " ") + entry = entry.replace(f" {rn},", ",") + filtered.append(entry) + return filtered From 4daffb76504f2156d05a6624add4a3855acead79 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Fri, 6 Mar 2026 16:36:42 +0200 Subject: [PATCH 42/60] inject reachability prefix in summary when all CCA results are negative --- ci/scripts/analyze_traces.py | 101 +++++++++++++++++- src/vuln_analysis/functions/cve_agent.py | 14 ++- src/vuln_analysis/functions/cve_summarize.py | 26 ++++- .../functions/react_internals.py | 1 + 4 files changed, 137 insertions(+), 5 deletions(-) diff --git a/ci/scripts/analyze_traces.py b/ci/scripts/analyze_traces.py index 6213f7ab..81d51d6a 100644 --- a/ci/scripts/analyze_traces.py +++ b/ci/scripts/analyze_traces.py @@ -319,6 +319,90 @@ def _build_flow_from_spans( return flow +def _build_flows_from_workflow_output( + spans: list[dict], + trace_id: str, + job_id: str, +) -> list[InvestigationFlow]: + """Fallback: build flows from the span's output when agent-level + spans (thought node, pre_process node, etc.) are absent. + + This handles the ``cve_justify`` pipeline variant where the agent's + internal LangGraph spans are not exported individually. + """ + workflow_span = next( + (s for s in spans if s.get("function_name") == "" and s.get("span_kind") == "FUNCTION"), + None, + ) + if not workflow_span: + return [] + + output_raw = workflow_span.get("output_value", "") + try: + output = json.loads(output_raw) + except (json.JSONDecodeError, TypeError): + return [] + + analyses = (output.get("output", {}) or {}).get("analysis", []) + if not analyses: + return [] + + input_raw = workflow_span.get("input_value", "") + ecosystem = "" + try: + inp = json.loads(input_raw) + ecosystem = (inp.get("image", {}) or {}).get("ecosystem", "") + except (json.JSONDecodeError, TypeError): + pass + + llm_spans = [s for s in spans if s.get("span_kind") == "LLM"] + total_prompt = sum(s.get("token_prompt", 0) for s in llm_spans) + total_completion = sum(s.get("token_completion", 0) for s in llm_spans) + + flows: list[InvestigationFlow] = [] + for analysis in analyses: + vuln_id = analysis.get("vuln_id", "") + checklist = analysis.get("checklist", []) + justification = analysis.get("justification", {}) or {} + label = justification.get("label", "") + reason = justification.get("reason", "") + + flow = InvestigationFlow( + trace_id=trace_id, + job_id=job_id, + parent_span_id=workflow_span.get("span_id", ""), + cve_id=vuln_id, + ecosystem=ecosystem, + ) + + for i, item in enumerate(checklist): + question = item.get("input", "") + response = item.get("response", "") + flow.steps.append(InvestigationStep( + step_num=i + 1, + span_kind="SUMMARY", + name="checklist_qa", + input_preview=_preview(question, 200), + output_preview=_preview(response, 200), + )) + + flow.steps.append(InvestigationStep( + step_num=len(checklist) + 1, + span_kind="SUMMARY", + name="justification", + input_preview=f"label={label}", + output_preview=_preview(reason, 200), + )) + + flow.llm_call_count = len(llm_spans) + if total_prompt > 0: + flow.token_progression.append(total_prompt) + + flows.append(flow) + + return flows + + def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: """ From a single trace document, reconstruct per-question investigation flows. @@ -381,6 +465,9 @@ def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: flow = _build_flow_from_spans(group_spans, trace_id, job_id, parent_id) flows.append(flow) + if not flows: + flows = _build_flows_from_workflow_output(spans, trace_id, job_id) + return flows @@ -391,6 +478,10 @@ def build_investigation_flows(trace_doc: dict) -> list[InvestigationFlow]: def run_quality_checks(flow: InvestigationFlow) -> list[QualityFlag]: flags = [] + is_summary_only = all(s.span_kind == "SUMMARY" for s in flow.steps) if flow.steps else False + if is_summary_only: + return flags + error_steps = [s for s in flow.steps if s.has_error] for es in error_steps: flags.append(QualityFlag( @@ -508,7 +599,9 @@ def run_quality_checks(flow: InvestigationFlow) -> list[QualityFlag]: # --------------------------------------------------------------------------- def print_flow_report(flow: InvestigationFlow): - header = f"=== Trace: {flow.trace_id} | Parent: {flow.parent_span_id[:12]}... ===" + is_summary = all(s.span_kind == "SUMMARY" for s in flow.steps) if flow.steps else False + tag = " [SUMMARY]" if is_summary else "" + header = f"=== Trace: {flow.trace_id} | Parent: {flow.parent_span_id[:12]}...{tag} ===" print(header) if flow.cve_id or flow.ecosystem: print(f"CVE: {flow.cve_id or 'unknown'} | Ecosystem: {flow.ecosystem or 'unknown'}") @@ -554,6 +647,10 @@ def print_flow_report(flow: InvestigationFlow): print(f"--- Step {step.step_num}: TOOL ({step.name}){error_marker} ---") print(f" Input: {step.input_preview}") print(f" Output: \"{step.output_preview}\"") + elif step.span_kind == "SUMMARY": + print(f"--- Step {step.step_num}: {step.name} ---") + print(f" Q: {step.input_preview}") + print(f" A: \"{step.output_preview}\"") elif step.span_kind == "FUNCTION": print(f"--- Step {step.step_num}: NODE ({step.name}){error_marker} ---") if step.has_error: @@ -599,12 +696,14 @@ def print_summary(all_flows: list[InvestigationFlow]): def build_json_report(all_flows: list[InvestigationFlow]) -> dict: flows_data = [] for flow in all_flows: + is_summary = all(s.span_kind == "SUMMARY" for s in flow.steps) if flow.steps else False flows_data.append({ "trace_id": flow.trace_id, "job_id": flow.job_id, "cve_id": flow.cve_id, "ecosystem": flow.ecosystem, "question": flow.question, + "summary_only": is_summary, "llm_call_count": flow.llm_call_count, "tool_call_count": flow.tool_call_count, "token_progression": flow.token_progression, diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 8fb81c88..a4a2eb38 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -411,10 +411,19 @@ async def observation_node(state: AgentState) -> AgentState: len(prune_messages), estimated, config.context_window_token_limit, ) span.set_output({"orig_estimated": orig_estimated, "estimated": estimated}) + cca_results = list(state.get("cca_results", [])) + if tool_used == ToolNames.CALL_CHAIN_ANALYZER: + stripped = tool_output_for_llm.strip().lstrip("([") + first_token = stripped.split(",", 1)[0].strip().lower() + if first_token == "true": + cca_results.append(True) + elif first_token == "false": + cca_results.append(False) return { "messages": prune_messages, "observation": new_observation, "step": state.get("step", 0), + "cca_results": cca_results, } except Exception as e: logger.exception("observation_node failed") @@ -552,7 +561,7 @@ def _postprocess_results(results: list[list[dict]], replace_exceptions: bool, re # If the agent encounters a parsing error or a server error after retries, replace the error # with default values to prevent the pipeline from crashing outputs[i].append({"input": checklist_questions[i][j], "output": replace_exceptions_value, - "intermediate_steps": None}) + "intermediate_steps": None, "cca_results": []}) if isinstance(answer, ToolRaisedException): tool_raised_exception: ToolRaisedException = answer logger.warning(f"An exception encountered during tool execution, in result [{i}][{j}]. for " @@ -580,7 +589,8 @@ def _postprocess_results(results: list[list[dict]], replace_exceptions: bool, re results[i][j]["intermediate_steps"] = None outputs[i].append({"input": answer["input"], "output": answer["output"], - "intermediate_steps": results[i][j]["intermediate_steps"]}) + "intermediate_steps": results[i][j]["intermediate_steps"], + "cca_results": answer.get("cca_results", [])}) return outputs diff --git a/src/vuln_analysis/functions/cve_summarize.py b/src/vuln_analysis/functions/cve_summarize.py index db7fba0b..8702e5eb 100644 --- a/src/vuln_analysis/functions/cve_summarize.py +++ b/src/vuln_analysis/functions/cve_summarize.py @@ -36,6 +36,16 @@ class CVESummarizeToolConfig(FunctionBaseConfig, name="cve_summarize"): llm_name: str = Field(description="The LLM model to use") +def _all_cca_not_reachable(checklist_items: list[dict]) -> bool: + """Return True if all Call Chain Analyzer results are False (not reachable). + Returns False if no CCA calls were made (no signal to gate on). + """ + all_cca = [] + for item in checklist_items: + all_cca.extend(item.get("cca_results", [])) + return len(all_cca) > 0 and not any(all_cca) + + @register_function(config_type=CVESummarizeToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) async def cve_summarize(config: CVESummarizeToolConfig, builder: Builder): @@ -49,10 +59,22 @@ async def cve_summarize(config: CVESummarizeToolConfig, builder: Builder): chain = prompt | llm async def summarize_cve(results): + checklist_items = results[1] response = '\n'.join( - [get_checklist_item_string(idx + 1, checklist_item) for idx, checklist_item in enumerate(results[1])]) - final_summary = await chain.ainvoke({"response": response}) + [get_checklist_item_string(idx + 1, checklist_item) for idx, checklist_item in enumerate(checklist_items)]) + + if _all_cca_not_reachable(checklist_items): + response = ( + "REACHABILITY GATE: Call Chain Analyzer confirmed the vulnerable function " + "is NOT reachable from application code in ALL reachability checks performed. " + "An unreachable function cannot be exploited regardless of other findings " + "(missing mitigations, absent protections, etc. are irrelevant if the code " + "path is never executed). The verdict MUST be 'not exploitable'.\n\n" + + response + ) + logger.info("Reachability gate activated: all CCA results are negative") + final_summary = await chain.ainvoke({"response": response}) return final_summary.content async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index eaf8b2dc..f987041c 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -154,6 +154,7 @@ class AgentState(MessagesState): app_package: str | None = None is_reachability: str = "yes" rules_tracker: SystemRulesTracker = SystemRulesTracker() + cca_results: list[bool] = [] ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# From 7f608fc1f6333448c66790a4df0ef09d041f01bf Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sat, 7 Mar 2026 11:47:53 +0200 Subject: [PATCH 43/60] add target functions from inteli cve if exist --- src/vuln_analysis/functions/cve_agent.py | 10 +++++-- .../functions/react_internals.py | 29 +++++++++++++++++++ src/vuln_analysis/utils/intel_utils.py | 18 +++++++++--- 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index a4a2eb38..70ae9218 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -178,13 +178,15 @@ async def pre_process_node(state: AgentState) -> AgentState: ecosystem = workflow_state.original_input.input.image.ecosystem.value with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span: try: - critical_context, candidate_packages = build_critical_context(workflow_state.cve_intel) + critical_context, candidate_packages, vulnerable_functions = build_critical_context(workflow_state.cve_intel) + vulnerable_functions_set = set(vulnerable_functions) ghsa_has_packages = any(c.get("source") == "ghsa" for c in candidate_packages) - if ecosystem == "go" and not ghsa_has_packages: + if ecosystem == "go" and (not ghsa_has_packages or not vulnerable_functions_set): cve_intel = workflow_state.cve_intel[0] if workflow_state.cve_intel else None if cve_intel: - await enrich_go_from_osv(cve_intel, critical_context, candidate_packages) + await enrich_go_from_osv(cve_intel, critical_context, candidate_packages, vulnerable_functions_set) + vulnerable_functions = sorted(vulnerable_functions_set) selected_package = None app_package = None @@ -243,6 +245,8 @@ async def pre_process_node(state: AgentState) -> AgentState: app_package = app_package if selected_package else None rules_tracker.set_target_package(app_package) rules_tracker.set_allowed_tools(active_tool_names) + if is_reachability == "yes": + rules_tracker.set_target_functions(vulnerable_functions) return { "ecosystem": ecosystem, "runtime_prompt": runtime_prompt, diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index f987041c..afeddfdf 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -71,10 +71,13 @@ def __init__(self): self.action_history = {} self.target_package = None self.allowed_tools = [] + self.target_functions: dict[str, bool] = {} def set_allowed_tools(self, allowed_tools: list[str]): self.allowed_tools = allowed_tools def set_target_package(self, target_package: str): self.target_package = target_package + def set_target_functions(self, functions: list[str]): + self.target_functions = {f: False for f in functions} @staticmethod def _is_empty_result(output) -> bool: @@ -130,6 +133,28 @@ def _rule_use_allowed_tools(self, action: str) -> bool: return True return False + def _rule_number_9(self, action: str, action_input: str) -> tuple[bool, str]: + if not self.target_functions: + return False, "" + if action not in ("Call Chain Analyzer", "Function Caller Finder"): + return False, "" + input_function = action_input.split(",", 1)[-1].strip() if "," in action_input else "" + if not input_function: + return False, "" + input_short = input_function.rsplit(".", 1)[-1].lower() + for fn in self.target_functions: + if fn.lower() == input_short: + self.target_functions[fn] = True + return False, "" + pending = [fn for fn, checked in self.target_functions.items() if not checked] + if pending: + return True, ( + f"You are NOT following Rule 9. The CVE lists specific vulnerable functions " + f"that you MUST investigate first: {', '.join(pending)}. " + f"Check these functions before investigating other functions." + ) + return False, "" + def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]: if self._rule_number_7(action, action_input, output): return True, ("You are NOT following Rule 7. Your query contains dots and returned " @@ -138,6 +163,9 @@ def check_thought_behavior(self, action: str, action_input: str, output) -> tupl return True, (f"You are NOT following Rule 8. You are using the wrong package name. You MUST use the target package name {self.target_package} see KNOWLEDGE as the package_name before trying alternative packages. Follow the rules.") if self._rule_use_allowed_tools(action): return True, (f"You are NOT following AVAILABLE_TOOLS. You MUST use the allowed tools {self.allowed_tools}. Follow the rules.") + rule9, msg9 = self._rule_number_9(action, action_input) + if rule9: + return True, msg9 self.add_action(action, action_input, output) return False, "" @@ -206,6 +234,7 @@ class AgentState(MessagesState): 6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. 7. If Code Keyword Search returns no results and the query contains dots (e.g. a.b.ClassName), retry with just the final component (e.g. ClassName). This does NOT apply to simple names without dots. 8. When using Function Locator, Call Chain Analyzer, or Function Caller Finder, always start with the TARGET PACKAGE from KNOWLEDGE as the package_name before trying alternative packages. +9. If KNOWLEDGE lists "Vulnerable functions (GHSA)" or "Vulnerable functions (Go vuln DB)", you MUST investigate those specific functions FIRST before checking any other functions in the same package. {{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} diff --git a/src/vuln_analysis/utils/intel_utils.py b/src/vuln_analysis/utils/intel_utils.py index 8163a5a8..3747965e 100644 --- a/src/vuln_analysis/utils/intel_utils.py +++ b/src/vuln_analysis/utils/intel_utils.py @@ -166,15 +166,18 @@ def parse(configurations: list): # Critical-context helpers (used by pre_process_node in cve_agent.py) # --------------------------------------------------------------------------- -def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict]]: +def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict], list[str]]: """Extract key facts from all available intel sources into a compact context. - Returns (critical_context, candidate_packages) where candidate_packages - contains dicts with 'name', 'source', and optional 'ecosystem' keys. + Returns (critical_context, candidate_packages, vulnerable_functions) where + candidate_packages contains dicts with 'name', 'source', and optional + 'ecosystem' keys, and vulnerable_functions is a deduplicated list of short + function names from GHSA. """ critical_context = [] candidate_packages: list[dict] = [] seen_packages: set[str] = set() + vulnerable_functions: set[str] = set() for cve_intel in cve_intel_list: if cve_intel.nvd is not None: @@ -195,6 +198,8 @@ def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict]]: short_names = [f.rsplit('.', 1)[-1] for f in vf if '.' in f] if short_names: critical_context.append(f"Search keywords: {', '.join(short_names)}") + for f in vf: + vulnerable_functions.add(f.rsplit('.', 1)[-1]) if pkg: if isinstance(pkg, dict): pkg_name = pkg.get("name", "") @@ -240,7 +245,7 @@ def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict]]: if not critical_context: critical_context = ["No CVE intel available. Investigate using tools."] - return critical_context, candidate_packages + return critical_context, candidate_packages, sorted(vulnerable_functions) _GO_VULN_RE = re.compile(r"pkg\.go\.dev/vuln/(GO-\d{4}-\d+)") @@ -252,6 +257,7 @@ async def enrich_go_from_osv( cve_intel, critical_context: list[str], candidate_packages: list[dict], + vulnerable_functions: set[str] | None = None, ) -> None: """Query the OSV API for Go module paths when GHSA has no package data. @@ -308,6 +314,10 @@ async def enrich_go_from_osv( short_names = [s.rsplit(".", 1)[-1] for s in all_symbols if "." in s] unique_keywords = list(dict.fromkeys(all_symbols + short_names)) critical_context.append(f"Search keywords: {', '.join(unique_keywords)}") + if vulnerable_functions is not None: + for s in all_symbols: + vulnerable_functions.add(s.rsplit(".", 1)[-1]) + vulnerable_functions.add(s) if seen_paths: logger.info("OSV enrichment for %s added Go modules: %s", go_id, seen_paths) From 6cdd41d8d2ed15b294250b51c5f6bf969abac222 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sat, 7 Mar 2026 13:33:35 +0200 Subject: [PATCH 44/60] New Gate check that vul package exist --- src/vuln_analysis/functions/cve_agent.py | 16 ++++++++++-- src/vuln_analysis/functions/cve_summarize.py | 25 ++++++++++++++++++- .../functions/react_internals.py | 1 + 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 70ae9218..ad3a1ce5 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -423,11 +423,21 @@ async def observation_node(state: AgentState) -> AgentState: cca_results.append(True) elif first_token == "false": cca_results.append(False) + package_validated = state.get("package_validated") + if tool_used == ToolNames.FUNCTION_LOCATOR and state.get("is_reachability") == "yes": + input_pkg = tool_input_detail.split(",", 1)[0].strip().lower() + target_pkg = (state.get("app_package") or "").strip().lower() + if target_pkg and input_pkg == target_pkg: + if "Package is valid" in tool_output_for_llm: + package_validated = True + elif "Package is not valid" in tool_output_for_llm and package_validated is None: + package_validated = False return { "messages": prune_messages, "observation": new_observation, "step": state.get("step", 0), "cca_results": cca_results, + "package_validated": package_validated, } except Exception as e: logger.exception("observation_node failed") @@ -565,7 +575,8 @@ def _postprocess_results(results: list[list[dict]], replace_exceptions: bool, re # If the agent encounters a parsing error or a server error after retries, replace the error # with default values to prevent the pipeline from crashing outputs[i].append({"input": checklist_questions[i][j], "output": replace_exceptions_value, - "intermediate_steps": None, "cca_results": []}) + "intermediate_steps": None, "cca_results": [], + "package_validated": None}) if isinstance(answer, ToolRaisedException): tool_raised_exception: ToolRaisedException = answer logger.warning(f"An exception encountered during tool execution, in result [{i}][{j}]. for " @@ -594,7 +605,8 @@ def _postprocess_results(results: list[list[dict]], replace_exceptions: bool, re outputs[i].append({"input": answer["input"], "output": answer["output"], "intermediate_steps": results[i][j]["intermediate_steps"], - "cca_results": answer.get("cca_results", [])}) + "cca_results": answer.get("cca_results", []), + "package_validated": answer.get("package_validated")}) return outputs diff --git a/src/vuln_analysis/functions/cve_summarize.py b/src/vuln_analysis/functions/cve_summarize.py index 8702e5eb..f3783f9c 100644 --- a/src/vuln_analysis/functions/cve_summarize.py +++ b/src/vuln_analysis/functions/cve_summarize.py @@ -46,6 +46,20 @@ def _all_cca_not_reachable(checklist_items: list[dict]) -> bool: return len(all_cca) > 0 and not any(all_cca) +def _package_not_found(checklist_items: list[dict]) -> bool: + """Return True if Function Locator confirmed the target package is absent. + Fires when any thread has package_validated=False and no thread has True. + """ + has_false = False + for item in checklist_items: + val = item.get("package_validated") + if val is True: + return False + if val is False: + has_false = True + return has_false + + @register_function(config_type=CVESummarizeToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) async def cve_summarize(config: CVESummarizeToolConfig, builder: Builder): @@ -63,7 +77,16 @@ async def summarize_cve(results): response = '\n'.join( [get_checklist_item_string(idx + 1, checklist_item) for idx, checklist_item in enumerate(checklist_items)]) - if _all_cca_not_reachable(checklist_items): + if _package_not_found(checklist_items): + response = ( + "PACKAGE PRESENCE GATE: Function Locator confirmed the vulnerable package " + "is NOT present in this container. A CVE cannot be exploitable if the " + "vulnerable package does not exist. Code matches from other packages are " + "irrelevant. The verdict MUST be 'not exploitable'.\n\n" + + response + ) + logger.info("Package presence gate activated: target package not found") + elif _all_cca_not_reachable(checklist_items): response = ( "REACHABILITY GATE: Call Chain Analyzer confirmed the vulnerable function " "is NOT reachable from application code in ALL reachability checks performed. " diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index afeddfdf..2c0c78b7 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -183,6 +183,7 @@ class AgentState(MessagesState): is_reachability: str = "yes" rules_tracker: SystemRulesTracker = SystemRulesTracker() cca_results: list[bool] = [] + package_validated: bool | None = None ### --- End of REACT Schemas ----# #---- REACT Prompt Templates ----# From e8dcf07f8a2fb652aa25d3acd47b564f4b9d07cf Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sat, 7 Mar 2026 13:53:39 +0200 Subject: [PATCH 45/60] Can remove cache file using env variables can be configure from google sheet --- .tekton/on-cm-runner.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.tekton/on-cm-runner.yaml b/.tekton/on-cm-runner.yaml index a4108244..711533ef 100644 --- a/.tekton/on-cm-runner.yaml +++ b/.tekton/on-cm-runner.yaml @@ -225,6 +225,13 @@ spec: echo "Cache link created successfully." ls -ld "${CACHE_DIR_TARGET}" + # Remove cached pickle files if REMOVE_CACHE_FILE is set (via Google Sheet config) + # e.g. REMOVE_CACHE_FILE=https.github.com.postgres.postgres + if [ -n "${REMOVE_CACHE_FILE:-}" ]; then + echo "--- Removing cache files matching: ${CACHE_DIR_TARGET}/pickle/${REMOVE_CACHE_FILE}* ---" + rm -fv "${CACHE_DIR_TARGET}/pickle/${REMOVE_CACHE_FILE}"* || echo " No matching files found." + fi + # Copy with verbose error if it fails echo "--- Exporting Telemetry Config ---" cp -v configs/config-no-tracing.yml $DATA_DIR/config-no-tracing.yml || echo "Failed to copy config" From 54618030c40c21fa71b9f79fab91505fb1b7b7df Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 8 Mar 2026 16:20:53 +0200 Subject: [PATCH 46/60] Handle go multi package options in context --- src/vuln_analysis/functions/cve_agent.py | 78 ++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 6 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index ad3a1ce5..f7ef8871 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -15,6 +15,7 @@ import os import asyncio import json +from pathlib import Path from vuln_analysis.runtime_context import ctx_state import typing from aiq.builder.builder import Builder @@ -40,6 +41,8 @@ from vuln_analysis.utils.prompting import build_tool_descriptions from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_SELECTION_STRATEGY_NON_REACHABILITY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from vuln_analysis.utils.intel_utils import build_critical_context, enrich_go_from_osv, filter_context_to_package +from exploit_iq_commons.utils.git_utils import sanitize_git_url_for_path +from exploit_iq_commons.utils.data_utils import DEFAULT_GIT_DIRECTORY from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage @@ -109,6 +112,62 @@ async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builde tool_descriptions = build_tool_descriptions(enabled_tool_names) return tools, tool_descriptions, tool_descriptions_list +def _validate_go_vendor_packages( + source_info: list, + candidate_packages: list[dict], +) -> tuple[list[dict], list[str]]: + """Check which Go candidate packages actually exist in the vendor directory.""" + code_si = next((si for si in source_info if si.type == "code"), None) + if code_si is None: + return candidate_packages, [] + + repo_path = Path(DEFAULT_GIT_DIRECTORY) / sanitize_git_url_for_path(code_si.git_repo) + vendor_path = repo_path / "vendor" + if not vendor_path.is_dir(): + return candidate_packages, [] + + validated = [] + removed = [] + for pkg in candidate_packages: + pkg_name = pkg.get("name", "") + if (vendor_path / pkg_name).is_dir(): + validated.append(pkg) + else: + removed.append(pkg_name) + + if validated: + return validated, removed + return candidate_packages, [] + + +async def _enrich_go_candidates( + cve_intel: list, + source_info: list, + critical_context: list[str], + candidate_packages: list[dict], + vulnerable_functions_set: set[str], +) -> tuple[list[dict], list[str]]: + """Enrich Go candidates via OSV and validate against vendor directory.""" + ghsa_has_packages = any(c.get("source") == "ghsa" for c in candidate_packages) + if not ghsa_has_packages or not vulnerable_functions_set: + intel = cve_intel[0] if cve_intel else None + if intel: + await enrich_go_from_osv(intel, critical_context, candidate_packages, vulnerable_functions_set) + + if candidate_packages: + candidate_packages, removed_pkgs = _validate_go_vendor_packages( + source_info, candidate_packages + ) + if removed_pkgs: + logger.info("Go vendor validation removed packages not in vendor/: %s", removed_pkgs) + critical_context.append( + f"VENDOR VALIDATION: The following packages were NOT found in the container's " + f"vendor directory and are excluded: {', '.join(removed_pkgs)}" + ) + + return candidate_packages, sorted(vulnerable_functions_set) + + async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState): tools, tool_guidance_list, tool_descriptions_list = await common_build_tools(config, builder, state) @@ -181,12 +240,14 @@ async def pre_process_node(state: AgentState) -> AgentState: critical_context, candidate_packages, vulnerable_functions = build_critical_context(workflow_state.cve_intel) vulnerable_functions_set = set(vulnerable_functions) - ghsa_has_packages = any(c.get("source") == "ghsa" for c in candidate_packages) - if ecosystem == "go" and (not ghsa_has_packages or not vulnerable_functions_set): - cve_intel = workflow_state.cve_intel[0] if workflow_state.cve_intel else None - if cve_intel: - await enrich_go_from_osv(cve_intel, critical_context, candidate_packages, vulnerable_functions_set) - vulnerable_functions = sorted(vulnerable_functions_set) + if ecosystem == "go": + candidate_packages, vulnerable_functions = await _enrich_go_candidates( + workflow_state.cve_intel, + workflow_state.original_input.input.image.source_info, + critical_context, + candidate_packages, + vulnerable_functions_set, + ) selected_package = None app_package = None @@ -206,6 +267,11 @@ async def pre_process_node(state: AgentState) -> AgentState: logger.info("Package filter selected '%s' from %d candidates (reason: %s)", selected_package, len(candidate_packages), selection.reason) critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages) + elif len(candidate_packages) == 1: + selected_package = candidate_packages[0].get("name") + app_package = selected_package + logger.info("Single candidate package after validation: '%s'", selected_package) + critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages) critical_context.append( "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " From b78eb83055b5bedd022c6fe7457a8a4a37c8e7ed Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 8 Mar 2026 18:50:57 +0200 Subject: [PATCH 47/60] clear java cache --- .tekton/on-pull-request.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index 575461d4..a9978240 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -246,8 +246,8 @@ spec: make lint-pr TARGET_BRANCH=$TARGET_BRANCH_NAME #clean the java cache - #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* - #rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat + rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat #rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* #rm -rf ./.cache/am_cache/pickle/https.github.com.postgres.postgres-* print_banner "RUNNING UNIT TESTS" From d967b2e3c1cb5248a9c3ad82246b3a7c876c28d3 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 8 Mar 2026 21:10:44 +0200 Subject: [PATCH 48/60] remove information not needed taking tokens --- src/vuln_analysis/functions/cve_agent.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index f7ef8871..3f6f8946 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -159,11 +159,7 @@ async def _enrich_go_candidates( source_info, candidate_packages ) if removed_pkgs: - logger.info("Go vendor validation removed packages not in vendor/: %s", removed_pkgs) - critical_context.append( - f"VENDOR VALIDATION: The following packages were NOT found in the container's " - f"vendor directory and are excluded: {', '.join(removed_pkgs)}" - ) + logger.info("Go vendor validation removed %d packages not in vendor/: %s", len(removed_pkgs), removed_pkgs) return candidate_packages, sorted(vulnerable_functions_set) From b1b658981a9700ffb92d6a7093cd29ef6ec79fe1 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 8 Mar 2026 21:11:47 +0200 Subject: [PATCH 49/60] disable java clear cache --- .tekton/on-pull-request.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index a9978240..575461d4 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -246,8 +246,8 @@ spec: make lint-pr TARGET_BRANCH=$TARGET_BRANCH_NAME #clean the java cache - rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* - rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat + #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + #rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat #rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* #rm -rf ./.cache/am_cache/pickle/https.github.com.postgres.postgres-* print_banner "RUNNING UNIT TESTS" From 348e6f1a7849a12f8f724602c414e6ccfc8faf4a Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 8 Mar 2026 21:48:49 +0200 Subject: [PATCH 50/60] fix regression issue --- src/vuln_analysis/utils/full_text_search.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/vuln_analysis/utils/full_text_search.py b/src/vuln_analysis/utils/full_text_search.py index 47f6108c..0a02ab99 100644 --- a/src/vuln_analysis/utils/full_text_search.py +++ b/src/vuln_analysis/utils/full_text_search.py @@ -153,6 +153,10 @@ def search_index(self, query: str, top_k: int = 10) -> str: total_app = len(app_docs) total_dep = len(dep_docs) + + if total_app == 0 and total_dep == 0: + return "[]" + app_docs = app_docs[:top_k] remaining = top_k - len(app_docs) dep_docs = dep_docs[:max(remaining, 0)] From 7c9de3c3bc2ed41021a50368004dc07996612231 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Mon, 9 Mar 2026 14:35:21 +0000 Subject: [PATCH 51/60] Add Code understanding into the observation prompt --- .../functions/react_internals.py | 39 ++++++++++++++++--- src/vuln_analysis/utils/intel_utils.py | 5 +++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 2c0c78b7..85b02713 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -43,10 +43,12 @@ class Thought(BaseModel): ) class Observation(BaseModel): - results: list[str] = Field(description="Bullet points of the most important new technical facts found") + results: list[str] = Field( + description="3-5 key technical facts from this tool output. Each fact must describe what the code DOES and how it relates to the investigation goal, not just that it was found." + ) memory: list[str] = Field( - description="A list of cumulative factual findings. Each item is a single, discrete technical fact." + description="Cumulative factual findings. Each item is a single, discrete technical fact that records functional behavior and context (e.g. what a function does, which package it belongs to, whether it is semantically relevant to the goal)." ) @@ -207,7 +209,14 @@ class AgentState(MessagesState): "STOPPING RULES:\n" "- POSITIVE reachability (Call Chain Analyzer returns True): you MAY conclude exploitable and finish.\n" "- NEGATIVE reachability (Call Chain Analyzer returns False): record the result. " - "You may only conclude 'not exploitable' after Call Chain Analyzer has confirmed the function is not reachable." + "You may only conclude 'not exploitable' after Call Chain Analyzer has confirmed the function is not reachable.\n" + "ANSWER QUALITY:\n" + "- Answer the SPECIFIC question asked with evidence. Do NOT just report what tools found.\n" + "- Never give bare assertions (e.g. 'not exploitable'). Always state: WHAT you checked, WHAT you found, and WHY it leads to your conclusion.\n" + "- Distinguish between code being PRESENT (exists in container), REACHABLE (called from application code), and EXPLOITABLE (attacker-controlled input can trigger it). Do not conflate these.\n" + "- If tool results conflict with each other, state the conflict explicitly rather than silently picking one side.\n" + "- Finding that a security check is ABSENT is potential evidence of vulnerability, not evidence of safety.\n" + "- When citing evidence, explain HOW it relates to the question -- do not just state that something was found." ) # Update LANGGRAPH_SYSTEM_PROMPT_TEMPLATE in react_internals.py @@ -269,7 +278,14 @@ class AgentState(MessagesState): "- Base conclusions ONLY on tool results, not assumptions.\n" "- If a search returns no results, that is evidence the code is absent.\n" "- Do NOT claim a function is used unless a tool confirmed it.\n" - "- Use CVE Web Search to gather additional vulnerability context if needed." + "- Use CVE Web Search to gather additional vulnerability context if needed.\n" + "ANSWER QUALITY:\n" + "- Answer the SPECIFIC question asked with evidence. Do NOT just report what tools found.\n" + "- Never give bare assertions (e.g. 'not exploitable'). Always state: WHAT you checked, WHAT you found, and WHY it leads to your conclusion.\n" + "- Distinguish between code being PRESENT (exists in container) and EXPLOITABLE (attacker-controlled input can trigger it). Do not conflate these.\n" + "- If tool results conflict with each other, state the conflict explicitly rather than silently picking one side.\n" + "- Finding that a security check is ABSENT is potential evidence of vulnerability, not evidence of safety.\n" + "- When citing evidence, explain HOW it relates to the question -- do not just state that something was found." ) AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY = """ @@ -330,6 +346,18 @@ class AgentState(MessagesState): NEW OUTPUT: {tool_output} +CODE COMPREHENSION (apply BEFORE recording any finding): +1. READ the actual code snippets in NEW OUTPUT. Do NOT just check whether something was "found" or "not found." +2. For each code snippet or function returned, determine: + - What does this code actually DO? (its functional purpose, not just its name) + - What data does it operate on and in what context? + - Is its purpose semantically related to the GOAL, or is it merely a keyword match? +3. DISTINGUISH same-named entities: a function with the same name in a different package, module, or namespace is a DIFFERENT function. Always record the specific package/module context. +4. CHECK technology fit: if a finding comes from a different programming language, framework, or platform than the target under investigation, note the mismatch explicitly. Do NOT treat it as equivalent. +5. RECORD functional behavior, not just location. Memory entries must describe what the code DOES (e.g., "rate_limit() in X throttles I/O bandwidth for read/write operations"), not just where it was found (e.g., "rate_limit found in X"). +6. SEPARATE keyword presence from investigation relevance: a security-related function existing in the codebase is NOT automatically a mitigation or evidence for the vulnerability being investigated. Explain HOW the finding relates to the specific GOAL. +7. GROUND TO VENDOR MITIGATIONS: If PREVIOUS MEMORY contains "KNOWN MITIGATIONS:", compare the code and behavior in NEW OUTPUT against those mitigations. Record whether the codebase implements, partially implements, or contradicts the vendor's recommended mitigations (e.g. configuration settings, patches, workarounds). Do not treat unrelated security-related code as evidence of mitigation for this CVE. + RULES: - memory: Append new facts from OUTPUT. Record absence if nothing found. No duplicates. Factual only. - If NEW OUTPUT is empty, contains an error, or indicates tool failure, record in memory: "FAILED: {tool_used} [{tool_input_detail}] - [reason]". Do NOT infer or fabricate positive findings from failed tool output. @@ -340,7 +368,8 @@ class AgentState(MessagesState): - For Function Locator results, record: "VALIDATED: [package],[function] exists" -- this is NOT reachability. - When Code Keyword Search results are grouped into "Main application" and "Application library dependencies", prioritize findings from the main application group and use its package name as package_name for subsequent tool calls. - Always record in memory: "CALLED: [tool] with [input] -> [outcome]" so future steps know what was already tried. -- results: 3-5 key technical facts from this OUTPUT only. +- results: 3-5 key technical facts from this OUTPUT only. Each fact must reflect code comprehension (what the code does), not just keyword presence. +- If PREVIOUS MEMORY contains "KNOWN MITIGATIONS:" and the tool output is relevant (e.g. config, code path, or version), add to memory whether the codebase aligns with or contradicts those mitigations (e.g. "Mitigation check: follow_symlinks=False not set" or "Vendor mitigation (patch X) applied in version Y"). - Keep only CVE-exploitability-relevant information. RESPONSE: {{""" diff --git a/src/vuln_analysis/utils/intel_utils.py b/src/vuln_analysis/utils/intel_utils.py index 3747965e..4c205230 100644 --- a/src/vuln_analysis/utils/intel_utils.py +++ b/src/vuln_analysis/utils/intel_utils.py @@ -220,6 +220,11 @@ def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict], list[ if cve_intel.rhsa is not None: if cve_intel.rhsa.statement: critical_context.append(f"RHSA Statement: {cve_intel.rhsa.statement[:300]}") + mitigation = getattr(cve_intel.rhsa, 'mitigation', None) + if mitigation: + mit_text = mitigation.get('value', '') if isinstance(mitigation, dict) else str(mitigation) + if mit_text: + critical_context.append(f"KNOWN MITIGATIONS: {mit_text[:500]}") if cve_intel.rhsa.package_state: pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name)) for p in pkgs: From f2d38c716d09cd6a3bac6bc9a9892c477921bca5 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Tue, 10 Mar 2026 11:59:04 +0000 Subject: [PATCH 52/60] improve function locator return results when pkg not found using sbom information --- .../utils/function_name_locator.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/vuln_analysis/utils/function_name_locator.py b/src/vuln_analysis/utils/function_name_locator.py index 55e866ab..efd44f45 100644 --- a/src/vuln_analysis/utils/function_name_locator.py +++ b/src/vuln_analysis/utils/function_name_locator.py @@ -22,6 +22,9 @@ from exploit_iq_commons.utils.standard_library_cache import StandardLibraryCache from vuln_analysis.utils.serp_api_wrapper import MorpheusSerpAPIWrapper +from exploit_iq_commons.utils.source_rpm_downloader import RPMDependencyManager +from exploit_iq_commons.data_models.input import SBOMPackage + logger = LoggingFactory.get_agent_logger(f"morpheus.{__name__}") @@ -39,8 +42,27 @@ def __init__(self, coc_retriever: ChainOfCallsRetrieverBase): self.stdlib_cache = StandardLibraryCache.get_instance() self.short_go_package_name = {} if self.coc_retriever.ecosystem.value == Ecosystem.GO.value: - self.short_go_package_name = self.build_short_go_package_name() + self.short_go_package_name = self.build_short_go_package_name() + self.sbom_dict = self.build_sbom_dict() + def build_sbom_dict(self) -> dict: + sbom_dict = {} + for package in RPMDependencyManager.get_instance().sbom: + sbom_dict[package.name] = package.version + return sbom_dict + def check_package_in_sbom(self, package: str) -> tuple[bool, str]: + if package in self.sbom_dict: + return True, self.sbom_dict[package] + else: + return False, None + + def check_fuzzy_match_in_sbom(self, package: str) -> tuple[bool, str]: + fuzzy_matches = difflib.get_close_matches(package, self.sbom_dict.keys(), n=3, cutoff=0.6) + if fuzzy_matches: + return True, ", ".join(fuzzy_matches) + else: + return False, None + def build_short_go_package_name(self) -> dict: short_go_package_name = {} for package in self.coc_retriever.supported_packages: @@ -193,6 +215,29 @@ async def locate_functions(self, query: str) -> list[str]: # First validate package name exists in supported packages if (not self.search_in_third_party_packages(package)): logger.info("Package '%s' not found in supported packages", package) + + in_sbom, sbom_version = self.check_package_in_sbom(package) + if in_sbom: + self.is_package_valid = True + logger.info("Package '%s' (version %s) found in SBOM but no source code available", package, sbom_version) + return [ + ( + f"INFO: Package '{package}' (version {sbom_version}) is present in the container SBOM, " + f"however no source code is available so package code content and function names cannot be checked." + ) + ] + + fuzzy_in_sbom, fuzzy_matches = self.check_fuzzy_match_in_sbom(package) + if fuzzy_in_sbom: + logger.info("Package '%s' not found exactly in SBOM, but fuzzy matches found: %s", package, fuzzy_matches) + return [ + ( + f"WARNING: Package '{package}' was NOT found in the container. " + f"Similar SBOM packages exist: {fuzzy_matches}. The function name could NOT be verified. Do NOT treat this as a confirmed match." + + ) + ] + is_standard_lib = self.stdlib_cache.is_standard_library(package, self.coc_retriever.ecosystem) if is_standard_lib: self.is_std_package = True From 91dd8ed98b4f934eeb4856b9f49bebb01d4b80cf Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 11 Mar 2026 12:18:04 +0000 Subject: [PATCH 53/60] Split Observation task to two llm calls --- src/vuln_analysis/functions/cve_agent.py | 31 +++++++-- .../functions/react_internals.py | 66 ++++++++++++++----- 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 3f6f8946..7f68fe8a 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -37,7 +37,7 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, OBSERVATION_NODE_PROMPT, AGENT_SYS_PROMPT_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY +from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, CodeFindings, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, COMPREHENSION_PROMPT, MEMORY_UPDATE_PROMPT, AGENT_SYS_PROMPT_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY from vuln_analysis.utils.prompting import build_tool_descriptions from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_SELECTION_STRATEGY_NON_REACHABILITY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from vuln_analysis.utils.intel_utils import build_critical_context, enrich_go_from_osv, filter_context_to_package @@ -169,6 +169,7 @@ async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Build tools, tool_guidance_list, tool_descriptions_list = await common_build_tools(config, builder, state) llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) thought_llm = llm.with_structured_output(Thought) + comprehension_llm = llm.with_structured_output(CodeFindings) observation_llm = llm.with_structured_output(Observation) reachability_llm = llm.with_structured_output(Classification) package_filter_llm = llm.with_structured_output(PackageSelection) @@ -314,6 +315,7 @@ async def pre_process_node(state: AgentState) -> AgentState: "runtime_prompt": runtime_prompt, "is_reachability": is_reachability, "observation": Observation(memory=critical_context, results=[]), + "critical_context": critical_context, "app_package": app_package if selected_package else None, } except Exception as e: @@ -447,17 +449,31 @@ async def observation_node(state: AgentState) -> AgentState: span.set_output({"rule_error": error_message}) return {"messages": [HumanMessage(content=error_message)]} - prompt = OBSERVATION_NODE_PROMPT.format( + # Step 1: Comprehension -- reads raw tool output, produces compact findings + ctx_lines = state.get("critical_context", []) + critical_context_text = "\n".join(ctx_lines) if ctx_lines else "N/A" + comp_prompt = COMPREHENSION_PROMPT.format( goal=state.get('input'), selected_package=state.get('app_package') or "N/A", - previous_memory=previous_memory, + critical_context=critical_context_text, tool_used=tool_used, tool_input_detail=tool_input_detail, last_thought_text=last_thought_text, tool_output=tool_output_for_llm, ) + code_findings: CodeFindings = await comprehension_llm.ainvoke([SystemMessage(content=comp_prompt)]) - new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=prompt)]) + findings_text = "\n".join(f"- {f}" for f in code_findings.findings) + + # Step 2: Memory update -- merges compressed findings into cumulative memory + mem_prompt = MEMORY_UPDATE_PROMPT.format( + goal=state.get('input'), + selected_package=state.get('app_package') or "N/A", + previous_memory=previous_memory, + findings=findings_text, + tool_outcome=code_findings.tool_outcome, + ) + new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=mem_prompt)]) messages = state["messages"] active_prompt = state.get("runtime_prompt") or default_system_prompt @@ -476,7 +492,12 @@ async def observation_node(state: AgentState) -> AgentState: "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", len(prune_messages), estimated, config.context_window_token_limit, ) - span.set_output({"orig_estimated": orig_estimated, "estimated": estimated}) + span.set_output({ + "orig_estimated": orig_estimated, + "estimated": estimated, + "comprehension_findings": code_findings.findings, + "tool_outcome": code_findings.tool_outcome, + }) cca_results = list(state.get("cca_results", [])) if tool_used == ToolNames.CALL_CHAIN_ANALYZER: stripped = tool_output_for_llm.strip().lstrip("([") diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 85b02713..40729924 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -42,6 +42,17 @@ class Thought(BaseModel): max_length=3000, ) +class CodeFindings(BaseModel): + """Compressed code comprehension output from raw tool results.""" + findings: list[str] = Field( + description="3-5 key technical facts from tool output. Each describes what the code DOES " + "and how it relates to the investigation goal, not just that it was found." + ) + tool_outcome: str = Field( + description="One-line summary: CALLED: [tool] with [input] -> [brief outcome]" + ) + + class Observation(BaseModel): results: list[str] = Field( description="3-5 key technical facts from this tool output. Each fact must describe what the code DOES and how it relates to the investigation goal, not just that it was found." @@ -184,6 +195,7 @@ class AgentState(MessagesState): app_package: str | None = None is_reachability: str = "yes" rules_tracker: SystemRulesTracker = SystemRulesTracker() + critical_context: list[str] = [] cca_results: list[bool] = [] package_validated: bool | None = None @@ -336,17 +348,18 @@ class AgentState(MessagesState): RESPONSE: {{""" -OBSERVATION_NODE_PROMPT = """Update the investigation memory based on new tool output. +COMPREHENSION_PROMPT = """Analyze the tool output and extract key technical findings. GOAL: {goal} TARGET PACKAGE (vulnerability): {selected_package} -PREVIOUS MEMORY: {previous_memory} +VULNERABILITY CONTEXT: +{critical_context} TOOL USED: {tool_used} TOOL INPUT: {tool_input_detail} THOUGHT: {last_thought_text} NEW OUTPUT: {tool_output} -CODE COMPREHENSION (apply BEFORE recording any finding): +CODE COMPREHENSION RULES: 1. READ the actual code snippets in NEW OUTPUT. Do NOT just check whether something was "found" or "not found." 2. For each code snippet or function returned, determine: - What does this code actually DO? (its functional purpose, not just its name) @@ -354,26 +367,49 @@ class AgentState(MessagesState): - Is its purpose semantically related to the GOAL, or is it merely a keyword match? 3. DISTINGUISH same-named entities: a function with the same name in a different package, module, or namespace is a DIFFERENT function. Always record the specific package/module context. 4. CHECK technology fit: if a finding comes from a different programming language, framework, or platform than the target under investigation, note the mismatch explicitly. Do NOT treat it as equivalent. -5. RECORD functional behavior, not just location. Memory entries must describe what the code DOES (e.g., "rate_limit() in X throttles I/O bandwidth for read/write operations"), not just where it was found (e.g., "rate_limit found in X"). +5. RECORD functional behavior, not just location. Findings must describe what the code DOES (e.g., "rate_limit() in X throttles I/O bandwidth for read/write operations"), not just where it was found (e.g., "rate_limit found in X"). 6. SEPARATE keyword presence from investigation relevance: a security-related function existing in the codebase is NOT automatically a mitigation or evidence for the vulnerability being investigated. Explain HOW the finding relates to the specific GOAL. -7. GROUND TO VENDOR MITIGATIONS: If PREVIOUS MEMORY contains "KNOWN MITIGATIONS:", compare the code and behavior in NEW OUTPUT against those mitigations. Record whether the codebase implements, partially implements, or contradicts the vendor's recommended mitigations (e.g. configuration settings, patches, workarounds). Do not treat unrelated security-related code as evidence of mitigation for this CVE. +7. GROUND TO VENDOR MITIGATIONS: If VULNERABILITY CONTEXT contains "KNOWN MITIGATIONS:", compare the code and behavior in NEW OUTPUT against those mitigations. Record whether the codebase implements, partially implements, or contradicts the vendor's recommended mitigations (e.g. configuration settings, patches, workarounds). Do not treat unrelated security-related code as evidence of mitigation for this CVE. + +TOOL-SPECIFIC RULES: +- If NEW OUTPUT is empty, contains an error, or indicates tool failure, findings must only state: "FAILED: {tool_used} [{tool_input_detail}] - [reason]". Do NOT infer or fabricate positive findings. +- Reachability tags apply ONLY to Call Chain Analyzer results. Code Keyword Search and Function Locator do NOT determine reachability. +- If Call Chain Analyzer returned NEGATIVE (False), include: "NOT reachable via [package]." +- If Call Chain Analyzer returned POSITIVE (True), include: "REACHABLE via [package] - sufficient evidence." +- For Function Locator results, include: "VALIDATED: [package],[function] exists" -- this is NOT reachability. +- When Code Keyword Search results are grouped into "Main application" and "Application library dependencies", prioritize findings from the main application group. +- findings: 3-5 key technical facts from this OUTPUT only. Each fact must reflect code comprehension (what the code does), not just keyword presence. +- tool_outcome: a single line "CALLED: [tool] with [input] -> [brief outcome]". +- Keep only CVE-exploitability-relevant information. +RESPONSE: +{{""" + +MEMORY_UPDATE_PROMPT = """Merge new findings into the investigation memory. +GOAL: {goal} +TARGET PACKAGE (vulnerability): {selected_package} +PREVIOUS MEMORY: {previous_memory} +NEW FINDINGS (from tool analysis): +{findings} +TOOL CALL RECORD: {tool_outcome} RULES: -- memory: Append new facts from OUTPUT. Record absence if nothing found. No duplicates. Factual only. -- If NEW OUTPUT is empty, contains an error, or indicates tool failure, record in memory: "FAILED: {tool_used} [{tool_input_detail}] - [reason]". Do NOT infer or fabricate positive findings from failed tool output. -- results from a failed tool call must only state the failure, not speculate about what the tool might have found. +- memory: Start from PREVIOUS MEMORY. Append new facts from NEW FINDINGS. Record absence if nothing was found. No duplicates. Factual only. +- Add TOOL CALL RECORD verbatim to memory so future steps know what was already tried. +- If NEW FINDINGS report a failure (starts with "FAILED:"), add the failure to memory. Do NOT infer or fabricate positive findings. - Reachability tags apply ONLY to Call Chain Analyzer results. Code Keyword Search and Function Locator do NOT determine reachability. -- If Call Chain Analyzer returned NEGATIVE (False), add to memory: "NOT reachable via [package]." -- If Call Chain Analyzer returned POSITIVE (True), add: "REACHABLE via [package] - sufficient evidence." -- For Function Locator results, record: "VALIDATED: [package],[function] exists" -- this is NOT reachability. -- When Code Keyword Search results are grouped into "Main application" and "Application library dependencies", prioritize findings from the main application group and use its package name as package_name for subsequent tool calls. -- Always record in memory: "CALLED: [tool] with [input] -> [outcome]" so future steps know what was already tried. -- results: 3-5 key technical facts from this OUTPUT only. Each fact must reflect code comprehension (what the code does), not just keyword presence. -- If PREVIOUS MEMORY contains "KNOWN MITIGATIONS:" and the tool output is relevant (e.g. config, code path, or version), add to memory whether the codebase aligns with or contradicts those mitigations (e.g. "Mitigation check: follow_symlinks=False not set" or "Vendor mitigation (patch X) applied in version Y"). +- If NEW FINDINGS mention "NOT reachable", add to memory: "NOT reachable via [package]." +- If NEW FINDINGS mention "REACHABLE", add: "REACHABLE via [package] - sufficient evidence." +- For Function Locator validation, record: "VALIDATED: [package],[function] exists" -- this is NOT reachability. +- When findings mention "Main application" vs "Application library dependencies", preserve this distinction in memory. +- If PREVIOUS MEMORY contains "KNOWN MITIGATIONS:" and findings are relevant (e.g. config, code path, or version), add to memory whether the codebase aligns with or contradicts those mitigations (e.g. "Mitigation check: follow_symlinks=False not set" or "Vendor mitigation (patch X) applied in version Y"). +- results: copy the NEW FINDINGS as-is. - Keep only CVE-exploitability-relevant information. RESPONSE: {{""" +# Legacy prompt kept for backwards compatibility with older traces +OBSERVATION_NODE_PROMPT = COMPREHENSION_PROMPT + ### --- End of REACT Prompt Templates ----# def build_system_prompt( tool_descriptions: str, From 7030fbde71d3696384a02c6b918728e098964c1f Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 11 Mar 2026 16:26:36 +0200 Subject: [PATCH 54/60] fix for go full fqdn search --- .../utils/function_name_locator.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/vuln_analysis/utils/function_name_locator.py b/src/vuln_analysis/utils/function_name_locator.py index efd44f45..4ec6c6f6 100644 --- a/src/vuln_analysis/utils/function_name_locator.py +++ b/src/vuln_analysis/utils/function_name_locator.py @@ -70,6 +70,16 @@ def build_short_go_package_name(self) -> dict: short_go_package_name[short_name] = package return short_go_package_name + def _resolve_go_fqdn_to_module(self, package: str) -> str | None: + """Find the longest supported_package that is a prefix of the input package.""" + best_match = None + for supported_package in self.coc_retriever.supported_packages: + if package.lower().startswith(supported_package.lower() + "/") or \ + package.lower() == supported_package.lower(): + if best_match is None or len(supported_package) > len(best_match): + best_match = supported_package + return best_match + def handle_package_not_in_supported_packages(self, package: str) -> list[str]: logger.info("package %s is not in supported packages", package) if self.coc_retriever.supported_packages is None: @@ -177,6 +187,10 @@ def is_same_package(package_name_from_input, package_name_from_tree): return True if short_path: return True + if self.coc_retriever.ecosystem.value == Ecosystem.GO.value: + fqdn_match = self._resolve_go_fqdn_to_module(package) + if fqdn_match: + return True return False async def locate_functions(self, query: str) -> list[str]: @@ -293,6 +307,10 @@ async def locate_functions(self, query: str) -> list[str]: # package_docs = self.coc_retriever.documents_of_functions if package in self.short_go_package_name: package = self.short_go_package_name[package] + elif self.coc_retriever.ecosystem.value == Ecosystem.GO.value: + resolved = self._resolve_go_fqdn_to_module(package) + if resolved: + package = resolved package_docs = [ doc for doc in (self.coc_retriever.documents_of_functions or []) From 90acf5763ba2c08cd93123fad9dc73baacd3fdef Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 11 Mar 2026 14:53:03 +0000 Subject: [PATCH 55/60] fix go exeception in parser --- .../utils/functions_parsers/golang_functions_parsers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py index fad0ee8f..47c2e201 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py +++ b/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py @@ -265,6 +265,9 @@ def parse_all_type_struct_class_to_fields(self, types: list[Document], type_inhe current_line_stripped = current_line_stripped[next_eol + 1:].lstrip() elif -1 < next_struct < next_eol and next_eol > -1: right_curly_bracket_ind = current_line_stripped.find("}") + if right_curly_bracket_ind == -1: + pos = size_of_type_block + break self.parse_one_type( Document(page_content=f"type {current_line_stripped[:right_curly_bracket_ind + 1]}", metadata={"source": the_type.metadata['source']}), From b02caf152ec0690ae43f5c5ec0a90a4b2edc61bb Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 11 Mar 2026 19:50:35 +0200 Subject: [PATCH 56/60] remove cache java --- .tekton/on-pull-request.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index 575461d4..0da82def 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -246,7 +246,7 @@ spec: make lint-pr TARGET_BRANCH=$TARGET_BRANCH_NAME #clean the java cache - #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* #rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat #rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* #rm -rf ./.cache/am_cache/pickle/https.github.com.postgres.postgres-* From 364315b6d3e2ccc84d6658242d51eb719006c7f4 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Thu, 12 Mar 2026 12:09:51 +0200 Subject: [PATCH 57/60] call the function call finder tool --- src/vuln_analysis/functions/cve_agent.py | 5 ++-- .../functions/react_internals.py | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 7f68fe8a..adbbacff 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -37,7 +37,7 @@ from vuln_analysis.utils.prompting import get_agent_prompt from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id -from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, CodeFindings, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, COMPREHENSION_PROMPT, MEMORY_UPDATE_PROMPT, AGENT_SYS_PROMPT_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY +from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, CodeFindings, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, COMPREHENSION_PROMPT, MEMORY_UPDATE_PROMPT, AGENT_SYS_PROMPT_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_GO from vuln_analysis.utils.prompting import build_tool_descriptions from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_SELECTION_STRATEGY_NON_REACHABILITY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES from vuln_analysis.utils.intel_utils import build_critical_context, enrich_go_from_osv, filter_context_to_package @@ -291,7 +291,8 @@ async def pre_process_node(state: AgentState) -> AgentState: if is_reachability == "yes": tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) - runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local) + go_instructions = {"instructions": AGENT_THOUGHT_INSTRUCTIONS_GO} if ecosystem == "go" else {} + runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local, **go_instructions) active_tool_names = [t.name for t in tools] else: reachability_tool_names = { ToolNames.CALL_CHAIN_ANALYZER, ToolNames.FUNCTION_CALLER_FINDER} diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 40729924..545dfa88 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -208,6 +208,7 @@ class AgentState(MessagesState): "2. SEARCH for its presence using Code Keyword Search.\n" "3. TRACE reachability using Call Chain Analyzer. " " - Use the Function Locator to verify the package name and find the function name." + " - For Go: use Function Caller Finder to identify which application functions call the vulnerable library function, BEFORE running Call Chain Analyzer." " - Keyword search alone is NOT sufficient -- you must trace the call chain.\n" "4. ASSESS: only after completing reachability checks, determine exploitability.\n" "GENERAL RULES:\n" @@ -267,6 +268,33 @@ class AgentState(MessagesState): {{"thought": "Function Locator confirmed the package. Now trace reachability with Call Chain Analyzer", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} """ + +AGENT_THOUGHT_INSTRUCTIONS_GO = """ +1. Output valid JSON only. thought < 100 words. final_answer < 150 words. +2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. +3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. +4. Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search: use query field. +5. Code Keyword Search and Function Locator are NOT reachability proof -- only Call Chain Analyzer is. +6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. If already tried, use a DIFFERENT tool or different input. +7. If Code Keyword Search returns no results and the query contains dots (e.g. a.b.ClassName), retry with just the final component (e.g. ClassName). This does NOT apply to simple names without dots. +8. When using Function Locator, Call Chain Analyzer, or Function Caller Finder, always start with the TARGET PACKAGE from KNOWLEDGE as the package_name before trying alternative packages. +9. If KNOWLEDGE lists "Vulnerable functions (GHSA)" or "Vulnerable functions (Go vuln DB)", you MUST investigate those specific functions FIRST before checking any other functions in the same package. +10. After Function Locator validates the package, use Function Caller Finder to identify callers. If FCF finds callers: the function IS reachable -- you MAY conclude and finish. If FCF returns empty: this does NOT mean unreachable. You MUST proceed to Call Chain Analyzer. Do NOT finish or conclude without calling Call Chain Analyzer when FCF returned empty. + + +{{"thought": "Search for the vulnerable function in the codebase first", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "", "tool_input": null, "reason": "Check if vulnerable function is present"}}, "final_answer": null}} + + +{{"thought": "Found the function. Now use Function Locator to verify the package name and function", "mode": "act", "actions": {{"tool": "Function Locator", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Validate package and function name before tracing callers"}}, "final_answer": null}} + + +{{"thought": "Function Locator confirmed the package. Now find which app functions call this library function", "mode": "act", "actions": {{"tool": "Function Caller Finder", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Identify application functions that call the vulnerable library function"}}, "final_answer": null}} + + + +{{"thought": "Function Caller Finder returned no callers, but this does not prove unreachable. MUST call Call Chain Analyzer to confirm reachability", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} +""" + AGENT_SYS_PROMPT_NON_REACHABILITY = ( "You are a security analyst investigating CVE exploitability in container images.\n" "This is NOT a reachability question -- do NOT trace call chains.\n" From ff134bd07afa41c2d941a5a217684a56e7e0bca7 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 15 Mar 2026 08:56:06 +0000 Subject: [PATCH 58/60] fixed arguments for tool call function finder go --- .tekton/on-pull-request.yaml | 2 +- src/vuln_analysis/functions/react_internals.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml index 0da82def..575461d4 100644 --- a/.tekton/on-pull-request.yaml +++ b/.tekton/on-pull-request.yaml @@ -246,7 +246,7 @@ spec: make lint-pr TARGET_BRANCH=$TARGET_BRANCH_NAME #clean the java cache - rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* + #rm -rf ./.cache/am_cache/pickle/https.github.com.cryostatio.cryostat-* #rm -rf ./.cache/am_cache/git/https.github.com.cryostatio.cryostat #rm -rf ./.cache/am_cache/pickle/https.github.com.TamarW0.node-example-project-* #rm -rf ./.cache/am_cache/pickle/https.github.com.postgres.postgres-* diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 545dfa88..df17c2c7 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -288,7 +288,7 @@ class AgentState(MessagesState): {{"thought": "Found the function. Now use Function Locator to verify the package name and function", "mode": "act", "actions": {{"tool": "Function Locator", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Validate package and function name before tracing callers"}}, "final_answer": null}} -{{"thought": "Function Locator confirmed the package. Now find which app functions call this library function", "mode": "act", "actions": {{"tool": "Function Caller Finder", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Identify application functions that call the vulnerable library function"}}, "final_answer": null}} +{{"thought": "Function Locator confirmed the package. Now find which app functions call this library function", "mode": "act", "actions": {{"tool": "Function Caller Finder", "package_name": "", "function_name": ".()", "query": null, "tool_input": null, "reason": "Identify application functions that call the vulnerable library function"}}, "final_answer": null}} From ca81036bb6250f9e30de814f077fc70e89058d33 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Sun, 15 Mar 2026 14:46:44 +0000 Subject: [PATCH 59/60] fix threading in CCA --- .../utils/chain_of_calls_retriever.py | 162 ++++++++---------- 1 file changed, 74 insertions(+), 88 deletions(-) diff --git a/src/exploit_iq_commons/utils/chain_of_calls_retriever.py b/src/exploit_iq_commons/utils/chain_of_calls_retriever.py index f14935a2..2d3c6e12 100644 --- a/src/exploit_iq_commons/utils/chain_of_calls_retriever.py +++ b/src/exploit_iq_commons/utils/chain_of_calls_retriever.py @@ -1,9 +1,10 @@ import re import time from collections import defaultdict, deque +from itertools import chain from pathlib import Path from typing import List - +import copy from langchain_core.documents import Document from exploit_iq_commons.utils.chain_of_calls_retriever_base import ChainOfCallsRetrieverBase, PARENTS_INDEX, \ @@ -50,6 +51,14 @@ def is_likely_macro_block(segment: str) -> bool: return False +class SessionContext: + """Per-call mutable state, isolating concurrent get_relevant_documents calls.""" + def __init__(self, tree_dict: dict[str, list[str]]): + self.last_visited = dict() + self.tree_dict = copy.deepcopy(tree_dict) + self.found_path = False + + class ChainOfCallsRetriever(ChainOfCallsRetrieverBase): """A ChainOfCall retriever that Knows how to perform a deep search, looking for a function usage, whether it's being called from the application code base or not. @@ -125,12 +134,8 @@ def __init__(self, documents: List[Document], ecosystem: Ecosystem, manifest_pat logger.debug(f"self.documents len : {len(self.documents)}") logger.debug("Chain of Calls Retriever - retaining only types/classes docs " "documents_of_types len %d", len(self.documents_of_types)) - # boolean attribute that indicates whether a path was found or not, initially set to False. - self.found_path = False logger.debug("Chain of Calls Retriever - after documents_of_full_sources") - self.last_visited_parent_package_indexes = dict() - self.last_visited = dict() # Constructing a map of types and classes to their attributes/members/fields self.types_classes_fields_mapping = self.language_parser.parse_all_type_struct_class_to_fields(self.documents_of_types) # Create a data structure containing dict of key=(function_name@source_file),value = dict of @@ -155,36 +160,35 @@ def __group_docs_by_pkg(self) -> dict[str, list[Document]]: logger.debug("PROFILE: sort all docs (%d) elaps %.3f", len(self.documents), t1) return sort_docs - def __find_caller_function_dfs(self, document_function: Document, function_package: str) -> Document: + def __find_caller_function_dfs(self, document_function: Document, function_package: str, + session_context: SessionContext) -> Document: """ This method gets function and package as arguments, search and return a caller function of a package, if exists :param document_function: the document containing the function code and signature :param function_package: the package name containing the function + :param session_context: per-call session state for thread safety :return: a single document of a function that is calling document_function, or None if not found """ package_names = self.language_parser.get_package_names(document_function) direct_parents = list() # gets list of all direct parents of function for package_name in package_names: - list_of_packages = self.tree_dict.get(package_name) + list_of_packages = session_context.tree_dict.get(package_name) if list_of_packages: direct_parents.extend(list_of_packages[PARENTS_INDEX]) - # Add same package itself to search path. + # Add same package itself to search path. # direct_parents.extend([function_package]) - # gets list of documents to search in only from parents of function' package. + # gets list of documents to search in only from parents of function' package. function_name_to_search = self.language_parser.get_function_name(document_function) if function_name_to_search == self.language_parser.get_constructor_method_name(): function_name_to_search = self.language_parser.get_class_name_from_class_function(document_function) function_file_name = document_function.metadata.get('source') relevant_docs_to_search_in = list() - last_visited_package_index = (self.last_visited_parent_package_indexes - .get(calculate_hashable_string_for_function(function_file_name, - function_name_to_search), 0)) - package_exclusions = self.tree_dict.get(function_package)[EXCLUSIONS_INDEX] # Search for caller functions only at parents according to dependency tree. - for package in direct_parents[last_visited_package_index:]: + package_exclusions = session_context.tree_dict.get(function_package)[EXCLUSIONS_INDEX] + for package in direct_parents: sources_location_packages = True - if self.tree_dict.get(package)[PARENTS_INDEX][0] == ROOT_LEVEL_SENTINEL: + if session_context.tree_dict.get(package)[PARENTS_INDEX][0] == ROOT_LEVEL_SENTINEL: sources_location_packages = False possible_docs = self.get_possible_docs(function_name_to_search, package, @@ -192,7 +196,6 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa sources_location_packages, frozenset(), dict()) - # Collect all potential caller functions for doc in self.get_functions_for_package(package_name=package, documents=possible_docs, @@ -200,7 +203,7 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa function_to_search=function_name_to_search, callee_function_file_name=function_file_name): relevant_docs_to_search_in.append(doc) - # Perform the search only on the subset of potential caller functions + # Perform the search only on the subset of potential caller functions for doc in relevant_docs_to_search_in: function_is_being_called = self.language_parser.search_for_called_function(caller_function=doc, callee_function_name= @@ -225,9 +228,6 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa # match, and add it to exclusions so it will not consider it when backtracking in order to prevent cycles. if function_is_being_called: package_exclusions.append(doc) - # update index of last scanned package for backtracking - # hashed_value = calculate_hashable_string_for_function(function_file_name, function_name_to_search) - # self.last_visited_parent_package_indexes[hashed_value] = last_visited_package_index + package_index return doc # If didn't find a matching caller function document, returns None. @@ -267,24 +267,23 @@ def get_possible_docs(self, function_name_to_search: str, package: str, exclusio return [doc for doc in filter_1 if doc.page_content.__contains__(f"{function_name_to_search}(")] - def __find_caller_functions_bfs(self, document_function: Document, function_package: str) -> List[Document]: + def __find_caller_functions_bfs(self, document_function: Document, function_package: str, + session_context: SessionContext) -> List[Document]: """ This method gets function and package as arguments, search and return a caller function of a package, if exists :param document_function: the document containing the function code and signature :param function_package: the package name containing the function + :param session_context: per-call session state for thread safety :return: a list of documents of functions that are calling document_function """ total_start = time.time() direct_parents = list() - # gets list of all direct parents of function - list_of_packages = self.tree_dict.get(function_package) + list_of_packages = session_context.tree_dict.get(function_package) if list_of_packages is not None: direct_parents.extend(list_of_packages[0]) - # Add same package itself to search path. - # direct_parents.extend([function_package]) - # gets list of documents to search in only from parents of function' package. + # gets list of documents to search in only from parents of function' package. function_name_to_search = self.language_parser.get_function_name(document_function) function_file_name = document_function.metadata.get('source') relevant_docs_to_search_in = list() @@ -295,15 +294,6 @@ def __find_caller_functions_bfs(self, document_function: Document, function_pack for package in direct_parents: pkg_docs = self.sort_docs[package] for doc in pkg_docs: - # for doc in self.documents: - # is_doc_in_pkg = False - # for package in direct_parents: - # if self.language_parser.is_a_package(package, doc): - # is_doc_in_pkg = True - # break - # if not is_doc_in_pkg: - # continue - file_name = doc.metadata.get('source') if doc.metadata.get('state') == "invalid": continue @@ -311,9 +301,6 @@ def __find_caller_functions_bfs(self, document_function: Document, function_pack # check for same doc if (function_name_to_search == func_name) and (file_name == function_file_name): continue - # same function name different files ? - # if (function_name_to_search == func_name): - # logger.debug(f"same func name {function_name_to_search}") if func_name == "main": file_path = Path(function_file_name) @@ -322,7 +309,7 @@ def __find_caller_functions_bfs(self, document_function: Document, function_pack if self.language_parser.dir_name_for_3rd_party_packages() in path_parts: continue if doc.page_content.__contains__(f"{function_name_to_search}("): - last_visited = (self.last_visited.get( + last_visited = (session_context.last_visited.get( calculate_hashable_string_for_function(file_name, func_name), 0)) if last_visited == 0: found = self.language_parser.search_for_called_function( @@ -366,11 +353,10 @@ def __find_caller_functions_bfs(self, document_function: Document, function_pack def _breadth_first_search(self, matching_documents: List[Document], target_function_doc: Document, - current_package_name: str) -> tuple[List[Document], bool]: - # main loop. + current_package_name: str, + session_context: SessionContext) -> tuple[List[Document], bool]: file_counter = 0 q = deque() - self.last_visited.clear() q.append(target_function_doc) loop_start = time.time() while q: @@ -384,10 +370,10 @@ def _breadth_first_search(self, matching_documents: List[Document], target_funct function_file = target_doc.metadata.get('source') hashed_value = calculate_hashable_string_for_function(function_file, function_name) - if hashed_value in self.last_visited: + if hashed_value in session_context.last_visited: continue - self.last_visited[hashed_value] = 1 + session_context.last_visited[hashed_value] = 1 logger.debug("%d:file:%s, func_name : %s , pkg:%s queue len %d", file_counter, target_doc.metadata['source'], function_name, target_pkg, len(q)) @@ -395,7 +381,8 @@ def _breadth_first_search(self, matching_documents: List[Document], target_funct # Find a caller function and containing package in the dependency tree according to hierarchy sub_start = time.time() found_documents = self.__find_caller_functions_bfs(document_function=target_doc, - function_package=target_pkg) + function_package=target_pkg, + session_context=session_context) sub_elapsed = time.time() - sub_start logger.debug(f"[PROFILE] __find_caller_functions took {sub_elapsed:.3f} seconds") # If found, then add it to path @@ -404,7 +391,7 @@ def _breadth_first_search(self, matching_documents: List[Document], target_funct # If the function is in the application ( root package), then we finished and found such a path. if self.language_parser.is_root_package(doc_candidate): matching_documents.append(doc_candidate) - self.found_path = True + session_context.found_path = True break # Otherwise, we continue to search for callers for the current found function, in order to extend # the chain of calls and potentially find a path from application to the vulnerable @@ -412,40 +399,41 @@ def _breadth_first_search(self, matching_documents: List[Document], target_funct else: q.append(doc_candidate) - if self.found_path: + if session_context.found_path: break loop_elapsed = time.time() - loop_start logger.debug(f"[PROFILE] Main loop in get_relevant_C_documents took {loop_elapsed:.3f} seconds") # When the loop is finished, return list of documents ( path) and boolean indicating whether a path was # found or not. - logger.debug("get_relevant_documents: result %s", self.found_path) + logger.debug("get_relevant_documents: result %s", session_context.found_path) logger.debug("get_relevant_documents: docs %s", matching_documents) - return matching_documents, self.found_path + return matching_documents, session_context.found_path def _depth_first_search(self, matching_documents: List[Document], target_function_doc: Document, - current_package_name: str) -> tuple[List[Document], bool]: + current_package_name: str, + session_context: SessionContext) -> tuple[List[Document], bool]: """Execute depth-first search with backtracking strategy.""" end_loop = False - # main loop. while not end_loop: # Find a caller function and containing package in the dependency tree according to hierarchy found_document = self.__find_caller_function_dfs(document_function=target_function_doc, - function_package=current_package_name) + function_package=current_package_name, + session_context=session_context) # If found, then add it to path if found_document: matching_documents.append(found_document) # If the function is in the application ( root package), then we finished and found such a path. if self.language_parser.is_root_package(found_document): end_loop = True - self.found_path = True - # Otherwise, we continue to search for callers for the current found function, in order to extend - # the chain of calls and potentially find a path from application to the vulnerable - # function in input package + session_context.found_path = True + # Otherwise, we continue to search for callers for the current found function, in order to extend + # the chain of calls and potentially find a path from application to the vulnerable + # function in input package else: target_function_doc = found_document # extract package name from function document - current_package_name = self.__determine_doc_package_name(target_function_doc) + current_package_name = self.__determine_doc_package_name(target_function_doc, session_context) else: # end loop because didn't find a caller for initial function if len(matching_documents) == 1: @@ -455,12 +443,12 @@ def _depth_first_search(self, matching_documents: List[Document], target_functio else: dead_end_node = matching_documents.pop() # Excludes dead end function node from future searches, as it led to nowhere. - self.tree_dict.get(current_package_name)[EXCLUSIONS_INDEX].append(dead_end_node) + session_context.tree_dict.get(current_package_name)[EXCLUSIONS_INDEX].append(dead_end_node) target_function_doc = matching_documents[-1] - current_package_name = self.__determine_doc_package_name(target_function_doc) - # When the loop is finished, return list of documents ( path) and boolean indicating whether a path was - # found or not. - return matching_documents, self.found_path + current_package_name = self.__determine_doc_package_name(target_function_doc, session_context) + # When the loop is finished, return list of documents ( path) and boolean indicating whether a path was + # found or not. + return matching_documents, session_context.found_path # This method is the entry point for the chain_of_calls_retriever transitive search, it gets a query, # in the form of "package_name, function", and returns a 2-tuple of (list_of_documents_in_path, bool_result). @@ -468,7 +456,6 @@ def _depth_first_search(self, matching_documents: List[Document], target_functio # calls from application to input function in the input package. def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]: """Sync implementations for retriever.""" - self.found_path = False query = query.splitlines()[0].replace('"', '').replace("'", "").replace("`", "").strip() (package_name, function) = tuple(query.split(",")) class_name = None @@ -477,13 +464,15 @@ def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]: class_name, function = function.split(splitters[0]) found_package = False matching_documents = [] - for dependency in self.tree_dict.values(): + session_context = SessionContext(self.tree_dict) + + for dependency in session_context.tree_dict.values(): dependency[EXCLUSIONS_INDEX] = [] standard_libs_cache = StandardLibraryCache.get_instance() # If it's a standard library package, then skip checking the package in dependency tree. if not standard_libs_cache.is_standard_library(package_name, self.ecosystem): # Check if input package is in dependency tree - for package in self.tree_dict: + for package in session_context.tree_dict: if self.language_parser.is_same_package(package_name, package): package_name = package found_package = True @@ -492,12 +481,14 @@ def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]: if found_package: target_function_doc = self.__find_initial_function(function, package_name=package_name, documents=self.documents, - class_name=class_name) + class_name=class_name, + session_context=session_context) if not target_function_doc and self.language_parser.get_constructor_method_name(): target_function_doc = self.__find_initial_function(function_name=self.language_parser.get_constructor_method_name(), package_name=package_name, documents=self.documents, - class_name=function) + class_name=function, + session_context=session_context) # If not, there is a chance that the package is some standard library in the ecosystem. else: @@ -505,7 +496,7 @@ def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]: # vulnerable function in standard lib package of language. page_content = self.language_parser.get_dummy_function(function) if page_content is None: - return matching_documents, self.found_path + return matching_documents, session_context.found_path if class_name: page_content = page_content + f'\n{self.language_parser.get_comment_line_notation()}(class: {class_name})' target_function_doc = Document(page_content=page_content @@ -514,18 +505,17 @@ def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]: importing_docs = self.language_parser.document_imports_package(self.documents_of_full_sources, re.escape(package_name)) - root_package = [key for (key, value) in self.tree_dict.items() if ROOT_LEVEL_SENTINEL in value[0]] + root_package = [key for (key, value) in session_context.tree_dict.items() if ROOT_LEVEL_SENTINEL in value[0]] prefix_of_3rd_parties_libs = self.language_parser.dir_name_for_3rd_party_packages() - # find all parents ( all importing packages) of the ibput package so we'll have candidate pkgs to search in. parents = list({self.language_parser.get_package_names(doc)[1] for doc in importing_docs if doc.metadata['source'].startswith( prefix_of_3rd_parties_libs) and self.language_parser.get_package_names(doc)[1] - in self.tree_dict.keys()}) + in session_context.tree_dict.keys()}) for doc in importing_docs: if not doc.metadata.get('source').startswith(prefix_of_3rd_parties_libs): parents.append(root_package[0]) break - self.tree_dict[package_name] = [parents, []] + session_context.tree_dict[package_name] = [parents, []] end_loop = False current_package_name = package_name # If an initial document (that represents the vulnerable input function in the input package) was created @@ -538,26 +528,26 @@ def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]: end_loop = True logger.error("Cannot find initial function=%s, in package=%s", function, package_name) if end_loop: - return matching_documents, self.found_path + return matching_documents, session_context.found_path if self.language_parser.is_search_algo_dfs(): - matching_documents, self.found_path = self._depth_first_search( - matching_documents, target_function_doc, current_package_name) + matching_documents, session_context.found_path = self._depth_first_search( + matching_documents, target_function_doc, current_package_name, session_context) else: - matching_documents, self.found_path = self._breadth_first_search( - matching_documents, target_function_doc, current_package_name) + matching_documents, session_context.found_path = self._breadth_first_search( + matching_documents, target_function_doc, current_package_name, session_context) - # When the loop is finished, return list of documents ( path) and boolean indicating whether a path was - # found or not. - return matching_documents, self.found_path + return matching_documents, session_context.found_path - def __determine_doc_package_name(self, target_function_doc): + def __determine_doc_package_name(self, target_function_doc, + session_context: SessionContext): return [package_name for package_name in self.language_parser.get_package_names(target_function_doc) - if self.tree_dict.get(package_name, None) is not None][0] + if session_context.tree_dict.get(package_name, None) is not None][0] def __find_initial_function(self, function_name: str, package_name: str, documents: list[Document], - class_name: str = None) -> Document: + class_name: str = None, *, + session_context: SessionContext) -> Document: if self.language_parser.is_search_algo_dfs(): pkg_docs = documents @@ -569,16 +559,12 @@ def __find_initial_function(self, function_name: str, package_name: str, documen relevant_docs = [doc for doc in relevant_docs if doc.page_content.endswith( f'{self.language_parser.get_comment_line_notation()}(class: {class_name})')] - package_exclusions = self.tree_dict.get(package_name)[EXCLUSIONS_INDEX] - #for index, document in enumerate(get_functions_for_package(package_name, relevant_docs, language_parser)): - from itertools import chain + package_exclusions = session_context.tree_dict.get(package_name)[EXCLUSIONS_INDEX] for document in chain( self.get_functions_for_package(package_name, relevant_docs, sources_location_packages=True), self.get_functions_for_package(package_name, relevant_docs, sources_location_packages=False), ): - # document_function_calls_input_function = True if function_name.lower() == self.language_parser.get_function_name(document).lower(): - # if language_parser.search_for_called_function(document, callee_function=function_name): package_exclusions.append(document) return document From d4c3e5029e72739464eee8e5fdb9d435d85bcc47 Mon Sep 17 00:00:00 2001 From: Shimon Tanny Date: Wed, 25 Mar 2026 14:42:02 +0200 Subject: [PATCH 60/60] set tracing of tokens only when needed --- src/vuln_analysis/functions/cve_agent.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index adbbacff..04f795a9 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -482,6 +482,8 @@ async def observation_node(state: AgentState) -> AgentState: prune_messages = [] orig_estimated = estimated + span_trace_dict = {"comprehension_findings": code_findings.findings, "tool_outcome": code_findings.tool_outcome} + if estimated > config.context_window_token_limit and len(messages) > 3: prunable = messages[1:-2] for msg in prunable: @@ -493,12 +495,9 @@ async def observation_node(state: AgentState) -> AgentState: "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", len(prune_messages), estimated, config.context_window_token_limit, ) - span.set_output({ - "orig_estimated": orig_estimated, - "estimated": estimated, - "comprehension_findings": code_findings.findings, - "tool_outcome": code_findings.tool_outcome, - }) + span_trace_dict["orig_estimated"] = orig_estimated + span_trace_dict["estimated"] = estimated + span.set_output(span_trace_dict) cca_results = list(state.get("cca_results", [])) if tool_used == ToolNames.CALL_CHAIN_ANALYZER: stripped = tool_output_for_llm.strip().lstrip("([")