Skip to content

Commit 204e178

Browse files
russellwheatleydlarocquehsubox76tanzimfhjsoref
authored andcommitted
feat(ai): feature parity for AI package with Web SDK
---- Co-authored-by: Daniel La Rocque <[email protected]> Co-authored-by: hsubox76 <[email protected]> Co-authored-by: Tanzim Hossain <[email protected]> Co-authored-by: Josh Soref <[email protected]> Co-authored-by: DellaBitta <[email protected]>
1 parent f45b420 commit 204e178

File tree

13 files changed

+583
-139
lines changed

13 files changed

+583
-139
lines changed

packages/ai/__tests__/chat-session-helpers.test.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,43 @@ describe('chat-session-helpers', () => {
139139
],
140140
isValid: false,
141141
},
142+
{
143+
history: [
144+
{ role: 'user', parts: [{ text: 'hi' }] },
145+
{
146+
role: 'model',
147+
parts: [
148+
{ text: 'hi' },
149+
{
150+
text: 'thought about hi',
151+
thought: true,
152+
thoughtSignature: 'thought signature',
153+
},
154+
],
155+
},
156+
],
157+
isValid: true,
158+
},
159+
{
160+
history: [
161+
{
162+
role: 'user',
163+
parts: [{ text: 'hi', thought: true, thoughtSignature: 'sig' }],
164+
},
165+
{
166+
role: 'model',
167+
parts: [
168+
{ text: 'hi' },
169+
{
170+
text: 'thought about hi',
171+
thought: true,
172+
thoughtSignature: 'thought signature',
173+
},
174+
],
175+
},
176+
],
177+
isValid: false,
178+
},
142179
];
143180

144181
TCS.forEach(tc => {

packages/ai/__tests__/response-helpers.test.ts

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,21 @@
1515
* limitations under the License.
1616
*/
1717
import { describe, expect, it, jest, afterEach } from '@jest/globals';
18-
import { addHelpers, formatBlockErrorMessage } from '../lib/requests/response-helpers';
18+
import {
19+
addHelpers,
20+
formatBlockErrorMessage,
21+
handlePredictResponse,
22+
} from '../lib/requests/response-helpers';
1923

20-
import { BlockReason, Content, FinishReason, GenerateContentResponse } from '../lib/types';
24+
import {
25+
BlockReason,
26+
Content,
27+
FinishReason,
28+
GenerateContentResponse,
29+
ImagenInlineImage,
30+
ImagenGCSImage,
31+
} from '../lib/types';
32+
import { getMockResponse, BackendName } from './test-utils/mock-response';
2133

2234
const fakeResponseText: GenerateContentResponse = {
2335
candidates: [
@@ -31,6 +43,18 @@ const fakeResponseText: GenerateContentResponse = {
3143
],
3244
};
3345

46+
const fakeResponseThoughts: GenerateContentResponse = {
47+
candidates: [
48+
{
49+
index: 0,
50+
content: {
51+
role: 'model',
52+
parts: [{ text: 'Some text' }, { text: 'and some thoughts', thought: true }],
53+
},
54+
},
55+
],
56+
};
57+
3458
const functionCallPart1 = {
3559
functionCall: {
3660
name: 'find_theaters',
@@ -129,12 +153,14 @@ describe('response-helpers methods', () => {
129153
const enhancedResponse = addHelpers(fakeResponseText);
130154
expect(enhancedResponse.text()).toBe('Some text and some more text');
131155
expect(enhancedResponse.functionCalls()).toBeUndefined();
156+
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
132157
});
133158

134159
it('good response functionCall', () => {
135160
const enhancedResponse = addHelpers(fakeResponseFunctionCall);
136161
expect(enhancedResponse.text()).toBe('');
137162
expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]);
163+
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
138164
});
139165

140166
it('good response functionCalls', () => {
@@ -144,29 +170,41 @@ describe('response-helpers methods', () => {
144170
functionCallPart1.functionCall,
145171
functionCallPart2.functionCall,
146172
]);
173+
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
147174
});
148175

149176
it('good response text/functionCall', () => {
150177
const enhancedResponse = addHelpers(fakeResponseMixed1);
151178
expect(enhancedResponse.functionCalls()).toEqual([functionCallPart2.functionCall]);
152179
expect(enhancedResponse.text()).toBe('some text');
180+
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
153181
});
154182

155183
it('good response functionCall/text', () => {
156184
const enhancedResponse = addHelpers(fakeResponseMixed2);
157185
expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]);
158186
expect(enhancedResponse.text()).toBe('some text');
187+
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
159188
});
160189

161190
it('good response text/functionCall/text', () => {
162191
const enhancedResponse = addHelpers(fakeResponseMixed3);
163192
expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]);
164193
expect(enhancedResponse.text()).toBe('some text and more text');
194+
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
195+
});
196+
197+
it('good response text/thought', () => {
198+
const enhancedResponse = addHelpers(fakeResponseThoughts);
199+
expect(enhancedResponse.text()).toBe('Some text');
200+
expect(enhancedResponse.thoughtSummary()).toBe('and some thoughts');
201+
expect(enhancedResponse.functionCalls()).toBeUndefined();
165202
});
166203

