Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@
"type": "dart",
"program": "lib/logging/logging.dart",
},
{
"name": "function calls",
"cwd": "example",
"request": "launch",
"type": "dart",
"program": "lib/function_calls/function_calls.dart",
},
{
"name": "recipes",
"cwd": "example",
Expand All @@ -86,4 +93,4 @@
"program": "lib/recipes/recipes.dart",
},
]
}
}
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## 0.8.1
* added support for tool calls to the Gemini and Vertex providers. Check out the
new `function_calls` example to see it in action. Thanks to @toshiossada for
[the inspiration](https://github.com/flutter/ai/pull/99). Fixes
[#98](https://github.com/flutter/ai/issues/98): How Can I get functionCalls?

## 0.8.0
* fixed [#90](https://github.com/flutter/ai/issues/90): Input box
shrinks unexpectedly when clicking file attachment button – customization not
Expand Down
3 changes: 3 additions & 0 deletions example/android/app/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
plugins {
id("com.android.application")
// START: FlutterFire Configuration
id("com.google.gms.google-services")
// END: FlutterFire Configuration
id("kotlin-android")
// The Flutter Gradle Plugin must be applied after the Android and Kotlin Gradle plugins.
id("dev.flutter.flutter-gradle-plugin")
Expand Down
3 changes: 3 additions & 0 deletions example/android/settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pluginManagement {
plugins {
id("dev.flutter.flutter-plugin-loader") version "1.0.0"
id("com.android.application") version "8.7.0" apply false
// START: FlutterFire Configuration
id("com.google.gms.google-services") version("4.3.15") apply false
// END: FlutterFire Configuration
id("org.jetbrains.kotlin.android") version "1.8.22" apply false
}

Expand Down
4 changes: 4 additions & 0 deletions example/ios/Runner.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
331C808B294A63AB00263BE5 /* RunnerTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 331C807B294A618700263BE5 /* RunnerTests.swift */; };
3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */ = {isa = PBXBuildFile; fileRef = 3B3967151E833CAA004F5970 /* AppFrameworkInfo.plist */; };
74858FAF1ED2DC5600515810 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 74858FAE1ED2DC5600515810 /* AppDelegate.swift */; };
96B60D02F354EA628523F83C /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 041275B288B850F18666D4D3 /* GoogleService-Info.plist */; };
97C146FC1CF9000F007C117D /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FA1CF9000F007C117D /* Main.storyboard */; };
97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FD1CF9000F007C117D /* Assets.xcassets */; };
97C147011CF9000F007C117D /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FF1CF9000F007C117D /* LaunchScreen.storyboard */; };
Expand Down Expand Up @@ -43,6 +44,7 @@

/* Begin PBXFileReference section */
00907B8AC1679C03DF3ECF3B /* Pods_RunnerTests.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_RunnerTests.framework; sourceTree = BUILT_PRODUCTS_DIR; };
041275B288B850F18666D4D3 /* GoogleService-Info.plist */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.plist.xml; name = "GoogleService-Info.plist"; path = "Runner/GoogleService-Info.plist"; sourceTree = "<group>"; };
1498D2321E8E86230040F4C2 /* GeneratedPluginRegistrant.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = GeneratedPluginRegistrant.h; sourceTree = "<group>"; };
1498D2331E8E89220040F4C2 /* GeneratedPluginRegistrant.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = GeneratedPluginRegistrant.m; sourceTree = "<group>"; };
1D6994B85815BA0993EDD5A4 /* Pods-RunnerTests.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-RunnerTests.debug.xcconfig"; path = "Target Support Files/Pods-RunnerTests/Pods-RunnerTests.debug.xcconfig"; sourceTree = "<group>"; };
Expand Down Expand Up @@ -129,6 +131,7 @@
331C8082294A63A400263BE5 /* RunnerTests */,
67F7FFA188394B87CE9E2A89 /* Pods */,
F15FFF544ED27837A1711D1A /* Frameworks */,
041275B288B850F18666D4D3 /* GoogleService-Info.plist */,
);
sourceTree = "<group>";
};
Expand Down Expand Up @@ -264,6 +267,7 @@
3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */,
97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */,
97C146FC1CF9000F007C117D /* Main.storyboard in Resources */,
96B60D02F354EA628523F83C /* GoogleService-Info.plist in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down
66 changes: 66 additions & 0 deletions example/lib/function_calls/function_calls.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2024 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

import 'package:flutter/material.dart';
import 'package:flutter_ai_toolkit/flutter_ai_toolkit.dart';
import 'package:google_generative_ai/google_generative_ai.dart';

import '../gemini_api_key.dart';

void main() => runApp(const App());

class App extends StatelessWidget {
static const title = 'Example: Function Calls';

const App({super.key});

@override
Widget build(BuildContext context) =>
const MaterialApp(title: title, home: ChatPage());
}

class ChatPage extends StatelessWidget {
const ChatPage({super.key});

@override
Widget build(BuildContext context) => Scaffold(
appBar: AppBar(title: const Text(App.title)),
body: LlmChatView(
provider: GeminiProvider(
model: GenerativeModel(
model: 'gemini-2.0-flash',
apiKey: geminiApiKey,
tools: [
Tool(
functionDeclarations: [
FunctionDeclaration(
'get_temperature',
'Get the current local temperature',
Schema.object(properties: {}),
),
FunctionDeclaration(
'get_time',
'Get the current local time',
Schema.object(properties: {}),
),
],
),
],
),
onFunctionCall: _onFunctionCall,
),
),
);

Future<Map<String, Object?>?> _onFunctionCall(
FunctionCall functionCall,
) async {
// note: just as an example, we're not actually calling any external APIs
return switch (functionCall.name) {
'get_temperature' => {'temperature': 60, 'unit': 'F'},
'get_time' => {'time': DateTime(1970, 1, 1).toIso8601String()},
_ => throw Exception('Unknown function call: ${functionCall.name}'),
};
}
}
4 changes: 4 additions & 0 deletions example/macos/Runner.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
33CC10F32044A3C60003C045 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 33CC10F22044A3C60003C045 /* Assets.xcassets */; };
33CC10F62044A3C60003C045 /* MainMenu.xib in Resources */ = {isa = PBXBuildFile; fileRef = 33CC10F42044A3C60003C045 /* MainMenu.xib */; };
33CC11132044BFA00003C045 /* MainFlutterWindow.swift in Sources */ = {isa = PBXBuildFile; fileRef = 33CC11122044BFA00003C045 /* MainFlutterWindow.swift */; };
83C8768E218628E214E0C7F2 /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */; };
D9168A3AC46A7BD217B5C7C1 /* Pods_RunnerTests.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = B40A05188DFAAEDDB9FB89BA /* Pods_RunnerTests.framework */; };
/* End PBXBuildFile section */

