Skip to content

Commit bc0bbac

Browse files
committed
Add tests for ChatBedrockConverse invocation parameters
1 parent 72f62b8 commit bc0bbac

File tree

1 file changed

+308
-0
lines changed

1 file changed

+308
-0
lines changed
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
import { describe, expect, test, vi } from "vitest";
2+
import { HumanMessage, SystemMessage } from "@langchain/core/messages";
3+
import { ChatBedrockConverse } from "../chat_models.js";
4+
import type {
5+
ConverseCommandInput,
6+
ConverseStreamCommandInput,
7+
} from "@aws-sdk/client-bedrock-runtime";
8+
9+
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
10+
class ConverseCommand {
11+
input: ConverseCommandInput;
12+
static lastInput: ConverseCommandInput;
13+
constructor(input: ConverseCommandInput) {
14+
this.input = input;
15+
ConverseCommand.lastInput = input;
16+
}
17+
}
18+
class ConverseStreamCommand {
19+
input: ConverseStreamCommandInput;
20+
static lastInput: ConverseStreamCommandInput;
21+
constructor(input: ConverseStreamCommandInput) {
22+
this.input = input;
23+
ConverseStreamCommand.lastInput = input;
24+
}
25+
}
26+
class BedrockRuntimeClient {
27+
async send(command: unknown) {
28+
// Non-stream path
29+
if (
30+
(command as { constructor?: unknown })?.constructor === ConverseCommand
31+
) {
32+
return {
33+
output: {
34+
message: {
35+
role: "assistant",
36+
content: [{ text: "Response" }],
37+
},
38+
},
39+
usage: {
40+
inputTokens: 10,
41+
outputTokens: 5,
42+
totalTokens: 15,
43+
},
44+
};
45+
}
46+
// Stream path
47+
if (
48+
(command as { constructor?: unknown })?.constructor ===
49+
ConverseStreamCommand
50+
) {
51+
return {
52+
stream: (async function* () {
53+
yield {
54+
contentBlockDelta: {
55+
contentBlockIndex: 0,
56+
delta: { text: "Response" },
57+
},
58+
};
59+
yield {
60+
metadata: {
61+
usage: {
62+
inputTokens: 10,
63+
outputTokens: 5,
64+
totalTokens: 15,
65+
},
66+
},
67+
};
68+
})(),
69+
};
70+
}
71+
return {};
72+
}
73+
}
74+
return {
75+
BedrockRuntimeClient,
76+
ConverseCommand,
77+
ConverseStreamCommand,
78+
};
79+
});
80+
81+
import {
82+
ConverseCommand,
83+
ConverseStreamCommand,
84+
} from "@aws-sdk/client-bedrock-runtime";
85+
86+
describe("ChatBedrockConverse invocationParams", () => {
87+
const baseConstructorArgs = {
88+
region: "us-east-1",
89+
credentials: {
90+
secretAccessKey: "test-secret",
91+
accessKeyId: "test-key",
92+
},
93+
model: "anthropic.claude-3-sonnet-20240229-v1:0",
94+
};
95+
96+
describe("inferenceConfig conditional logic", () => {
97+
test("covers all inferenceConfig scenarios compactly", () => {
98+
const cases: Array<{
99+
name: string;
100+
ctor?: Partial<ConstructorParameters<typeof ChatBedrockConverse>[0]>;
101+
opts?: Parameters<ChatBedrockConverse["invocationParams"]>[0];
102+
expectDefined: boolean;
103+
expectValues?: Partial<{
104+
maxTokens: number;
105+
temperature: number;
106+
topP: number;
107+
stopSequences: string[];
108+
}>;
109+
expectUndefinedKeys?: Array<
110+
"maxTokens" | "temperature" | "topP" | "stopSequences"
111+
>;
112+
}> = [
113+
{
114+
name: "undefined when no inference values are set",
115+
expectDefined: false,
116+
},
117+
{
118+
name: "includes only maxTokens when set",
119+
ctor: { maxTokens: 100 },
120+
expectDefined: true,
121+
expectValues: { maxTokens: 100 },
122+
expectUndefinedKeys: ["temperature", "topP", "stopSequences"],
123+
},
124+
{
125+
name: "includes only temperature when set",
126+
ctor: { temperature: 0.7 },
127+
expectDefined: true,
128+
expectValues: { temperature: 0.7 },
129+
expectUndefinedKeys: ["maxTokens", "topP", "stopSequences"],
130+
},
131+
{
132+
name: "includes only topP when set",
133+
ctor: { topP: 0.9 },
134+
expectDefined: true,
135+
expectValues: { topP: 0.9 },
136+
expectUndefinedKeys: ["maxTokens", "temperature", "stopSequences"],
137+
},
138+
{
139+
name: "includes stopSequences when provided",
140+
opts: { stop: ["END", "STOP"] },
141+
expectDefined: true,
142+
expectValues: { stopSequences: ["END", "STOP"] },
143+
expectUndefinedKeys: ["maxTokens", "temperature", "topP"],
144+
},
145+
{
146+
name: "includes all values when all are set",
147+
ctor: { maxTokens: 200, temperature: 0.5, topP: 0.95 },
148+
opts: { stop: ["END"] },
149+
expectDefined: true,
150+
expectValues: {
151+
maxTokens: 200,
152+
temperature: 0.5,
153+
topP: 0.95,
154+
stopSequences: ["END"],
155+
},
156+
},
157+
{
158+
name: "undefined when stop sequences is empty array",
159+
opts: { stop: [] },
160+
expectDefined: false,
161+
},
162+
];
163+
164+
for (const c of cases) {
165+
const model = new ChatBedrockConverse({
166+
...baseConstructorArgs,
167+
...(c.ctor ?? {}),
168+
});
169+
const params = model.invocationParams(c.opts);
170+
if (!c.expectDefined) {
171+
expect(params.inferenceConfig).toBeUndefined();
172+
} else {
173+
expect(params.inferenceConfig).toBeDefined();
174+
if (c.expectValues?.maxTokens !== undefined) {
175+
expect(params.inferenceConfig?.maxTokens).toBe(
176+
c.expectValues.maxTokens
177+
);
178+
}
179+
if (c.expectValues?.temperature !== undefined) {
180+
expect(params.inferenceConfig?.temperature).toBe(
181+
c.expectValues.temperature
182+
);
183+
}
184+
if (c.expectValues?.topP !== undefined) {
185+
expect(params.inferenceConfig?.topP).toBe(c.expectValues.topP);
186+
}
187+
if (c.expectValues?.stopSequences !== undefined) {
188+
expect(params.inferenceConfig?.stopSequences).toEqual(
189+
c.expectValues.stopSequences
190+
);
191+
}
192+
const ic = params.inferenceConfig as Record<
193+
"maxTokens" | "temperature" | "topP" | "stopSequences",
194+
unknown
195+
>;
196+
for (const k of c.expectUndefinedKeys ?? []) {
197+
expect(ic?.[k]).toBeUndefined();
198+
}
199+
}
200+
}
201+
});
202+
});
203+
204+
describe("system parameter conditional logic", () => {
205+
test.each([
206+
{
207+
name: "no system messages",
208+
messages: [new HumanMessage("Hello")],
209+
expectedSystem: { present: false, length: 0, texts: [] as string[] },
210+
},
211+
{
212+
name: "one system message",
213+
messages: [
214+
new SystemMessage("You are a helpful assistant."),
215+
new HumanMessage("Hello"),
216+
],
217+
expectedSystem: {
218+
present: true,
219+
length: 1,
220+
texts: ["You are a helpful assistant."],
221+
},
222+
},
223+
{
224+
name: "multiple system messages",
225+
messages: [
226+
new SystemMessage("You are a helpful assistant."),
227+
new SystemMessage("Be concise in your responses."),
228+
new HumanMessage("Hello"),
229+
],
230+
expectedSystem: {
231+
present: true,
232+
length: 2,
233+
texts: [
234+
"You are a helpful assistant.",
235+
"Be concise in your responses.",
236+
],
237+
},
238+
},
239+
])(
240+
"invoke should handle system parameter: $name",
241+
async ({ messages, expectedSystem }) => {
242+
const model = new ChatBedrockConverse(baseConstructorArgs);
243+
await model.invoke(messages);
244+
const input = Reflect.get(
245+
ConverseCommand,
246+
"lastInput"
247+
) as ConverseCommandInput;
248+
if (expectedSystem.present) {
249+
expect(input).toHaveProperty("system");
250+
const system = input.system as NonNullable<typeof input.system>;
251+
expect(system).toHaveLength(expectedSystem.length);
252+
expectedSystem.texts.forEach((t, i) => {
253+
expect(system[i]).toHaveProperty("text", t);
254+
});
255+
} else {
256+
expect(input).not.toHaveProperty("system");
257+
}
258+
}
259+
);
260+
});
261+
262+
describe("stream method system parameter logic", () => {
263+
test.each([
264+
{
265+
name: "no system messages",
266+
messages: [new HumanMessage("Hello")],
267+
expectedPresent: false,
268+
expectedLength: 0,
269+
expectedTexts: [] as string[],
270+
},
271+
{
272+
name: "one system message",
273+
messages: [
274+
new SystemMessage("You are a helpful assistant."),
275+
new HumanMessage("Hello"),
276+
],
277+
expectedPresent: true,
278+
expectedLength: 1,
279+
expectedTexts: ["You are a helpful assistant."],
280+
},
281+
])(
282+
"stream should handle system parameter: $name",
283+
async ({ messages, expectedPresent, expectedLength, expectedTexts }) => {
284+
const model = new ChatBedrockConverse(baseConstructorArgs);
285+
const stream = await model.stream(messages);
286+
const chunks = [];
287+
for await (const chunk of stream) {
288+
chunks.push(chunk);
289+
}
290+
expect(chunks.length).toBeGreaterThan(0);
291+
const input = Reflect.get(
292+
ConverseStreamCommand,
293+
"lastInput"
294+
) as ConverseStreamCommandInput;
295+
if (expectedPresent) {
296+
expect(input).toHaveProperty("system");
297+
const system = input.system as NonNullable<typeof input.system>;
298+
expect(system).toHaveLength(expectedLength);
299+
expectedTexts.forEach((t: string, i: number) => {
300+
expect(system[i]).toHaveProperty("text", t);
301+
});
302+
} else {
303+
expect(input).not.toHaveProperty("system");
304+
}
305+
}
306+
);
307+
});
308+
});

0 commit comments

Comments
 (0)