167204
it('bad response safety', () => {
168205
const enhancedResponse = addHelpers(badFakeResponse);
169206
expect(() => enhancedResponse.text()).toThrow('SAFETY');
207+
expect(() => enhancedResponse.thoughtSummary()).toThrow('SAFETY');
170208
});
171209
});
172210

@@ -233,4 +271,80 @@ describe('response-helpers methods', () => {
233271
expect(message).toContain('Candidate was blocked due to SAFETY: unsafe candidate');
234272
});
235273
});
274+
275+
describe('handlePredictResponse', () => {
276+
it('returns base64 images', async () => {
277+
const mockResponse = getMockResponse(
278+
BackendName.VertexAI,
279+
'unary-success-generate-images-base64.json',
280+
) as Response;
281+
const res = await handlePredictResponse<ImagenInlineImage>(mockResponse);
282+
expect(res.filteredReason).toBeUndefined();
283+
expect(res.images.length).toBe(4);
284+
res.images.forEach(image => {
285+
expect(image.mimeType).toBe('image/png');
286+
expect(image.bytesBase64Encoded.length).toBeGreaterThan(0);
287+
});
288+
});
289+
290+
it('returns GCS images', async () => {
291+
const mockResponse = getMockResponse(
292+
BackendName.VertexAI,
293+
'unary-success-generate-images-gcs.json',
294+
) as Response;
295+
const res = await handlePredictResponse<ImagenGCSImage>(mockResponse);
296+
expect(res.filteredReason).toBeUndefined();
297+
expect(res.images.length).toBe(4);
298+
res.images.forEach((image, i) => {
299+
expect(image.mimeType).toBe('image/jpeg');
300+
expect(image.gcsURI).toBe(
301+
`gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_${i}.jpg`,
302+
);
303+
});
304+
});
305+
306+
it('has filtered reason and no images if all images were filtered', async () => {
307+
const mockResponse = getMockResponse(
308+
BackendName.VertexAI,
309+
'unary-failure-generate-images-all-filtered.json',
310+
) as Response;
311+
const res = await handlePredictResponse<ImagenInlineImage>(mockResponse);
312+
expect(res.filteredReason).toBe(
313+
"Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback. Support codes: 39322892, 29310472",
314+
);
315+
expect(res.images.length).toBe(0);
316+
});
317+
318+
it('has filtered reason and no images if all base64 images were filtered', async () => {
319+
const mockResponse = getMockResponse(
320+
BackendName.VertexAI,
321+
'unary-failure-generate-images-base64-some-filtered.json',
322+
) as Response;
323+
const res = await handlePredictResponse<ImagenInlineImage>(mockResponse);
324+
expect(res.filteredReason).toBe(
325+
'Your current safety filter threshold filtered out 2 generated images. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback.',
326+
);
327+
expect(res.images.length).toBe(2);
328+
res.images.forEach(image => {
329+
expect(image.mimeType).toBe('image/png');
330+
expect(image.bytesBase64Encoded.length).toBeGreaterThan(0);
331+
});
332+
});
333+
334+
it('has filtered reason and no images if all GCS images were filtered', async () => {
335+
const mockResponse = getMockResponse(
336+
BackendName.VertexAI,
337+
'unary-failure-generate-images-gcs-some-filtered.json',
338+
) as Response;
339+
const res = await handlePredictResponse<ImagenGCSImage>(mockResponse);
340+
expect(res.filteredReason).toBe(
341+
'Your current safety filter threshold filtered out 2 generated images. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback.',
342+
);
343+
expect(res.images.length).toBe(2);
344+
res.images.forEach(image => {
345+
expect(image.mimeType).toBe('image/jpeg');
346+
expect(image.gcsURI.length).toBeGreaterThan(0);
347+
});
348+
});
349+
});
236350
});

