diff --git a/.changeset/feat-prefer-in-cloud.md b/.changeset/feat-prefer-in-cloud.md new file mode 100644 index 00000000000..90f859c2da8 --- /dev/null +++ b/.changeset/feat-prefer-in-cloud.md @@ -0,0 +1,6 @@ +--- +"@firebase/ai": minor +"firebase": minor +--- + +Added a new `InferenceMode` option for the hybrid on-device capability: `prefer_in_cloud`. When this mode is selected, the SDK will attempt to use a cloud-hosted model first. If the call to the cloud-hosted model fails with a network-related error, the SDK will fall back to the on-device model, if it's available. diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index 5a8e5df6ab9..43a10169cfa 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -727,6 +727,7 @@ export const InferenceMode: { readonly PREFER_ON_DEVICE: "prefer_on_device"; readonly ONLY_ON_DEVICE: "only_on_device"; readonly ONLY_IN_CLOUD: "only_in_cloud"; + readonly PREFER_IN_CLOUD: "prefer_in_cloud"; }; // @public diff --git a/docs-devsite/ai.md b/docs-devsite/ai.md index d70b381d6fe..d94c31155de 100644 --- a/docs-devsite/ai.md +++ b/docs-devsite/ai.md @@ -624,6 +624,8 @@ ImagenSafetyFilterLevel: { (EXPERIMENTAL) Determines whether inference happens on-device or in-cloud. +PREFER\_ON\_DEVICE: Attempt to make inference calls using an on-device model. If on-device inference is not available, the SDK will fall back to using a cloud-hosted model.
ONLY\_ON\_DEVICE: Only attempt to make inference calls using an on-device model. The SDK will not fall back to a cloud-hosted model. If on-device inference is not available, inference methods will throw.
ONLY\_IN\_CLOUD: Only attempt to make inference calls using a cloud-hosted model. The SDK will not fall back to an on-device model.
PREFER\_IN\_CLOUD: Attempt to make inference calls to a cloud-hosted model. If not available, the SDK will fall back to an on-device model. + Signature: ```typescript @@ -631,6 +633,7 @@ InferenceMode: { readonly PREFER_ON_DEVICE: "prefer_on_device"; readonly ONLY_ON_DEVICE: "only_on_device"; readonly ONLY_IN_CLOUD: "only_in_cloud"; + readonly PREFER_IN_CLOUD: "prefer_in_cloud"; } ``` diff --git a/packages/ai/src/methods/count-tokens.test.ts b/packages/ai/src/methods/count-tokens.test.ts index 56985b4d54e..aabf06a841a 100644 --- a/packages/ai/src/methods/count-tokens.test.ts +++ b/packages/ai/src/methods/count-tokens.test.ts @@ -196,24 +196,16 @@ describe('countTokens()', () => { ); }); }); - it('on-device', async () => { - const chromeAdapter = fakeChromeAdapter; - const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true); - const mockResponse = getMockResponse( - 'vertexAI', - 'unary-success-total-tokens.json' - ); - const countTokensStub = stub(chromeAdapter, 'countTokens').resolves( - mockResponse as Response + it('throws if mode is ONLY_ON_DEVICE', async () => { + const chromeAdapter = new ChromeAdapterImpl( + // @ts-expect-error + undefined, + InferenceMode.ONLY_ON_DEVICE ); - const result = await countTokens( - fakeApiSettings, - 'model', - fakeRequestParams, - chromeAdapter + await expect( + countTokens(fakeApiSettings, 'model', fakeRequestParams, chromeAdapter) + ).to.be.rejectedWith( + /countTokens\(\) is not supported for on-device models/ ); - expect(result.totalTokens).eq(6); - expect(isAvailableStub).to.be.called; - expect(countTokensStub).to.be.calledWith(fakeRequestParams); }); }); diff --git a/packages/ai/src/methods/count-tokens.ts b/packages/ai/src/methods/count-tokens.ts index 00dde84ab48..ecd86a82912 100644 --- a/packages/ai/src/methods/count-tokens.ts +++ b/packages/ai/src/methods/count-tokens.ts @@ -15,16 +15,20 @@ * limitations under the License. */ +import { AIError } from '../errors'; import { CountTokensRequest, CountTokensResponse, - RequestOptions + InferenceMode, + RequestOptions, + AIErrorCode } from '../types'; import { Task, makeRequest } from '../requests/request'; import { ApiSettings } from '../types/internal'; import * as GoogleAIMapper from '../googleai-mappers'; import { BackendType } from '../public-types'; import { ChromeAdapter } from '../types/chrome-adapter'; +import { ChromeAdapterImpl } from './chrome-adapter'; export async function countTokensOnCloud( apiSettings: ApiSettings, @@ -57,9 +61,13 @@ export async function countTokens( chromeAdapter?: ChromeAdapter, requestOptions?: RequestOptions ): Promise { - if (chromeAdapter && (await chromeAdapter.isAvailable(params))) { - return (await chromeAdapter.countTokens(params)).json(); + if ( + (chromeAdapter as ChromeAdapterImpl)?.mode === InferenceMode.ONLY_ON_DEVICE + ) { + throw new AIError( + AIErrorCode.UNSUPPORTED, + 'countTokens() is not supported for on-device models.' + ); } - return countTokensOnCloud(apiSettings, model, params, requestOptions); } diff --git a/packages/ai/src/methods/generate-content.ts b/packages/ai/src/methods/generate-content.ts index 2c1c383641f..0e65b479343 100644 --- a/packages/ai/src/methods/generate-content.ts +++ b/packages/ai/src/methods/generate-content.ts @@ -29,6 +29,7 @@ import { ApiSettings } from '../types/internal'; import * as GoogleAIMapper from '../googleai-mappers'; import { BackendType } from '../public-types'; import { ChromeAdapter } from '../types/chrome-adapter'; +import { callCloudOrDevice } from '../requests/hybrid-helpers'; async function generateContentStreamOnCloud( apiSettings: ApiSettings, @@ -56,17 +57,13 @@ export async function generateContentStream( chromeAdapter?: ChromeAdapter, requestOptions?: RequestOptions ): Promise { - let response; - if (chromeAdapter && (await chromeAdapter.isAvailable(params))) { - response = await chromeAdapter.generateContentStream(params); - } else { - response = await generateContentStreamOnCloud( - apiSettings, - model, - params, - requestOptions - ); - } + const response = await callCloudOrDevice( + params, + chromeAdapter, + () => chromeAdapter!.generateContentStream(params), + () => + generateContentStreamOnCloud(apiSettings, model, params, requestOptions) + ); return processStream(response, apiSettings); // TODO: Map streaming responses } @@ -96,17 +93,12 @@ export async function generateContent( chromeAdapter?: ChromeAdapter, requestOptions?: RequestOptions ): Promise { - let response; - if (chromeAdapter && (await chromeAdapter.isAvailable(params))) { - response = await chromeAdapter.generateContent(params); - } else { - response = await generateContentOnCloud( - apiSettings, - model, - params, - requestOptions - ); - } + const response = await callCloudOrDevice( + params, + chromeAdapter, + () => chromeAdapter!.generateContent(params), + () => generateContentOnCloud(apiSettings, model, params, requestOptions) + ); const generateContentResponse = await processGenerateContentResponse( response, apiSettings diff --git a/packages/ai/src/models/generative-model.test.ts b/packages/ai/src/models/generative-model.test.ts index 68f1565b26a..9c5074b8eb3 100644 --- a/packages/ai/src/models/generative-model.test.ts +++ b/packages/ai/src/models/generative-model.test.ts @@ -16,15 +16,26 @@ */ import { use, expect } from 'chai'; import { GenerativeModel } from './generative-model'; -import { FunctionCallingMode, AI, InferenceMode } from '../public-types'; +import { + FunctionCallingMode, + AI, + InferenceMode, + AIErrorCode +} from '../public-types'; import * as request from '../requests/request'; -import { match, restore, stub } from 'sinon'; -import { getMockResponse } from '../../test-utils/mock-response'; +import { SinonStub, match, restore, stub } from 'sinon'; +import { + getMockResponse, + getMockResponseStreaming +} from '../../test-utils/mock-response'; import sinonChai from 'sinon-chai'; import { VertexAIBackend } from '../backend'; import { ChromeAdapterImpl } from '../methods/chrome-adapter'; +import { AIError } from '../errors'; +import chaiAsPromised from 'chai-as-promised'; use(sinonChai); +use(chaiAsPromised); const fakeAI: AI = { app: { @@ -406,3 +417,299 @@ describe('GenerativeModel', () => { restore(); }); }); + +describe('GenerativeModel dispatch logic', () => { + let makeRequestStub: SinonStub; + let mockChromeAdapter: ChromeAdapterImpl; + + function stubMakeRequest(stream?: boolean): void { + if (stream) { + makeRequestStub = stub(request, 'makeRequest').resolves( + getMockResponseStreaming( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ) as Response + ); + } else { + makeRequestStub = stub(request, 'makeRequest').resolves( + getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ) as Response + ); + } + } + + beforeEach(() => { + // @ts-ignore + mockChromeAdapter = { + isAvailable: stub(), + generateContent: stub().resolves(new Response(JSON.stringify({}))), + generateContentStream: stub().resolves( + new Response(new ReadableStream()) + ), + countTokens: stub().resolves(new Response(JSON.stringify({}))), + mode: InferenceMode.PREFER_ON_DEVICE + }; + }); + + afterEach(() => { + restore(); + }); + + describe('PREFER_ON_DEVICE', () => { + beforeEach(() => { + mockChromeAdapter.mode = InferenceMode.PREFER_ON_DEVICE; + }); + it('should use on-device for generateContent when available', async () => { + stubMakeRequest(); + (mockChromeAdapter.isAvailable as SinonStub).resolves(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContent('hello'); + expect(mockChromeAdapter.generateContent).to.have.been.calledOnce; + expect(makeRequestStub).to.not.have.been.called; + }); + it('should use cloud for generateContent when on-device is not available', async () => { + stubMakeRequest(); + (mockChromeAdapter.isAvailable as SinonStub).resolves(false); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContent('hello'); + expect(mockChromeAdapter.generateContent).to.not.have.been.called; + expect(makeRequestStub).to.have.been.calledOnce; + }); + it('should use on-device for generateContentStream when available', async () => { + stubMakeRequest(true); + (mockChromeAdapter.isAvailable as SinonStub).resolves(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContentStream('hello'); + expect(mockChromeAdapter.generateContentStream).to.have.been.calledOnce; + expect(makeRequestStub).to.not.have.been.called; + }); + it('should use cloud for generateContentStream when on-device is not available', async () => { + stubMakeRequest(true); + (mockChromeAdapter.isAvailable as SinonStub).resolves(false); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContentStream('hello'); + expect(mockChromeAdapter.generateContentStream).to.not.have.been.called; + expect(makeRequestStub).to.have.been.calledOnce; + }); + it('should use cloud for countTokens', async () => { + stubMakeRequest(); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.countTokens('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + }); + }); + + describe('ONLY_ON_DEVICE', () => { + beforeEach(() => { + mockChromeAdapter.mode = InferenceMode.ONLY_ON_DEVICE; + }); + it('should use on-device for generateContent when available', async () => { + stubMakeRequest(); + (mockChromeAdapter.isAvailable as SinonStub).resolves(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContent('hello'); + expect(mockChromeAdapter.generateContent).to.have.been.calledOnce; + expect(makeRequestStub).to.not.have.been.called; + }); + it('generateContent should throw when on-device is not available', async () => { + stubMakeRequest(); + (mockChromeAdapter.isAvailable as SinonStub).resolves(false); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await expect(model.generateContent('hello')).to.be.rejectedWith( + /on-device model is not available/ + ); + expect(mockChromeAdapter.generateContent).to.not.have.been.called; + expect(makeRequestStub).to.not.have.been.called; + }); + it('should use on-device for generateContentStream when available', async () => { + stubMakeRequest(true); + (mockChromeAdapter.isAvailable as SinonStub).resolves(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContentStream('hello'); + expect(mockChromeAdapter.generateContentStream).to.have.been.calledOnce; + expect(makeRequestStub).to.not.have.been.called; + }); + it('generateContentStream should throw when on-device is not available', async () => { + stubMakeRequest(true); + (mockChromeAdapter.isAvailable as SinonStub).resolves(false); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await expect(model.generateContentStream('hello')).to.be.rejectedWith( + /on-device model is not available/ + ); + expect(mockChromeAdapter.generateContent).to.not.have.been.called; + expect(makeRequestStub).to.not.have.been.called; + }); + it('should always throw for countTokens', async () => { + stubMakeRequest(); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await expect(model.countTokens('hello')).to.be.rejectedWith(AIError); + expect(makeRequestStub).to.not.have.been.called; + }); + }); + + describe('ONLY_IN_CLOUD', () => { + beforeEach(() => { + mockChromeAdapter.mode = InferenceMode.ONLY_IN_CLOUD; + }); + it('should use cloud for generateContent even when on-device is available', async () => { + stubMakeRequest(); + (mockChromeAdapter.isAvailable as SinonStub).resolves(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContent('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + expect(mockChromeAdapter.generateContent).to.not.have.been.called; + }); + it('should use cloud for generateContentStream even when on-device is available', async () => { + stubMakeRequest(true); + (mockChromeAdapter.isAvailable as SinonStub).resolves(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContentStream('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + expect(mockChromeAdapter.generateContentStream).to.not.have.been.called; + }); + it('should always use cloud for countTokens', async () => { + stubMakeRequest(); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.countTokens('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + }); + }); + + describe('PREFER_IN_CLOUD', () => { + beforeEach(() => { + mockChromeAdapter.mode = InferenceMode.PREFER_IN_CLOUD; + }); + it('should use cloud for generateContent when available', async () => { + stubMakeRequest(); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContent('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + expect(mockChromeAdapter.generateContent).to.not.have.been.called; + }); + it('should fall back to on-device for generateContent if cloud fails', async () => { + makeRequestStub.rejects( + new AIError(AIErrorCode.FETCH_ERROR, 'Network error') + ); + (mockChromeAdapter.isAvailable as SinonStub).resolves(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContent('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + expect(mockChromeAdapter.generateContent).to.have.been.calledOnce; + }); + it('should use cloud for generateContentStream when available', async () => { + stubMakeRequest(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContentStream('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + expect(mockChromeAdapter.generateContentStream).to.not.have.been.called; + }); + it('should fall back to on-device for generateContentStream if cloud fails', async () => { + makeRequestStub.rejects( + new AIError(AIErrorCode.FETCH_ERROR, 'Network error') + ); + (mockChromeAdapter.isAvailable as SinonStub).resolves(true); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.generateContentStream('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + expect(mockChromeAdapter.generateContentStream).to.have.been.calledOnce; + }); + it('should use cloud for countTokens', async () => { + stubMakeRequest(); + const model = new GenerativeModel( + fakeAI, + { model: 'model' }, + {}, + mockChromeAdapter + ); + await model.countTokens('hello'); + expect(makeRequestStub).to.have.been.calledOnce; + }); + }); +}); diff --git a/packages/ai/src/requests/hybrid-helpers.test.ts b/packages/ai/src/requests/hybrid-helpers.test.ts new file mode 100644 index 00000000000..a758f34ad21 --- /dev/null +++ b/packages/ai/src/requests/hybrid-helpers.test.ts @@ -0,0 +1,187 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { use, expect } from 'chai'; +import { SinonStub, SinonStubbedInstance, restore, stub } from 'sinon'; +import { callCloudOrDevice } from './hybrid-helpers'; +import { GenerateContentRequest, InferenceMode, AIErrorCode } from '../types'; +import { AIError } from '../errors'; +import sinonChai from 'sinon-chai'; +import chaiAsPromised from 'chai-as-promised'; +import { ChromeAdapterImpl } from '../methods/chrome-adapter'; + +use(sinonChai); +use(chaiAsPromised); + +describe('callCloudOrDevice', () => { + let chromeAdapter: SinonStubbedInstance; + let onDeviceCall: SinonStub; + let inCloudCall: SinonStub; + let request: GenerateContentRequest; + + beforeEach(() => { + // @ts-ignore + chromeAdapter = { + mode: InferenceMode.PREFER_ON_DEVICE, + isAvailable: stub(), + generateContent: stub(), + generateContentStream: stub(), + countTokens: stub() + }; + onDeviceCall = stub().resolves('on-device-response'); + inCloudCall = stub().resolves('in-cloud-response'); + request = { contents: [] }; + }); + + afterEach(() => { + restore(); + }); + + it('should call inCloudCall if chromeAdapter is undefined', async () => { + const result = await callCloudOrDevice( + request, + undefined, + onDeviceCall, + inCloudCall + ); + expect(result).to.equal('in-cloud-response'); + expect(inCloudCall).to.have.been.calledOnce; + expect(onDeviceCall).to.not.have.been.called; + }); + + describe('PREFER_ON_DEVICE mode', () => { + beforeEach(() => { + chromeAdapter.mode = InferenceMode.PREFER_ON_DEVICE; + }); + + it('should call onDeviceCall if available', async () => { + chromeAdapter.isAvailable.resolves(true); + const result = await callCloudOrDevice( + request, + chromeAdapter, + onDeviceCall, + inCloudCall + ); + expect(result).to.equal('on-device-response'); + expect(onDeviceCall).to.have.been.calledOnce; + expect(inCloudCall).to.not.have.been.called; + }); + + it('should call inCloudCall if not available', async () => { + chromeAdapter.isAvailable.resolves(false); + const result = await callCloudOrDevice( + request, + chromeAdapter, + onDeviceCall, + inCloudCall + ); + expect(result).to.equal('in-cloud-response'); + expect(inCloudCall).to.have.been.calledOnce; + expect(onDeviceCall).to.not.have.been.called; + }); + }); + + describe('ONLY_ON_DEVICE mode', () => { + beforeEach(() => { + chromeAdapter.mode = InferenceMode.ONLY_ON_DEVICE; + }); + + it('should call onDeviceCall if available', async () => { + chromeAdapter.isAvailable.resolves(true); + const result = await callCloudOrDevice( + request, + chromeAdapter, + onDeviceCall, + inCloudCall + ); + expect(result).to.equal('on-device-response'); + expect(onDeviceCall).to.have.been.calledOnce; + expect(inCloudCall).to.not.have.been.called; + }); + + it('should throw if not available', async () => { + chromeAdapter.isAvailable.resolves(false); + await expect( + callCloudOrDevice(request, chromeAdapter, onDeviceCall, inCloudCall) + ).to.be.rejectedWith(/on-device model is not available/); + expect(inCloudCall).to.not.have.been.called; + expect(onDeviceCall).to.not.have.been.called; + }); + }); + + describe('ONLY_IN_CLOUD mode', () => { + beforeEach(() => { + chromeAdapter.mode = InferenceMode.ONLY_IN_CLOUD; + }); + + it('should call inCloudCall even if on-device is available', async () => { + chromeAdapter.isAvailable.resolves(true); + const result = await callCloudOrDevice( + request, + chromeAdapter, + onDeviceCall, + inCloudCall + ); + expect(result).to.equal('in-cloud-response'); + expect(inCloudCall).to.have.been.calledOnce; + expect(onDeviceCall).to.not.have.been.called; + }); + }); + + describe('PREFER_IN_CLOUD mode', () => { + beforeEach(() => { + chromeAdapter.mode = InferenceMode.PREFER_IN_CLOUD; + }); + + it('should call inCloudCall first', async () => { + const result = await callCloudOrDevice( + request, + chromeAdapter, + onDeviceCall, + inCloudCall + ); + expect(result).to.equal('in-cloud-response'); + expect(inCloudCall).to.have.been.calledOnce; + expect(onDeviceCall).to.not.have.been.called; + }); + + it('should fall back to onDeviceCall if inCloudCall fails with AIErrorCode.FETCH_ERROR', async () => { + inCloudCall.rejects( + new AIError(AIErrorCode.FETCH_ERROR, 'Network error') + ); + const result = await callCloudOrDevice( + request, + chromeAdapter, + onDeviceCall, + inCloudCall + ); + expect(result).to.equal('on-device-response'); + expect(inCloudCall).to.have.been.calledOnce; + expect(onDeviceCall).to.have.been.calledOnce; + }); + + it('should re-throw other errors from inCloudCall', async () => { + const error = new AIError(AIErrorCode.RESPONSE_ERROR, 'safety problem'); + inCloudCall.rejects(error); + await expect( + callCloudOrDevice(request, chromeAdapter, onDeviceCall, inCloudCall) + ).to.be.rejectedWith(error); + expect(inCloudCall).to.have.been.calledOnce; + expect(onDeviceCall).to.not.have.been.called; + }); + }); +}); diff --git a/packages/ai/src/requests/hybrid-helpers.ts b/packages/ai/src/requests/hybrid-helpers.ts new file mode 100644 index 00000000000..3140594c00e --- /dev/null +++ b/packages/ai/src/requests/hybrid-helpers.ts @@ -0,0 +1,88 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AIError } from '../errors'; +import { + GenerateContentRequest, + InferenceMode, + AIErrorCode, + ChromeAdapter +} from '../types'; +import { ChromeAdapterImpl } from '../methods/chrome-adapter'; + +const errorsCausingFallback: AIErrorCode[] = [ + // most network errors + AIErrorCode.FETCH_ERROR, + // fallback code for all other errors in makeRequest + AIErrorCode.ERROR, + // error due to API not being enabled in project + AIErrorCode.API_NOT_ENABLED +]; + +/** + * Dispatches a request to the appropriate backend (on-device or in-cloud) + * based on the inference mode. + * + * @param request - The request to be sent. + * @param chromeAdapter - The on-device model adapter. + * @param onDeviceCall - The function to call for on-device inference. + * @param inCloudCall - The function to call for in-cloud inference. + * @returns The response from the backend. + */ +export async function callCloudOrDevice( + request: GenerateContentRequest, + chromeAdapter: ChromeAdapter | undefined, + onDeviceCall: () => Promise, + inCloudCall: () => Promise +): Promise { + if (!chromeAdapter) { + return inCloudCall(); + } + switch ((chromeAdapter as ChromeAdapterImpl).mode) { + case InferenceMode.ONLY_ON_DEVICE: + if (await chromeAdapter.isAvailable(request)) { + return onDeviceCall(); + } + throw new AIError( + AIErrorCode.UNSUPPORTED, + 'Inference mode is ONLY_ON_DEVICE, but an on-device model is not available.' + ); + case InferenceMode.ONLY_IN_CLOUD: + return inCloudCall(); + case InferenceMode.PREFER_IN_CLOUD: + try { + return await inCloudCall(); + } catch (e) { + if (e instanceof AIError && errorsCausingFallback.includes(e.code)) { + return onDeviceCall(); + } + throw e; + } + case InferenceMode.PREFER_ON_DEVICE: + if (await chromeAdapter.isAvailable(request)) { + return onDeviceCall(); + } + return inCloudCall(); + default: + throw new AIError( + AIErrorCode.ERROR, + `Unexpected infererence mode: ${ + (chromeAdapter as ChromeAdapterImpl).mode + }` + ); + } +} diff --git a/packages/ai/src/types/enums.ts b/packages/ai/src/types/enums.ts index 701cd4a695d..de70d325157 100644 --- a/packages/ai/src/types/enums.ts +++ b/packages/ai/src/types/enums.ts @@ -347,17 +347,36 @@ export type ResponseModality = /** * (EXPERIMENTAL) * Determines whether inference happens on-device or in-cloud. + * + * @remarks + * PREFER_ON_DEVICE: Attempt to make inference calls using an + * on-device model. If on-device inference is not available, the SDK + * will fall back to using a cloud-hosted model. + *
+ * ONLY_ON_DEVICE: Only attempt to make inference calls using an + * on-device model. The SDK will not fall back to a cloud-hosted model. + * If on-device inference is not available, inference methods will throw. + *
+ * ONLY_IN_CLOUD: Only attempt to make inference calls using a + * cloud-hosted model. The SDK will not fall back to an on-device model. + *
+ * PREFER_IN_CLOUD: Attempt to make inference calls to a + * cloud-hosted model. If not available, the SDK will fall back to an + * on-device model. + * * @public */ export const InferenceMode = { 'PREFER_ON_DEVICE': 'prefer_on_device', 'ONLY_ON_DEVICE': 'only_on_device', - 'ONLY_IN_CLOUD': 'only_in_cloud' + 'ONLY_IN_CLOUD': 'only_in_cloud', + 'PREFER_IN_CLOUD': 'prefer_in_cloud' } as const; /** * (EXPERIMENTAL) * Determines whether inference happens on-device or in-cloud. + * * @public */ export type InferenceMode = (typeof InferenceMode)[keyof typeof InferenceMode];