-
Notifications
You must be signed in to change notification settings - Fork 36
Update travel app to talk directly to Firebase AI Logic rather than use firebase_ai_client to enable easier tool integration #276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
91fb159
a7f5c86
30ba46b
cc65a18
6f7abea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,18 +2,22 @@ | |
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
import 'package:dart_schema_builder/dart_schema_builder.dart'; | ||
import 'package:firebase_ai/firebase_ai.dart'; | ||
import 'package:firebase_app_check/firebase_app_check.dart'; | ||
import 'package:firebase_core/firebase_core.dart'; | ||
import 'package:flutter/material.dart'; | ||
import 'package:flutter_genui/flutter_genui.dart'; | ||
import 'package:flutter_genui/flutter_genui.dart' hide ChatMessage, TextPart; | ||
import 'package:logging/logging.dart'; | ||
|
||
import 'firebase_options.dart'; | ||
import 'src/asset_images.dart'; | ||
import 'src/catalog.dart'; | ||
import 'src/gemini_client.dart'; | ||
import 'src/turn.dart'; | ||
import 'src/widgets/conversation.dart'; | ||
|
||
final _logger = Logger('TravelApp'); | ||
|
||
void main() async { | ||
WidgetsFlutterBinding.ensureInitialized(); | ||
await Firebase.initializeApp(options: DefaultFirebaseOptions.currentPlatform); | ||
|
@@ -24,26 +28,30 @@ void main() async { | |
); | ||
_imagesJson = await assetImageCatalogJson(); | ||
configureGenUiLogging(level: Level.ALL); | ||
_configureLogging(); | ||
runApp(const TravelApp()); | ||
} | ||
|
||
void _configureLogging() { | ||
hierarchicalLoggingEnabled = true; | ||
Logger.root.level = Level.ALL; | ||
Logger.root.onRecord.listen((record) { | ||
// ignore: avoid_print | ||
print( | ||
'[${record.level.name}] ${record.time}: ' | ||
'${record.loggerName}: ${record.message}', | ||
); | ||
}); | ||
} | ||
|
||
/// The root widget for the travel application. | ||
/// | ||
/// This widget sets up the [MaterialApp], which configures the overall theme, | ||
/// title, and home page for the app. It serves as the main entry point for the | ||
/// user interface. | ||
class TravelApp extends StatelessWidget { | ||
/// Creates a new [TravelApp]. | ||
/// | ||
/// The optional [aiClient] can be used to inject a specific AI client, | ||
/// which is useful for testing with a mock implementation. | ||
const TravelApp({this.aiClient, super.key}); | ||
|
||
/// The AI client to use for the application. | ||
/// | ||
/// If null, a default [FirebaseAiClient] will be created by the | ||
/// [TravelPlannerPage]. | ||
final AiClient? aiClient; | ||
const TravelApp({super.key}); | ||
|
||
@override | ||
Widget build(BuildContext context) { | ||
|
@@ -53,7 +61,7 @@ class TravelApp extends StatelessWidget { | |
theme: ThemeData( | ||
colorScheme: ColorScheme.fromSeed(seedColor: Colors.blue), | ||
), | ||
home: TravelPlannerPage(aiClient: aiClient), | ||
home: const TravelPlannerPage(), | ||
); | ||
} | ||
} | ||
|
@@ -70,27 +78,17 @@ class TravelApp extends StatelessWidget { | |
/// generated UI, and a menu to switch between different AI models. | ||
class TravelPlannerPage extends StatefulWidget { | ||
/// Creates a new [TravelPlannerPage]. | ||
/// | ||
/// An optional [aiClient] can be provided, which is useful for testing | ||
/// or using a custom AI client implementation. If not provided, a default | ||
/// [FirebaseAiClient] is created. | ||
const TravelPlannerPage({this.aiClient, super.key}); | ||
|
||
/// The AI client to use for the application. | ||
/// | ||
/// If null, a default instance of [FirebaseAiClient] will be created within | ||
/// the page's state. | ||
final AiClient? aiClient; | ||
const TravelPlannerPage({super.key}); | ||
Comment on lines
80
to
+81
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The removal of dependency injection for the AI client has made To maintain testability, please consider re-introducing dependency injection. You could pass a |
||
|
||
@override | ||
State<TravelPlannerPage> createState() => _TravelPlannerPageState(); | ||
} | ||
|
||
class _TravelPlannerPageState extends State<TravelPlannerPage> { | ||
late final GenUiManager _genUiManager; | ||
late final AiClient _aiClient; | ||
late final GeminiClient _geminiClient; | ||
late final UiEventManager _eventManager; | ||
final List<ChatMessage> _conversation = []; | ||
final List<Turn> _conversation = []; | ||
final _textController = TextEditingController(); | ||
final _scrollController = ScrollController(); | ||
bool _isThinking = false; | ||
|
@@ -109,31 +107,29 @@ class _TravelPlannerPageState extends State<TravelPlannerPage> { | |
), | ||
); | ||
_eventManager = UiEventManager(callback: _onUiEvents); | ||
_aiClient = | ||
widget.aiClient ?? | ||
FirebaseAiClient( | ||
tools: _genUiManager.getTools(), | ||
systemInstruction: prompt, | ||
); | ||
_geminiClient = GeminiClient( | ||
tools: _genUiManager.getTools(), | ||
systemInstruction: prompt, | ||
); | ||
_genUiManager.surfaceUpdates.listen((update) { | ||
setState(() { | ||
switch (update) { | ||
case SurfaceAdded(:final surfaceId, :final definition): | ||
_conversation.add( | ||
AiUiMessage(definition: definition, surfaceId: surfaceId), | ||
GenUiTurn(definition: definition, surfaceId: surfaceId), | ||
); | ||
_scrollToBottom(); | ||
|
||
case SurfaceRemoved(:final surfaceId): | ||
_conversation.removeWhere( | ||
(m) => m is AiUiMessage && m.surfaceId == surfaceId, | ||
(m) => m is GenUiTurn && m.surfaceId == surfaceId, | ||
); | ||
case SurfaceUpdated(:final surfaceId, :final definition): | ||
final index = _conversation.lastIndexWhere( | ||
(m) => m is AiUiMessage && m.surfaceId == surfaceId, | ||
(m) => m is GenUiTurn && m.surfaceId == surfaceId, | ||
); | ||
if (index != -1) { | ||
_conversation[index] = AiUiMessage( | ||
_conversation[index] = GenUiTurn( | ||
definition: definition, | ||
surfaceId: surfaceId, | ||
); | ||
|
@@ -169,32 +165,18 @@ class _TravelPlannerPageState extends State<TravelPlannerPage> { | |
_isThinking = true; | ||
}); | ||
try { | ||
final result = await _aiClient.generateContent( | ||
_conversation, | ||
S.object( | ||
properties: { | ||
'result': S.boolean( | ||
description: 'Successfully generated a response UI.', | ||
), | ||
'message': S.string( | ||
description: | ||
'A message about what went wrong, or a message responding to ' | ||
'the request. Take into account any UI that has been ' | ||
"generated, so there's no need to duplicate requests or " | ||
'information already present in the UI.', | ||
), | ||
}, | ||
required: ['result'], | ||
), | ||
); | ||
if (result == null) { | ||
return; | ||
} | ||
final value = | ||
(result as Map).cast<String, Object?>()['message'] as String? ?? ''; | ||
final contentHistory = _conversation | ||
.map((turn) => turn.toContent()) | ||
.whereType<Content>() | ||
.toList(); | ||
final result = await _geminiClient.generate(contentHistory); | ||
final value = result.candidates.first.content.parts | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing |
||
.whereType<TextPart>() | ||
.map((part) => part.text) | ||
.join(''); | ||
if (value.isNotEmpty) { | ||
setState(() { | ||
_conversation.add(AiTextMessage.text(value)); | ||
_conversation.add(AiTextTurn(value)); | ||
}); | ||
_scrollToBottom(); | ||
} | ||
|
@@ -227,7 +209,7 @@ class _TravelPlannerPageState extends State<TravelPlannerPage> { | |
} | ||
|
||
setState(() { | ||
_conversation.add(UserUiInteractionMessage.text(message.toString())); | ||
_conversation.add(UserUiInteractionTurn(message.toString())); | ||
}); | ||
_scrollToBottom(); | ||
_triggerInference(); | ||
|
@@ -240,7 +222,7 @@ class _TravelPlannerPageState extends State<TravelPlannerPage> { | |
void _sendPrompt(String text) { | ||
if (_isThinking || text.trim().isEmpty) return; | ||
setState(() { | ||
_conversation.add(UserMessage.text(text)); | ||
_conversation.add(UserTurn(text)); | ||
}); | ||
_scrollToBottom(); | ||
_textController.clear(); | ||
|
@@ -386,6 +368,7 @@ to the user. | |
3. Create an initial itinerary, which will be iterated over in subsequent | ||
steps. This involves planning out each day of the trip, including the | ||
specific locations and draft activities. For shorter trips where the | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
customer is just staying in one location, this may just involve choosing | ||
activities, while for longer trips this likely involves choosing which | ||
specific places to stay in and how many nights in each place. | ||
|
@@ -447,10 +430,10 @@ because it avoids confusing the conversation with many versions of the same | |
itinerary etc. | ||
|
||
When processing a user message or event, you should add or update one surface | ||
and then call provideFinalOutput to return control to the user. Never continue | ||
to add or update surfaces until you receive another user event. If the last | ||
entry in the context is a functionResponse, just call provideFinalOutput | ||
immediately - don't try to update the UI. | ||
and then output an explanatory message to return control to the user. Never | ||
continue to add or update surfaces until you receive another user event. | ||
If the last entry in the context is a functionResponse from addOrUpdateSurface, | ||
*do not* call addOrUpdateSurface again - just return. | ||
|
||
# UI style | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// Copyright 2025 The Flutter Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
import 'dart:convert'; | ||
|
||
import 'package:firebase_ai/firebase_ai.dart' as fai; | ||
import 'package:flutter_genui/flutter_genui.dart'; | ||
import 'package:flutter_genui/src/ai_client/gemini_schema_adapter.dart'; | ||
import 'package:logging/logging.dart'; | ||
|
||
class GeminiClient { | ||
GeminiClient({required this.tools, required String systemInstruction}) { | ||
final functionDeclarations = <fai.FunctionDeclaration>[]; | ||
final adapter = GeminiSchemaAdapter(); | ||
for (final tool in tools) { | ||
fai.Schema? adaptedParameters; | ||
if (tool.parameters != null) { | ||
final result = adapter.adapt(tool.parameters!); | ||
if (result.errors.isNotEmpty) { | ||
_logger.warning( | ||
'Errors adapting parameters for tool ${tool.name}: ' | ||
'${result.errors.join('\n')}', | ||
); | ||
} | ||
adaptedParameters = result.schema; | ||
} | ||
final parameters = adaptedParameters?.properties; | ||
functionDeclarations.add( | ||
fai.FunctionDeclaration( | ||
tool.name, | ||
tool.description, | ||
parameters: parameters ?? const {}, | ||
), | ||
); | ||
} | ||
|
||
_logger.info( | ||
'Registered tools: ${functionDeclarations.map((d) => d.toJson()).join(', ')}', | ||
); | ||
|
||
_model = fai.FirebaseAI.googleAI().generativeModel( | ||
model: 'gemini-2.5-flash', | ||
systemInstruction: fai.Content.system(systemInstruction), | ||
tools: [fai.Tool.functionDeclarations(functionDeclarations)], | ||
); | ||
} | ||
|
||
late final fai.GenerativeModel _model; | ||
final List<AiTool> tools; | ||
final _logger = Logger('GeminiClient'); | ||
|
||
Future<fai.GenerateContentResponse> generate( | ||
Iterable<fai.Content> history, | ||
) async { | ||
final mutableHistory = List.of(history); | ||
var toolUsageCycle = 0; | ||
const maxToolUsageCycles = 10; | ||
|
||
while (toolUsageCycle < maxToolUsageCycles) { | ||
toolUsageCycle++; | ||
|
||
final concatenatedContents = mutableHistory | ||
.map((c) => const JsonEncoder.withIndent(' ').convert(c.toJson())) | ||
.join('\n'); | ||
|
||
_logger.info( | ||
'****** Performing Inference ******\n$concatenatedContents\n' | ||
'With functions:\n' | ||
' ${tools.map((t) => t.name).join(', ')}', | ||
); | ||
|
||
final inferenceStartTime = DateTime.now(); | ||
final response = await _model.generateContent(mutableHistory); | ||
final elapsed = DateTime.now().difference(inferenceStartTime); | ||
|
||
final candidate = response.candidates.first; | ||
final content = candidate.content; | ||
mutableHistory.add(content); | ||
|
||
_logger.info( | ||
'****** Completed Inference ******\n' | ||
'Latency = ${elapsed.inMilliseconds}ms\n' | ||
'Output tokens = ${response.usageMetadata?.candidatesTokenCount ?? 0}\n' | ||
'Prompt tokens = ${response.usageMetadata?.promptTokenCount ?? 0}\n' | ||
'${const JsonEncoder.withIndent(' ').convert(content.toJson())}', | ||
); | ||
|
||
final functionCalls = content.parts | ||
.whereType<fai.FunctionCall>() | ||
.toList(); | ||
|
||
if (functionCalls.isEmpty) { | ||
return response; | ||
} | ||
|
||
final functionResponses = <fai.FunctionResponse>[]; | ||
for (final call in functionCalls) { | ||
final tool = tools.firstWhere((t) => t.name == call.name); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using A safer approach would be to use a method that doesn't throw, like |
||
final result = await tool.invoke(call.args); | ||
functionResponses.add(fai.FunctionResponse(call.name, result)); | ||
} | ||
|
||
mutableHistory.add(fai.Content.functionResponses(functionResponses)); | ||
} | ||
|
||
throw Exception('Max tool usage cycles reached'); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_configureLogging
function usesprint
to output log records. The// ignore: avoid_print
indicates you're aware of the lint. For better practice, consider usinglog
fromdart:developer
(you'll need to import it as e.g. 'developer'). It's automatically stripped from release builds and integrates well with developer tools.