diff --git a/lib/src/providers/implementations/gemini_provider.dart b/lib/src/providers/implementations/gemini_provider.dart index 5c0e8d0..12fdb3c 100644 --- a/lib/src/providers/implementations/gemini_provider.dart +++ b/lib/src/providers/implementations/gemini_provider.dart @@ -29,6 +29,7 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { /// model's generation behavior. GeminiProvider({ required GenerativeModel model, + this.onFunctionCalls, Iterable? history, List? chatSafetySettings, GenerationConfig? chatGenerationConfig, @@ -38,7 +39,7 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { _chatGenerationConfig = chatGenerationConfig { _chat = _startChat(history); } - + final void Function(Iterable)? onFunctionCalls; final GenerativeModel _model; final List? _chatSafetySettings; final GenerationConfig? _chatGenerationConfig; @@ -53,6 +54,7 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { prompt: prompt, attachments: attachments, contentStreamGenerator: (c) => _model.generateContentStream([c]), + onFunctionCalls: onFunctionCalls, ); @override @@ -68,6 +70,7 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { prompt: prompt, attachments: attachments, contentStreamGenerator: _chat!.sendMessageStream, + onFunctionCalls: onFunctionCalls, ); // don't write this code if you're targeting the web until this is fixed: @@ -90,6 +93,7 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { required Iterable attachments, required Stream Function(Content) contentStreamGenerator, + required void Function(Iterable)? onFunctionCalls, }) async* { final content = Content('user', [ TextPart(prompt), @@ -104,7 +108,13 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { // if (text != null) yield text; // } yield* response - .map((chunk) => chunk.text) + .map((chunk) { + if (chunk.candidates.any((e) => e.finishReason != null) && + chunk.functionCalls.isNotEmpty) { + onFunctionCalls?.call(chunk.functionCalls); + } + return chunk.text; + }) .where((text) => text != null) .cast(); } diff --git a/lib/src/providers/implementations/vertex_provider.dart b/lib/src/providers/implementations/vertex_provider.dart index d9858d3..293260a 100644 --- a/lib/src/providers/implementations/vertex_provider.dart +++ b/lib/src/providers/implementations/vertex_provider.dart @@ -29,6 +29,7 @@ class VertexProvider extends LlmProvider with ChangeNotifier { /// model's generation behavior. VertexProvider({ required GenerativeModel model, + this.onFunctionCalls, Iterable? history, List? chatSafetySettings, GenerationConfig? chatGenerationConfig, @@ -38,7 +39,7 @@ class VertexProvider extends LlmProvider with ChangeNotifier { _chatGenerationConfig = chatGenerationConfig { _chat = _startChat(history); } - + final void Function(Iterable)? onFunctionCalls; final GenerativeModel _model; final List? _chatSafetySettings; final GenerationConfig? _chatGenerationConfig;