Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 87 additions & 9 deletions app/services/grok/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
OpenAI 响应格式处理器
"""
import re
import time
import uuid
import random
Expand Down Expand Up @@ -37,9 +38,64 @@ def _build_video_poster_preview(video_url: str, thumbnail_url: str = "") -> str:
</a>'''


_GROK_RENDER_TAG_RE = re.compile(
r"<grok:render\b[^>]*>(?:<argument\b[^>]*>[^<]*</argument>)*\s*</grok:render>"
)


def _extract_web_sources(mr: dict) -> list[dict]:
"""Extract web search sources from modelResponse fields."""
sources: list[dict] = []
seen: set[str] = set()

for item in mr.get("citedWebSearchResults", mr.get("webSearchResults", [])):
if isinstance(item, dict):
url = (item.get("url") or "").strip()
if url and url not in seen:
seen.add(url)
sources.append({
"title": (item.get("title") or "").strip(),
"url": url,
})

if not sources:
for raw in mr.get("cardAttachmentsJson", []):
try:
card = orjson.loads(raw) if isinstance(raw, (str, bytes)) else raw
except Exception:
continue
if not isinstance(card, dict):
continue
url = (card.get("url") or "").strip()
if url and url not in seen:
seen.add(url)
sources.append({
"title": (card.get("title") or "").strip(),
"url": url,
})

return sources


def _strip_grok_render_tags(text: str) -> str:
"""Remove <grok:render> citation placeholder tags from text."""
return _GROK_RENDER_TAG_RE.sub("", text)


def _format_sources_as_references(sources: list[dict]) -> str:
"""Format sources as a Markdown references section."""
if not sources:
return ""
lines = ["\n\n## References\n"]
for i, s in enumerate(sources, 1):
title = s.get("title") or s["url"]
lines.append(f"{i}. [{title}]({s['url']})")
return "\n".join(lines)


class BaseProcessor:
"""基础处理器"""

def __init__(self, model: str, token: str = ""):
self.model = model
self.token = token
Expand Down Expand Up @@ -109,7 +165,7 @@ def _sse(self, content: str = "", role: str = None, finish: str = None) -> str:

class StreamProcessor(BaseProcessor):
"""流式响应处理器"""

def __init__(self, model: str, token: str = "", think: bool = None):
super().__init__(model, token)
self.response_id: Optional[str] = None
Expand All @@ -118,7 +174,8 @@ def __init__(self, model: str, token: str = "", think: bool = None):
self.role_sent: bool = False
self.filter_tags = get_config("grok.filter_tags", [])
self.image_format = get_config("app.image_format", "url")

self.web_sources: list[dict] = []

if think is None:
self.show_think = get_config("grok.thinking", False)
else:
Expand Down Expand Up @@ -166,7 +223,11 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
yield self._sse(msg + "\n")
yield self._sse("</think>\n")
self.think_opened = False


# 提取 web search sources
if sources := _extract_web_sources(mr):
self.web_sources = sources

# 处理生成的图片
for url in mr.get("generatedImageUrls", []):
parts = url.split("/")
Expand All @@ -190,11 +251,19 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N

# 普通 token
if (token := resp.get("token")) is not None:
if token and not (self.filter_tags and any(t in token for t in self.filter_tags)):
if token:
# 剥离 grok:render 标签而非丢弃整个 token
if self.filter_tags and any(t in token for t in self.filter_tags):
token = _strip_grok_render_tags(token)
if not token.strip():
continue
yield self._sse(token)

if self.think_opened:
yield self._sse("</think>\n")
# 输出 web search references
if self.web_sources:
yield self._sse(_format_sources_as_references(self.web_sources))
yield self._sse(finish="stop")
yield "data: [DONE]\n\n"
except Exception as e:
Expand All @@ -216,7 +285,8 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
response_id = ""
fingerprint = ""
content = ""

web_sources: list[dict] = []

try:
async for line in response:
if not line:
Expand All @@ -234,7 +304,10 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
if mr := resp.get("modelResponse"):
response_id = mr.get("responseId", "")
content = mr.get("message", "")


# 提取 web search sources
web_sources = _extract_web_sources(mr)

if urls := mr.get("generatedImageUrls"):
content += "\n"
for url in urls:
Expand All @@ -255,11 +328,16 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:

if (meta := mr.get("metadata", {})).get("llm_info", {}).get("modelHash"):
fingerprint = meta["llm_info"]["modelHash"]

except Exception as e:
logger.error(f"Collect processing error: {e}", extra={"model": self.model})
finally:
await self.close()

# 清理 grok:render 标签并附加 web search references
content = _strip_grok_render_tags(content)
if web_sources:
content += _format_sources_as_references(web_sources)

return {
"id": response_id,
Expand Down
57 changes: 56 additions & 1 deletion src/grok/processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ export function createOpenAiStreamFromGrokNdjson(
let thinkingFinished = false;
let videoProgressStarted = false;
let lastVideoProgress = -1;
const collectedSources = new Map<string, { title: string; url: string }>();

let buffer = "";

Expand Down Expand Up @@ -253,6 +254,17 @@ export function createOpenAiStreamFromGrokNdjson(
const userRespModel = grok.userResponse?.model;
if (typeof userRespModel === "string" && userRespModel.trim()) currentModel = userRespModel.trim();

// Collect web search sources early (before rawToken checks that may skip this frame)
if (grok.webSearchResults?.results && Array.isArray(grok.webSearchResults.results)) {
for (const r of grok.webSearchResults.results) {
const url = typeof r.url === "string" ? r.url.trim() : "";
if (url && !collectedSources.has(url)) {
const title = typeof r.title === "string" ? r.title.trim() : url;
collectedSources.set(url, { title, url });
}
}
}

// Video generation stream
const videoResp = grok.streamingVideoGenerationResponse;
if (videoResp) {
Expand Down Expand Up @@ -337,7 +349,15 @@ export function createOpenAiStreamFromGrokNdjson(
if (typeof rawToken !== "string" || !rawToken) continue;
let token = rawToken;

if (filteredTags.some((t) => token.includes(t))) continue;
// Strip filtered tags from token instead of dropping entire token
for (const t of filteredTags) {
if (token.includes(t)) {
token = token.replace(new RegExp(`<[^>]*${t.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")}[^>]*>(?:<[^>]*>[^<]*<\/[^>]*>)*\\s*<\/[^>]*>`, "g"), "");
token = token.replace(new RegExp(`<[^>]*${t.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")}[^>]*\\/?>`, "g"), "");
}
}
token = token.trim();
if (!token) continue;

const currentIsThinking = Boolean(grok.isThinking);
const messageTag = grok.messageTag;
Expand Down Expand Up @@ -382,6 +402,17 @@ export function createOpenAiStreamFromGrokNdjson(
}
}

// Append collected web search sources as References section
if (collectedSources.size > 0) {
const refLines = ["\n\n## References\n"];
let idx = 1;
for (const s of collectedSources.values()) {
refLines.push(`${idx}. [${s.title}](${s.url})`);
idx++;
}
controller.enqueue(encoder.encode(makeChunk(id, created, currentModel, refLines.join("\n"))));
}

controller.enqueue(encoder.encode(makeChunk(id, created, currentModel, "", "stop")));
controller.enqueue(encoder.encode(makeDone()));
if (opts.onFinish) await opts.onFinish({ status: finalStatus, duration: (Date.now() - startTime) / 1000 });
Expand Down Expand Up @@ -417,6 +448,8 @@ export async function parseOpenAiFromGrokNdjson(

let content = "";
let model = requestedModel;
const collectedSources = new Map<string, { title: string; url: string }>();

for (const line of lines) {
let data: GrokNdjson;
try {
Expand All @@ -431,6 +464,17 @@ export async function parseOpenAiFromGrokNdjson(
const grok = (data as any).result?.response;
if (!grok) continue;

// Collect web search sources from all frames
if (grok.webSearchResults?.results && Array.isArray(grok.webSearchResults.results)) {
for (const r of grok.webSearchResults.results) {
const url = typeof r.url === "string" ? r.url.trim() : "";
if (url && !collectedSources.has(url)) {
const title = typeof r.title === "string" ? r.title.trim() : url;
collectedSources.set(url, { title, url });
}
}
}

const videoResp = grok.streamingVideoGenerationResponse;
if (videoResp?.videoUrl && typeof videoResp.videoUrl === "string") {
const videoPath = encodeAssetPath(videoResp.videoUrl);
Expand Down Expand Up @@ -476,6 +520,17 @@ export async function parseOpenAiFromGrokNdjson(
break;
}

// Append collected web search sources as References section
if (collectedSources.size > 0) {
const refLines = ["\n\n## References\n"];
let idx = 1;
for (const s of collectedSources.values()) {
refLines.push(`${idx}. [${s.title}](${s.url})`);
idx++;
}
content += refLines.join("\n");
}

return {
id: `chatcmpl-${crypto.randomUUID()}`,
object: "chat.completion",
Expand Down