Skip to content

Commit 06ab5c4

Browse files
hsubox76dlarocquerachelsaunders
authored
feat(ai): Add prefer_in_cloud option for inference mode (#9236)
Co-authored-by: Daniel La Rocque <[email protected]> Co-authored-by: rachelsaunders <[email protected]>
1 parent 55f3f83 commit 06ab5c4

File tree

10 files changed

+650
-47
lines changed

10 files changed

+650
-47
lines changed

.changeset/feat-prefer-in-cloud.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"@firebase/ai": minor
3+
"firebase": minor
4+
---
5+
6+
Added a new `InferenceMode` option for the hybrid on-device capability: `prefer_in_cloud`. When this mode is selected, the SDK will attempt to use a cloud-hosted model first. If the call to the cloud-hosted model fails with a network-related error, the SDK will fall back to the on-device model, if it's available.

common/api-review/ai.api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,7 @@ export const InferenceMode: {
727727
readonly PREFER_ON_DEVICE: "prefer_on_device";
728728
readonly ONLY_ON_DEVICE: "only_on_device";
729729
readonly ONLY_IN_CLOUD: "only_in_cloud";
730+
readonly PREFER_IN_CLOUD: "prefer_in_cloud";
730731
};
731732

732733
// @public

docs-devsite/ai.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,13 +624,16 @@ ImagenSafetyFilterLevel: {
624624

625625
<b>(EXPERIMENTAL)</b> Determines whether inference happens on-device or in-cloud.
626626

627+
<b>PREFER\_ON\_DEVICE:</b> Attempt to make inference calls using an on-device model. If on-device inference is not available, the SDK will fall back to using a cloud-hosted model. <br/> <b>ONLY\_ON\_DEVICE:</b> Only attempt to make inference calls using an on-device model. The SDK will not fall back to a cloud-hosted model. If on-device inference is not available, inference methods will throw. <br/> <b>ONLY\_IN\_CLOUD:</b> Only attempt to make inference calls using a cloud-hosted model. The SDK will not fall back to an on-device model. <br/> <b>PREFER\_IN\_CLOUD:</b> Attempt to make inference calls to a cloud-hosted model. If not available, the SDK will fall back to an on-device model.
628+
627629
<b>Signature:</b>
628630

629631
```typescript
630632
InferenceMode: {
631633
readonly PREFER_ON_DEVICE: "prefer_on_device";
632634
readonly ONLY_ON_DEVICE: "only_on_device";
633635
readonly ONLY_IN_CLOUD: "only_in_cloud";
636+
readonly PREFER_IN_CLOUD: "prefer_in_cloud";
634637
}
635638
```
636639

packages/ai/src/methods/count-tokens.test.ts

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -196,24 +196,16 @@ describe('countTokens()', () => {
196196
);
197197
});
198198
});
199-
it('on-device', async () => {
200-
const chromeAdapter = fakeChromeAdapter;
201-
const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true);
202-
const mockResponse = getMockResponse(
203-
'vertexAI',
204-
'unary-success-total-tokens.json'
205-
);
206-
const countTokensStub = stub(chromeAdapter, 'countTokens').resolves(
207-
mockResponse as Response
199+
it('throws if mode is ONLY_ON_DEVICE', async () => {
200+
const chromeAdapter = new ChromeAdapterImpl(
201+
// @ts-expect-error
202+
undefined,
203+
InferenceMode.ONLY_ON_DEVICE
208204
);
209-
const result = await countTokens(
210-
fakeApiSettings,
211-
'model',
212-
fakeRequestParams,
213-
chromeAdapter
205+
await expect(
206+
countTokens(fakeApiSettings, 'model', fakeRequestParams, chromeAdapter)
207+
).to.be.rejectedWith(
208+
/countTokens\(\) is not supported for on-device models/
214209
);
215-
expect(result.totalTokens).eq(6);
216-
expect(isAvailableStub).to.be.called;
217-
expect(countTokensStub).to.be.calledWith(fakeRequestParams);
218210
});
219211
});

packages/ai/src/methods/count-tokens.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,20 @@
1515
* limitations under the License.
1616
*/
1717

18+
import { AIError } from '../errors';
1819
import {
1920
CountTokensRequest,
2021
CountTokensResponse,
21-
RequestOptions
22+
InferenceMode,
23+
RequestOptions,
24+
AIErrorCode
2225
} from '../types';
2326
import { Task, makeRequest } from '../requests/request';
2427
import { ApiSettings } from '../types/internal';
2528
import * as GoogleAIMapper from '../googleai-mappers';
2629
import { BackendType } from '../public-types';
2730
import { ChromeAdapter } from '../types/chrome-adapter';
31+
import { ChromeAdapterImpl } from './chrome-adapter';
2832

2933
export async function countTokensOnCloud(
3034
apiSettings: ApiSettings,
@@ -57,9 +61,13 @@ export async function countTokens(
5761
chromeAdapter?: ChromeAdapter,
5862
requestOptions?: RequestOptions
5963
): Promise<CountTokensResponse> {
60-
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
61-
return (await chromeAdapter.countTokens(params)).json();
64+
if (
65+
(chromeAdapter as ChromeAdapterImpl)?.mode === InferenceMode.ONLY_ON_DEVICE
66+
) {
67+
throw new AIError(
68+
AIErrorCode.UNSUPPORTED,
69+
'countTokens() is not supported for on-device models.'
70+
);
6271
}
63-
6472
return countTokensOnCloud(apiSettings, model, params, requestOptions);
6573
}

packages/ai/src/methods/generate-content.ts

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import { ApiSettings } from '../types/internal';
2929
import * as GoogleAIMapper from '../googleai-mappers';
3030
import { BackendType } from '../public-types';
3131
import { ChromeAdapter } from '../types/chrome-adapter';
32+
import { callCloudOrDevice } from '../requests/hybrid-helpers';
3233

3334
async function generateContentStreamOnCloud(
3435
apiSettings: ApiSettings,
@@ -56,17 +57,13 @@ export async function generateContentStream(
5657
chromeAdapter?: ChromeAdapter,
5758
requestOptions?: RequestOptions
5859
): Promise<GenerateContentStreamResult> {
59-
let response;
60-
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
61-
response = await chromeAdapter.generateContentStream(params);
62-
} else {
63-
response = await generateContentStreamOnCloud(
64-
apiSettings,
65-
model,
66-
params,
67-
requestOptions
68-
);
69-
}
60+
const response = await callCloudOrDevice(
61+
params,
62+
chromeAdapter,
63+
() => chromeAdapter!.generateContentStream(params),
64+
() =>
65+
generateContentStreamOnCloud(apiSettings, model, params, requestOptions)
66+
);
7067
return processStream(response, apiSettings); // TODO: Map streaming responses
7168
}
7269

@@ -96,17 +93,12 @@ export async function generateContent(
9693
chromeAdapter?: ChromeAdapter,
9794
requestOptions?: RequestOptions
9895
): Promise<GenerateContentResult> {
99-
let response;
100-
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
101-
response = await chromeAdapter.generateContent(params);
102-
} else {
103-
response = await generateContentOnCloud(
104-
apiSettings,
105-
model,
106-
params,
107-
requestOptions
108-
);
109-
}
96+
const response = await callCloudOrDevice(
97+
params,
98+
chromeAdapter,
99+
() => chromeAdapter!.generateContent(params),
100+
() => generateContentOnCloud(apiSettings, model, params, requestOptions)
101+
);
110102
const generateContentResponse = await processGenerateContentResponse(
111103
response,
112104
apiSettings

0 commit comments

Comments
 (0)