packages/ai/lib/constants.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,12 @@ export const AI_TYPE = 'AI';
2121

2222
export const DEFAULT_LOCATION = 'us-central1';
2323

24-
export const DEFAULT_BASE_URL = 'https://firebasevertexai.googleapis.com';
24+
export const DEFAULT_DOMAIN = 'firebasevertexai.googleapis.com';
2525

26-
// This is the default API version for the VertexAI API. At some point, should be able to change when the feature becomes available.
27-
// `v1beta` & `stable` available: https://cloud.google.com/vertex-ai/docs/reference#versions
2826
export const DEFAULT_API_VERSION = 'v1beta';
2927

3028
export const PACKAGE_VERSION = version;
3129

3230
export const LANGUAGE_TAG = 'gl-rn';
3331

34-
// Timeout is 180s by default
3532
export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000;

packages/ai/lib/index.ts

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import './polyfills';
1919
import { getApp, ReactNativeFirebase } from '@react-native-firebase/app';
20-
import { GoogleAIBackend, VertexAIBackend } from './backend';
20+
import { Backend, GoogleAIBackend, VertexAIBackend } from './backend';
2121
import { AIErrorCode, ModelParams, RequestOptions } from './types';
2222
import { AI, AIOptions, ImagenModelParams } from './public-types';
2323
import { AIError } from './errors';
@@ -27,8 +27,9 @@ import { AIModel, ImagenModel } from './models';
2727
export * from './public-types';
2828
export { ChatSession } from './methods/chat-session';
2929
export * from './requests/schema-builder';
30-
export { GoogleAIBackend, VertexAIBackend } from './backend';
31-
export { GenerativeModel, AIError, AIModel };
30+
export { ImagenImageFormat } from './requests/imagen-image-format';
31+
export { Backend, GoogleAIBackend, VertexAIBackend } from './backend';
32+
export { GenerativeModel, AIError, AIModel, ImagenModel };
3233

3334
/**
3435
* Returns the default {@link AI} instance that is associated with the provided
@@ -58,16 +59,22 @@ export { GenerativeModel, AIError, AIModel };
5859
*
5960
* @public
6061
*/
61-
export function getAI(
62-
app: ReactNativeFirebase.FirebaseApp = getApp(),
63-
options: AIOptions = { backend: new GoogleAIBackend() },
64-
): AI {
62+
export function getAI(app: ReactNativeFirebase.FirebaseApp = getApp(), options?: AIOptions): AI {
63+
const backend: Backend = options?.backend ?? new GoogleAIBackend();
64+
65+
const finalOptions: Omit<AIOptions, 'backend'> = {
66+
useLimitedUseAppCheckTokens: options?.useLimitedUseAppCheckTokens ?? false,
67+
appCheck: options?.appCheck || null,
68+
auth: options?.auth || null,
69+
};
70+
6571
return {
6672
app,
67-
backend: options.backend,
68-
location: (options.backend as VertexAIBackend)?.location || '',
69-
appCheck: options.appCheck || null,
70-
auth: options.auth || null,
73+
backend,
74+
options: finalOptions,
75+
location: (backend as VertexAIBackend)?.location || '',
76+
appCheck: options?.appCheck || null,
77+
auth: options?.auth || null,
7178
} as AI;
7279
}
7380

packages/ai/lib/methods/chat-session-helpers.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ const VALID_PART_FIELDS: Array<keyof Part> = [
2525
'inlineData',
2626
'functionCall',
2727
'functionResponse',
28+
'thought',
29+
'thoughtSignature',
2830
];
2931

3032
const VALID_PARTS_PER_ROLE: { [key in Role]: Array<keyof Part> } = {
3133
user: ['text', 'inlineData'],
3234
function: ['functionResponse'],
33-
model: ['text', 'functionCall'],
35+
model: ['text', 'functionCall', 'thought', 'thoughtSignature'],
3436
// System instructions shouldn't be in history anyway.
3537
system: ['text'],
3638
};
@@ -78,6 +80,8 @@ export function validateChatHistory(history: Content[]): void {
7880
inlineData: 0,
7981
functionCall: 0,
8082
functionResponse: 0,
83+
thought: 0,
84+
thoughtSignature: 0,
8185
};
8286

8387
for (const part of parts) {

0 commit comments

Comments
 (0)