Expand Down Expand Up @@ -63,6 +64,7 @@

/* Begin PBXFileReference section */
214AF9134787478DD9FCBE9B /* Pods-Runner.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Runner.release.xcconfig"; path = "Target Support Files/Pods-Runner/Pods-Runner.release.xcconfig"; sourceTree = "<group>"; };
29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.plist.xml; name = "GoogleService-Info.plist"; path = "Runner/GoogleService-Info.plist"; sourceTree = "<group>"; };
2D628886CA8B87FB4043E6B3 /* Pods-RunnerTests.profile.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-RunnerTests.profile.xcconfig"; path = "Target Support Files/Pods-RunnerTests/Pods-RunnerTests.profile.xcconfig"; sourceTree = "<group>"; };
331C80D5294CF71000263BE5 /* RunnerTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = RunnerTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
331C80D7294CF71000263BE5 /* RunnerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RunnerTests.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -152,6 +154,7 @@
33CC10EE2044A3C60003C045 /* Products */,
D73912EC22F37F3D000D13A0 /* Frameworks */,
1D8694EA61F695DDF92A68F4 /* Pods */,
29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */,
);
sourceTree = "<group>";
};
Expand Down Expand Up @@ -317,6 +320,7 @@
files = (
33CC10F32044A3C60003C045 /* Assets.xcassets in Resources */,
33CC10F62044A3C60003C045 /* MainMenu.xib in Resources */,
83C8768E218628E214E0C7F2 /* GoogleService-Info.plist in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down
52 changes: 35 additions & 17 deletions lib/src/providers/implementations/gemini_provider.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,29 @@ class GeminiProvider extends LlmProvider with ChangeNotifier {
///
/// [chatGenerationConfig] is an optional configuration for controlling the
/// model's generation behavior.
///
/// [onFunctionCall] is an optional function that will be called when the LLM
/// needs to call a function.
GeminiProvider({
required GenerativeModel model,
this.onFunctionCalls,
Iterable<ChatMessage>? history,
List<SafetySetting>? chatSafetySettings,
GenerationConfig? chatGenerationConfig,
Future<Map<String, Object?>?> Function(FunctionCall)? onFunctionCall,
}) : _model = model,
_history = history?.toList() ?? [],
_chatSafetySettings = chatSafetySettings,
_chatGenerationConfig = chatGenerationConfig {
_chatGenerationConfig = chatGenerationConfig,
_onFunctionCall = onFunctionCall {
_chat = _startChat(history);
}
final void Function(Iterable<FunctionCall>)? onFunctionCalls;

final GenerativeModel _model;
final List<SafetySetting>? _chatSafetySettings;
final GenerationConfig? _chatGenerationConfig;
final List<ChatMessage> _history;
ChatSession? _chat;
final Future<Map<String, Object?>?> Function(FunctionCall)? _onFunctionCall;

@override
Stream<String> generateStream(
Expand All @@ -54,7 +59,6 @@ class GeminiProvider extends LlmProvider with ChangeNotifier {
prompt: prompt,
attachments: attachments,
contentStreamGenerator: (c) => _model.generateContentStream([c]),
onFunctionCalls: onFunctionCalls,
);

@override
Expand All @@ -70,7 +74,6 @@ class GeminiProvider extends LlmProvider with ChangeNotifier {
prompt: prompt,
attachments: attachments,
contentStreamGenerator: _chat!.sendMessageStream,
onFunctionCalls: onFunctionCalls,
);

// don't write this code if you're targeting the web until this is fixed:
Expand All @@ -93,30 +96,45 @@ class GeminiProvider extends LlmProvider with ChangeNotifier {
required Iterable<Attachment> attachments,
required Stream<GenerateContentResponse> Function(Content)
contentStreamGenerator,
required void Function(Iterable<FunctionCall>)? onFunctionCalls,
}) async* {
final content = Content('user', [
TextPart(prompt),
...attachments.map(_partFrom),
]);

final response = contentStreamGenerator(content);
final contentResponse = contentStreamGenerator(content);

// don't write this code if you're targeting the web until this is fixed:
// https://github.com/dart-lang/sdk/issues/47764
// await for (final chunk in response) {
// final text = chunk.text;
// if (text != null) yield text;
// }
yield* response
.map((chunk) {
if (chunk.candidates.any((e) => e.finishReason != null) &&
chunk.functionCalls.isNotEmpty) {
onFunctionCalls?.call(chunk.functionCalls);
}
return chunk.text;
})
.where((text) => text != null)
.cast<String>();
yield* contentResponse.asyncMap((chunk) async {
if (chunk.functionCalls.isEmpty) return chunk.text ?? '';

final functionResponses = <FunctionResponse>[];
for (final functionCall in chunk.functionCalls) {
try {
functionResponses.add(
FunctionResponse(
functionCall.name,
await _onFunctionCall?.call(functionCall) ?? {},
),
);
} catch (ex) {
functionResponses.add(
FunctionResponse(functionCall.name, {'error': ex.toString()}),
);
}
}

final functionContentResponse = await _chat!.sendMessage(
Content.functionResponses(functionResponses),
);

return '${chunk.text ?? ''}${functionContentResponse.text ?? ''}';
});
}

@override
Expand Down
37 changes: 31 additions & 6 deletions lib/src/providers/implementations/vertex_provider.dart
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@ class VertexProvider extends LlmProvider with ChangeNotifier {
Iterable<ChatMessage>? history,
List<SafetySetting>? chatSafetySettings,
GenerationConfig? chatGenerationConfig,
Future<Map<String, Object?>?> Function(FunctionCall)? onFunctionCall,
}) : _model = model,
_history = history?.toList() ?? [],
_chatSafetySettings = chatSafetySettings,
_chatGenerationConfig = chatGenerationConfig {
_chatGenerationConfig = chatGenerationConfig,
_onFunctionCall = onFunctionCall {
_chat = _startChat(history);
}
final void Function(Iterable<FunctionCall>)? onFunctionCalls;
final GenerativeModel _model;
final List<SafetySetting>? _chatSafetySettings;
final GenerationConfig? _chatGenerationConfig;
final List<ChatMessage> _history;
final Future<Map<String, Object?>?> Function(FunctionCall)? _onFunctionCall;
ChatSession? _chat;

@override
Expand Down Expand Up @@ -97,17 +100,39 @@ class VertexProvider extends LlmProvider with ChangeNotifier {
...attachments.map(_partFrom),
]);

final response = contentStreamGenerator(content);
final contentResponse = contentStreamGenerator(content);

// don't write this code if you're targeting the web until this is fixed:
// https://github.com/dart-lang/sdk/issues/47764
// await for (final chunk in response) {
// final text = chunk.text;
// if (text != null) yield text;
// }
yield* response
.map((chunk) => chunk.text)
.where((text) => text != null)
.cast<String>();
yield* contentResponse.asyncMap((chunk) async {
if (chunk.functionCalls.isEmpty) return chunk.text ?? '';

final functionResponses = <FunctionResponse>[];
for (final functionCall in chunk.functionCalls) {
try {
functionResponses.add(
FunctionResponse(
functionCall.name,
await _onFunctionCall?.call(functionCall) ?? {},
),
);
} catch (ex) {
functionResponses.add(
FunctionResponse(functionCall.name, {'error': ex.toString()}),
);
}
}

final functionContentResponse = await _chat!.sendMessage(
Content.functionResponses(functionResponses),
);

return '${chunk.text ?? ''}${functionContentResponse.text ?? ''}';
});
}

@override
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: flutter_ai_toolkit
description: >-
A set of AI chat-related widgets for Flutter apps targeting
mobile, desktop, and web.
version: 0.8.0
version: 0.8.1
repository: https://github.com/flutter/ai

topics:
Expand Down