diff --git a/firebaseai/src/TemplateChatSession.cs b/firebaseai/src/TemplateChatSession.cs new file mode 100644 index 00000000..42e831ae --- /dev/null +++ b/firebaseai/src/TemplateChatSession.cs @@ -0,0 +1,275 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Firebase.AI.Internal; + +namespace Firebase.AI +{ + /// + /// An object that represents a back-and-forth chat with a template model, capturing the history + /// and saving the context in memory between each message sent. + /// + public class TemplateChatSession + { + private readonly TemplateGenerativeModel _generativeModel; + private readonly string _templateId; + private readonly IDictionary _inputs; + private readonly List _chatHistory; + private readonly List _tools; + private readonly TemplateToolConfig _toolConfig; + private readonly Dictionary _autoFunctions; + private readonly int _maxTurns; + + // Use a SemaphoreSlim as a mutex lock. + private readonly SemaphoreSlim _mutex = new SemaphoreSlim(1, 1); + + /// + /// The previous content from the chat that has been successfully sent and received from the + /// model. This will be provided to the model for each message sent as context for the discussion. + /// + public IReadOnlyList History => _chatHistory; + + // Note: No public constructor, get one through TemplateGenerativeModel.StartChat + private TemplateChatSession( + TemplateGenerativeModel model, + string templateId, + IDictionary inputs, + IEnumerable initialHistory, + IEnumerable tools, + TemplateToolConfig toolConfig, + int maxTurns) + { + _generativeModel = model; + _templateId = templateId; + _inputs = inputs ?? new Dictionary(); + _tools = tools?.ToList() ?? new List(); + _toolConfig = toolConfig; + _maxTurns = maxTurns; + + if (initialHistory != null) + { + _chatHistory = new List(initialHistory); + } + else + { + _chatHistory = new List(); + } + + _autoFunctions = new Dictionary(); + if (tools != null) + { + foreach (var tool in tools) + { + foreach (var function in tool.GetAutoFunctionDeclarations()) + { + _autoFunctions[function.Name] = function; + } + } + } + } + + /// + /// Intended for internal use only. + /// Use `TemplateGenerativeModel.StartChat` instead to ensure proper initialization. + /// + internal static TemplateChatSession InternalCreateChat( + TemplateGenerativeModel model, + string templateId, + IDictionary inputs, + IEnumerable initialHistory, + IEnumerable tools, + TemplateToolConfig toolConfig, + int maxTurns) + { + return new TemplateChatSession(model, templateId, inputs, initialHistory, tools, toolConfig, maxTurns); + } + + /// + /// Sends a message using the existing history of this chat as context. + /// + /// The input given to the model as a prompt. + /// An optional token to cancel the operation. + /// The model's response if no error occurred. + public Task SendMessageAsync( + ModelContent content, CancellationToken cancellationToken = default) + { + return SendMessageAsyncInternal(content, cancellationToken); + } + + /// + /// Sends a message using the existing history of this chat as context. + /// + /// The text given to the model as a prompt. + /// An optional token to cancel the operation. + /// The model's response if no error occurred. + public Task SendMessageAsync( + string text, CancellationToken cancellationToken = default) + { + return SendMessageAsync(ModelContent.Text(text), cancellationToken); + } + + /// + /// Sends a message using the existing history of this chat as context. + /// + /// The input given to the model as a prompt. + /// An optional token to cancel the operation. + /// A stream of generated content responses from the model. + public IAsyncEnumerable SendMessageStreamAsync( + ModelContent content, CancellationToken cancellationToken = default) + { + return SendMessageStreamAsyncInternal(content, cancellationToken); + } + + /// + /// Sends a message using the existing history of this chat as context. + /// + /// The text given to the model as a prompt. + /// An optional token to cancel the operation. + /// A stream of generated content responses from the model. + public IAsyncEnumerable SendMessageStreamAsync( + string text, CancellationToken cancellationToken = default) + { + return SendMessageStreamAsync(ModelContent.Text(text), cancellationToken); + } + + private async Task SendMessageAsyncInternal( + ModelContent requestContent, CancellationToken cancellationToken = default) + { + await _mutex.WaitAsync(cancellationToken); + try + { + ModelContent message = FirebaseAIExtensions.ConvertToUser(requestContent); + List requestHistory = new List { message }; + + int turn = 0; + while (turn < _maxTurns) + { + List fullRequest = new List(_chatHistory); + fullRequest.AddRange(requestHistory); + + var response = await _generativeModel.GenerateContentAsyncInternal( + _templateId, _inputs, fullRequest, _tools, _toolConfig, cancellationToken); + + if (!response.Candidates.Any()) + { + return response; + } + + var candidate = response.Candidates.First(); + var parts = candidate.Content.Parts; + var functionCalls = parts.OfType().ToList(); + + bool shouldAutoExecute = _autoFunctions.Count > 0 && functionCalls.Count > 0 && + functionCalls.All(c => _autoFunctions.ContainsKey(c.Name)); + + if (!shouldAutoExecute) + { + _chatHistory.Add(message); + _chatHistory.Add(candidate.Content.ConvertToModel()); + return response; + } + + // Auto function execution + requestHistory.Add(candidate.Content.ConvertToModel()); + var functionResponses = new List(); + + foreach (var call in functionCalls) + { + var function = _autoFunctions[call.Name]; + object result; + try + { + if (function.Callable != null) + { + result = await function.Callable(call.Args.ToDictionary(k => k.Key, k => k.Value)); + } + else + { + result = null; + } + } + catch (Exception ex) + { + result = ex.Message; + } + + // Wrap the result in a {"result": ...} dictionary as expected by the Dart implementation + functionResponses.Add(new ModelContent.FunctionResponsePart( + call.Name, + new Dictionary { { "result", result } }, + call.Id)); + } + + requestHistory.Add(new ModelContent("function", functionResponses)); + turn++; + } + + throw new InvalidOperationException($"Max turns of {_maxTurns} reached."); + } + finally + { + _mutex.Release(); + } + } + + private async IAsyncEnumerable SendMessageStreamAsyncInternal( + ModelContent requestContent, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await _mutex.WaitAsync(cancellationToken); + try + { + ModelContent message = FirebaseAIExtensions.ConvertToUser(requestContent); + List fullRequest = new List(_chatHistory) { message }; + + List responseContents = new List(); + bool saveHistory = true; + + await foreach (GenerateContentResponse response in + _generativeModel.GenerateContentStreamAsyncInternal(_templateId, _inputs, fullRequest, _tools, _toolConfig, cancellationToken)) + { + if (response.Candidates.Any()) + { + ModelContent responseContent = response.Candidates.First().Content; + responseContents.Add(responseContent.ConvertToModel()); + } + else + { + saveHistory = false; + } + + yield return response; + } + + if (saveHistory && responseContents.Count > 0) + { + _chatHistory.Add(message); + _chatHistory.AddRange(responseContents); + } + } + finally + { + _mutex.Release(); + } + } + } +} diff --git a/firebaseai/src/TemplateGenerativeModel.cs b/firebaseai/src/TemplateGenerativeModel.cs index 0eec10a2..44925e61 100644 --- a/firebaseai/src/TemplateGenerativeModel.cs +++ b/firebaseai/src/TemplateGenerativeModel.cs @@ -89,7 +89,9 @@ public IAsyncEnumerable GenerateContentStreamAsync( } private string MakeGenerateContentRequest(IDictionary inputs, - IEnumerable chatHistory) + IEnumerable chatHistory, + IEnumerable tools, + TemplateToolConfig toolConfig) { var jsonDict = new Dictionary() { @@ -99,12 +101,22 @@ private string MakeGenerateContentRequest(IDictionary inputs, { jsonDict["history"] = chatHistory.Select(t => t.ToJson()).ToList(); } + if (tools != null && tools.Any()) + { + jsonDict["tools"] = tools.Select(t => t.ToJson()).ToList(); + } + if (toolConfig != null) + { + jsonDict["toolConfig"] = toolConfig.ToJson(); + } return Json.Serialize(jsonDict); } - private async Task GenerateContentAsyncInternal( + internal async Task GenerateContentAsyncInternal( string templateId, IDictionary inputs, IEnumerable chatHistory, + IEnumerable tools, + TemplateToolConfig toolConfig, CancellationToken cancellationToken) { HttpRequestMessage request = new(HttpMethod.Post, @@ -114,7 +126,7 @@ private async Task GenerateContentAsyncInternal( await Firebase.Internal.HttpHelpers.SetRequestHeaders(request, _firebaseApp); // Set the content - string bodyJson = MakeGenerateContentRequest(inputs, chatHistory); + string bodyJson = MakeGenerateContentRequest(inputs, chatHistory, tools, toolConfig); request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json"); #if FIREBASE_LOG_REST_CALLS @@ -133,9 +145,11 @@ private async Task GenerateContentAsyncInternal( return GenerateContentResponse.FromJson(result, _backend.Provider); } - private async IAsyncEnumerable GenerateContentStreamAsyncInternal( + internal async IAsyncEnumerable GenerateContentStreamAsyncInternal( string templateId, IDictionary inputs, IEnumerable chatHistory, + IEnumerable tools, + TemplateToolConfig toolConfig, [EnumeratorCancellation] CancellationToken cancellationToken) { HttpRequestMessage request = new(HttpMethod.Post, @@ -145,7 +159,7 @@ private async IAsyncEnumerable GenerateContentStreamAsy await Firebase.Internal.HttpHelpers.SetRequestHeaders(request, _firebaseApp); // Set the content - string bodyJson = MakeGenerateContentRequest(inputs, chatHistory); + string bodyJson = MakeGenerateContentRequest(inputs, chatHistory, tools, toolConfig); request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json"); #if FIREBASE_LOG_REST_CALLS @@ -173,5 +187,26 @@ private async IAsyncEnumerable GenerateContentStreamAsy } } } + + /// + /// Starts a that uses this template generative model to respond to messages. + /// + /// The id of the server prompt template to use. + /// Any input parameters expected by the server prompt template. + /// Optional chat history. + /// Optional tools (e.g., auto functions) to use. + /// Optional tool configuration. + /// Maximum number of interactions for auto functions to execute. + /// A new instance. + public TemplateChatSession StartChat( + string templateId, + IDictionary inputs = null, + IEnumerable history = null, + IEnumerable tools = null, + TemplateToolConfig toolConfig = null, + int maxTurns = 5) + { + return TemplateChatSession.InternalCreateChat(this, templateId, inputs, history, tools, toolConfig, maxTurns); + } } } diff --git a/firebaseai/src/TemplateTool.cs b/firebaseai/src/TemplateTool.cs new file mode 100644 index 00000000..c26e52a8 --- /dev/null +++ b/firebaseai/src/TemplateTool.cs @@ -0,0 +1,167 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Firebase.AI +{ + /// + /// A class representing a generic function declaration used with a template model. + /// + public class TemplateFunctionDeclaration + { + /// + /// The name of the function. + /// + public string Name { get; } + + private readonly Schema _schemaObject; + + /// + /// Constructs a TemplateFunctionDeclaration. + /// + /// The name of the function. + /// Optional dictionary of parameters schema. + /// Optional list of parameter names that are not required. + public TemplateFunctionDeclaration( + string name, + IDictionary parameters = null, + IEnumerable optionalParameters = null) + { + Name = name; + _schemaObject = parameters != null ? Schema.Object(parameters, optionalParameters) : null; + } + + /// + /// Intended for internal use only. + /// This method is used for serializing the object to JSON for the API request. + /// + internal Dictionary ToJson() + { + return new Dictionary() + { + { "name", Name }, + { "input_schema", _schemaObject != null ? _schemaObject.ToJson() : "" } + }; + } + } + + /// + /// A class representing an auto-function declaration that provides an execution callable for template chat. + /// + public class TemplateAutoFunctionDeclaration : TemplateFunctionDeclaration + { + /// + /// The callable function that this declaration represents. + /// + public Func, Task>> Callable { get; } + + /// + /// Constructs a TemplateAutoFunctionDeclaration. + /// + /// The name of the function. + /// The function to execute when requested by the model. + /// Optional dictionary of parameters schema. + /// Optional list of parameter names that are not required. + public TemplateAutoFunctionDeclaration( + string name, + Func, Task>> callable, + IDictionary parameters = null, + IEnumerable optionalParameters = null) + : base(name, parameters, optionalParameters) + { + Callable = callable; + } + } + + /// + /// Describes a set of template tools that can be passed to a . + /// + public readonly struct TemplateTool + { + private readonly List _functionDeclarations; + + /// + /// Creates a TemplateTool containing a collection of TemplateFunctionDeclarations. + /// + public static TemplateTool FunctionDeclarations(IEnumerable functionDeclarations) + { + return new TemplateTool(functionDeclarations); + } + + /// + /// Creates a TemplateTool containing a collection of TemplateFunctionDeclarations. + /// + public static TemplateTool FunctionDeclarations(params TemplateFunctionDeclaration[] functionDeclarations) + { + return new TemplateTool(functionDeclarations); + } + + private TemplateTool(IEnumerable functionDeclarations) + { + _functionDeclarations = functionDeclarations?.ToList() ?? new List(); + } + + /// + /// Intended for internal use only. + /// Returns the subset of TemplateFunctionDeclarations that are TemplateAutoFunctionDeclarations. + /// + internal IEnumerable GetAutoFunctionDeclarations() + { + if (_functionDeclarations == null) return Enumerable.Empty(); + return _functionDeclarations.OfType(); + } + + /// + /// Intended for internal use only. + /// This method is used for serializing the object to JSON for the API request. + /// + internal Dictionary ToJson() + { + var json = new Dictionary(); + if (_functionDeclarations != null && _functionDeclarations.Any()) + { + json["functionDeclarations"] = _functionDeclarations.Select(f => f.ToJson()).ToList(); + } + return json; + } + } + + /// + /// Tool configuration for any specified in the request. + /// + public class TemplateToolConfig + { + /// + /// Constructs a new . + /// + public TemplateToolConfig() + { + } + + /// + /// Intended for internal use only. + /// This method is used for serializing the object to JSON for the API request. + /// + internal Dictionary ToJson() + { + return new Dictionary(); + } + } +}