Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .changeset/feat-prefer-in-cloud.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
"@firebase/ai": minor
"firebase": minor
---

feat: Add `prefer_in_cloud` option for inference mode

This change introduces a new `InferenceMode` option, `prefer_in_cloud`. When this mode is selected, the SDK will attempt to use the cloud backend first. If the cloud call fails with a network-related error, it will fall back to the on-device model if available.

This also includes a refactoring of the logic for dispatching requests to either the on-device or cloud backends to improve clarity and remove duplication.
1 change: 1 addition & 0 deletions common/api-review/ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ export const InferenceMode: {
readonly PREFER_ON_DEVICE: "prefer_on_device";
readonly ONLY_ON_DEVICE: "only_on_device";
readonly ONLY_IN_CLOUD: "only_in_cloud";
readonly PREFER_IN_CLOUD: "prefer_in_cloud";
};

// @public
Expand Down
1 change: 1 addition & 0 deletions docs-devsite/ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ InferenceMode: {
readonly PREFER_ON_DEVICE: "prefer_on_device";
readonly ONLY_ON_DEVICE: "only_on_device";
readonly ONLY_IN_CLOUD: "only_in_cloud";
readonly PREFER_IN_CLOUD: "prefer_in_cloud";
}
```

Expand Down
26 changes: 9 additions & 17 deletions packages/ai/src/methods/count-tokens.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,16 @@ describe('countTokens()', () => {
);
});
});
it('on-device', async () => {
const chromeAdapter = fakeChromeAdapter;
const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true);
const mockResponse = getMockResponse(
'vertexAI',
'unary-success-total-tokens.json'
);
const countTokensStub = stub(chromeAdapter, 'countTokens').resolves(
mockResponse as Response
it('throws if mode is ONLY_ON_DEVICE', async () => {
const chromeAdapter = new ChromeAdapterImpl(
// @ts-expect-error
undefined,
InferenceMode.ONLY_ON_DEVICE
);
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams,
chromeAdapter
await expect(
countTokens(fakeApiSettings, 'model', fakeRequestParams, chromeAdapter)
).to.be.rejectedWith(
/countTokens\(\) is not supported for on-device models/
);
expect(result.totalTokens).eq(6);
expect(isAvailableStub).to.be.called;
expect(countTokensStub).to.be.calledWith(fakeRequestParams);
});
});
16 changes: 12 additions & 4 deletions packages/ai/src/methods/count-tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
* limitations under the License.
*/

import { AIError } from '../errors';
import {
CountTokensRequest,
CountTokensResponse,
RequestOptions
InferenceMode,
RequestOptions,
AIErrorCode
} from '../types';
import { Task, makeRequest } from '../requests/request';
import { ApiSettings } from '../types/internal';
import * as GoogleAIMapper from '../googleai-mappers';
import { BackendType } from '../public-types';
import { ChromeAdapter } from '../types/chrome-adapter';
import { ChromeAdapterImpl } from './chrome-adapter';

export async function countTokensOnCloud(
apiSettings: ApiSettings,
Expand Down Expand Up @@ -57,9 +61,13 @@ export async function countTokens(
chromeAdapter?: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<CountTokensResponse> {
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
return (await chromeAdapter.countTokens(params)).json();
if (
(chromeAdapter as ChromeAdapterImpl)?.mode === InferenceMode.ONLY_ON_DEVICE
) {
throw new AIError(
AIErrorCode.UNSUPPORTED,
'countTokens() is not supported for on-device models.'
);
}

return countTokensOnCloud(apiSettings, model, params, requestOptions);
}
36 changes: 14 additions & 22 deletions packages/ai/src/methods/generate-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { ApiSettings } from '../types/internal';
import * as GoogleAIMapper from '../googleai-mappers';
import { BackendType } from '../public-types';
import { ChromeAdapter } from '../types/chrome-adapter';
import { callCloudOrDevice } from '../requests/hybrid-helpers';

async function generateContentStreamOnCloud(
apiSettings: ApiSettings,
Expand Down Expand Up @@ -56,17 +57,13 @@ export async function generateContentStream(
chromeAdapter?: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<GenerateContentStreamResult> {
let response;
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
response = await chromeAdapter.generateContentStream(params);
} else {
response = await generateContentStreamOnCloud(
apiSettings,
model,
params,
requestOptions
);
}
const response = await callCloudOrDevice(
params,
chromeAdapter,
() => chromeAdapter!.generateContentStream(params),
() =>
generateContentStreamOnCloud(apiSettings, model, params, requestOptions)
);
return processStream(response, apiSettings); // TODO: Map streaming responses
}

Expand Down Expand Up @@ -96,17 +93,12 @@ export async function generateContent(
chromeAdapter?: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<GenerateContentResult> {
let response;
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
response = await chromeAdapter.generateContent(params);
} else {
response = await generateContentOnCloud(
apiSettings,
model,
params,
requestOptions
);
}
const response = await callCloudOrDevice(
params,
chromeAdapter,
() => chromeAdapter!.generateContent(params),
() => generateContentOnCloud(apiSettings, model, params, requestOptions)
);
const generateContentResponse = await processGenerateContentResponse(
response,
apiSettings
Expand Down
Loading
Loading