diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3ff08d642f3..6bcdb432440 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -66,3 +66,4 @@ /generative-ai/open-models/serving/cloud_run_ollama_qwen3_inference.ipynb @GoogleCloudPlatform/generative-ai-devrel @vladkol /generative-ai/open-models/get_started_with_model_garden_sdk.ipynb @GoogleCloudPlatform/generative-ai-devrel @inardini @lizzij /generative-ai/open-models/use-cases/model_garden_litellm_inference.ipynb @GoogleCloudPlatform/generative-ai-devrel @lizzij +/generative-ai/llmevalkit @GoogleCloudPlatform/generative-ai-devrel @santoromike @lkatherine diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 6787182e4b4..8ed1fa4fed0 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -170,6 +170,7 @@ CASP catus caudatus caxis +cbp cce cctv cdiv @@ -236,6 +237,7 @@ COUNTIF countplot covs Cowabunga +cpb cpet crewai CRF @@ -799,9 +801,11 @@ mapbox marp Masaru maskmode +mathvista mavenrc maxcold Mbappe +mbp mbsdk mcp MCP @@ -856,6 +860,7 @@ Mookie morty Mosi moviepy +mpb mpe mpegps mpga @@ -1332,6 +1337,7 @@ THT THTH Tianli Tianxiang +ticktext tickvals tiktoken timechart @@ -1427,6 +1433,7 @@ vectoral vectordb veo Verilog +versioned Ves Vesia vesselin diff --git a/.github/actions/spelling/line_forbidden.patterns b/.github/actions/spelling/line_forbidden.patterns index e4934eba56f..cf3ec0bdb9c 100644 --- a/.github/actions/spelling/line_forbidden.patterns +++ b/.github/actions/spelling/line_forbidden.patterns @@ -308,7 +308,7 @@ \w \w # Don't use "smart quotes" -(?!'")[‘’“”] +(? to see your current agreements or to +sign a new one. + +### Review our community guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). + +## Contribution process + +### Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. diff --git a/tools/llmevalkit/LICENSE b/tools/llmevalkit/LICENSE new file mode 100644 index 00000000000..7a4a3ea2424 --- /dev/null +++ b/tools/llmevalkit/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/tools/llmevalkit/README.md b/tools/llmevalkit/README.md new file mode 100644 index 00000000000..573e173a39f --- /dev/null +++ b/tools/llmevalkit/README.md @@ -0,0 +1,214 @@ +# LLM EvalKit + +## Summary + +LLMEvalKit is a tool designed to help developers evaluate and improve the performance of Large Language Models (LLMs) on specific tasks. It provides a comprehensive workflow to create, test, and optimize prompts, manage datasets, and analyze evaluation results. With LLMEvalKit, developers can conduct both human and model-based evaluations, compare results, and use automated processes to refine prompts for better accuracy and relevance. This toolkit streamlines the iterative process of prompt engineering and evaluation, enabling developers to build more effective and reliable LLM-powered applications. + + + +**Authors: [Mike Santoro](https://github.com/Michael-Santoro), [Katherine Larson](https://github.com/larsonk)** + +## 🚀 Getting Started + +There are two ways to work through a tutorial of this application one method is more stable one is less stable. + +1. Scroll down to the Tutorial Section here. + +2. Open this [notebook](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/tools/llmevalkit/prompt-management-tutorial.ipynb) in colab running the application on a colab server. + +## Overview + +This tutorial provides a comprehensive guide to prompt engineering, covering the entire lifecycle from creation to evaluation and optimization. It's broken down into the following sections: + +1. **Prompt Management:** This section focuses on the core tasks of creating, editing, and managing prompts. You can: + - **Create new prompts:** Define the prompt's name, text, the model it's designed for, and any system instructions. + - **Load and edit existing prompts:** Browse a library of saved prompts, load a specific version, and make modifications. + - **Test prompts:** Before saving, you can provide sample input and generate a response to see how the prompt performs. + - **Versioning:** Each time you save a change to a prompt, a new version is created, allowing you to track its evolution and compare different iterations. + +2. **Dataset Creation:** A crucial part of prompt engineering is having good data to test and evaluate your prompts. This section allows you to: + + - **Create new datasets:** A dataset is essentially a folder in Google Cloud Storage where you can group related files. + - **Upload data:** You can upload files in CSV, JSON, or JSONL format to your datasets. This data will be used for evaluating your prompts. + +3. **Evaluation:** Once you have a prompt and a dataset, you need to see how well the prompt performs. The evaluation section helps you with this by: + + - **Running evaluations:** You can select a prompt and a dataset and run an evaluation. This will generate responses from the model for each item in your dataset. + - **Human-in-the-loop rating:** For a more nuanced evaluation, you can manually review the model's responses and rate them. + - **Automated metrics:** The tutorial also supports automated evaluation metrics to get a quantitative measure of your prompt's performance. + +4. **Prompt Optimization:** This section helps you automatically improve your prompts. It uses Vertex AI's prompt optimization capabilities to: + + - **Configure and launch optimization jobs:** You can set up and run a job that will take your prompt and a dataset and try to find a better-performing version of the prompt. + +5. **Prompt Optimization Results:** After an optimization job has run, this section allows you to: + + - **View the results:** You can see the different prompt versions that the optimizer came up with and how they performed. + - **Compare versions:** The results are presented in a way that makes it easy to compare the different optimized prompts and choose the best one. + +6. **Prompt Records:** This is a leaderboard that shows you the evaluation results of all your different prompt versions. It helps you to: + + - **Track performance over time:** See how your prompts have improved with each new version. + - **Compare different prompts:** You can compare the performance of different prompts for the same task. + +In summary, this tutorial provides a complete and integrated environment for all your prompt engineering needs, from initial creation to sophisticated optimization and evaluation. + +## Tutorial: Step-by-Step + +This section walks you through using the app. + +### 0. Startup + +First, clone the repository and set up the environment: + +```bash +# Clone the repository +git clone https://github.com/GoogleCloudPlatform/generative-ai.git + +# Navigate to the project directory +cd generative-ai/tools/llmevalkit + +# Create a Python virtual environment +python -m venv venv + +# Activate the virtual environment +source venv/bin/activate + +# Install the required packages +pip install -r requirements.txt + +# Run the Streamlit application +streamlit run index.py +``` + +Next, `cp src/.env.example src/.env` open the file and set `BUCKET_NAME` and `PROJECT_ID` + +### 1. Prompt Management + +In the Prompt Name field enter: + +``` +math_prompt_test +``` + +In the Prompt Data field enter: + +``` +Problem: {{query}} +Image: {{image}} @@@image/jpeg +Answer: {{target}} +``` + +In the Model Name field enter: +``` +gemini-2.0-flash-001 +``` + +In the System Instructions field enter: +``` +Solve the problem given the image. +``` + +Click `Save` + +Copy this text for testing: + +``` +{"query": "Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\nQuestion: As shown in the figure, CD is the diameter of \u2299O, chord DE \u2225 OA, if the degree of \u2220D is 50.0, then the degree of \u2220C is ()", "Choices":\n(A) 25\u00b0\n(B) 30\u00b0\n(C) 40\u00b0\n(D) 50\u00b0", "image": "gs://github-repo/prompts/prompt_optimizer/mathvista_dataset/images/643.jpg", "target": "25\u00b0"} +``` + +🖱️ Click `Generate`. + +### 2. Dataset Creation + +Download a copy of the dataset. Then upload this file in the application. + +**Dataset Name:** `mathvista` + +You can preview the dataset at the bottom of the page. + +To download the dataset, run this command: +```bash +gsutil cp gs://github-repo/prompts/prompt_optimizer/mathvista_dataset/mathvista_input.jsonl . +``` + +### 3. Evaluation + +We will now run an evaluation, prior to doing any tweaking to get a baseline. + +- **Existing Dataset:** 'mathvista' +- **Dataset File:** 'mathvista_input.jsonl' +- **Number of Samples:** '100' +- **Ground Truth Column Name:** 'target' +- **Existing Prompt:** 'math_prompt_test' +- **Version:** '1' + +Click Load Prompt, and Upload and Get Response... ⏰ Wait!! + +Review the responses. + +- **Model-Based:** 'question-answering-quality' + +Launch the Eval... ⏰ Wait!! + +View the Evaluation Results, and save to prompt records. This will save this initial version to the prompt records for the baseline. + +### 4. Prompt Optimization + +🔧 Set-Up Prompt Optimization. + +- **Target Model:** 'gemini-2.0-flash-001' +- **Existing Prompt:** 'math_prompt_test' +- **Version:** '1' + +🖱️ Click Load Prompt. + +- **Select Existing Dataset:** 'mathvista' +- **Select the File:** 'mathvista_input.jsonl' + +🖱️ Click Load Dataset. + +Preview the dataset. + +🖱️ Click Start Optimization. + +**Note:** If Interested in viewing the progress, Navigate to https://console.cloud.google.com/vertex-ai/training/custom-jobs + +⏰ Wait!! This step will take about 20-min to run. + +### 5. Prompt Optimization Results + +View the Optimization Results. + +The last run will be shown at the top of the screen. Pick this from the dropdown menu: + + + +Review the results and select the highest scoring version and copy the instruction. + +### 6. Navigate Back to Prompt for New Version + +Load your existing prompt from before. + +📋 Paste your new instructions from the prompt optimizer, and save new version. + +### 7. Run new Evaluation + +Repeat step 3 with your new version. + +### 8. View the Records + +Navigate to the leaderboard and load the results. + +## License +``` +Copyright 2025 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language +``` diff --git a/tools/llmevalkit/assets/favicon.ico b/tools/llmevalkit/assets/favicon.ico new file mode 100644 index 00000000000..7217989a667 Binary files /dev/null and b/tools/llmevalkit/assets/favicon.ico differ diff --git a/tools/llmevalkit/assets/image.gif b/tools/llmevalkit/assets/image.gif new file mode 100644 index 00000000000..9f9117ed3bb Binary files /dev/null and b/tools/llmevalkit/assets/image.gif differ diff --git a/tools/llmevalkit/assets/prompt_optimization_result.png b/tools/llmevalkit/assets/prompt_optimization_result.png new file mode 100644 index 00000000000..a9e6a206a25 Binary files /dev/null and b/tools/llmevalkit/assets/prompt_optimization_result.png differ diff --git a/tools/llmevalkit/assets/welcome_page.png b/tools/llmevalkit/assets/welcome_page.png new file mode 100644 index 00000000000..62696f50389 Binary files /dev/null and b/tools/llmevalkit/assets/welcome_page.png differ diff --git a/tools/llmevalkit/index.py b/tools/llmevalkit/index.py new file mode 100644 index 00000000000..c7c0572bf9b --- /dev/null +++ b/tools/llmevalkit/index.py @@ -0,0 +1,54 @@ +## Copyright 2025 Google LLC +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## https://www.apache.org/licenses/LICENSE-2.0 +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language + + +"""The main landing page for the LLM EvalKit Streamlit application.""" + +import streamlit as st +from dotenv import load_dotenv + +load_dotenv("src/.env") + + +def main() -> None: + """Renders the main landing page of the application.""" + st.set_page_config( + page_title="LLM EvalKit", + layout="wide", + initial_sidebar_state="expanded", + page_icon="assets/favicon.ico", + ) + + st.title("Welcome to the LLM EvalKit") + st.markdown( + "A suite of tools for managing, evaluating, and optimizing LLM prompts and datasets." + ) + + st.subheader("Getting Started") + st.markdown( + """ + This application helps you streamline your prompt engineering workflow. + Select a tool from the sidebar on the left to begin. + + **Available Tools:** + * **Prompt Management:** Create, test, and manage your prompts. + * **Dataset Creation:** Create evaluation datasets from CSV files. + * **Simple Evaluation:** Run simple evaluations on your prompts. + * **Evaluation Human Judge:** Manually rate model responses for evaluation. + * **Prompt Optimization:** Optimize your prompts for better performance. + * **Prompt Optimization Results:** View the results of prompt optimization runs. + * **Prompt Records:** View and manage your prompt records. + """ + ) + st.caption("LLM EvalKit | Home") + + +if __name__ == "__main__": + main() diff --git a/tools/llmevalkit/pages/1_Prompt_Management.py b/tools/llmevalkit/pages/1_Prompt_Management.py new file mode 100644 index 00000000000..2aa269ccd5c --- /dev/null +++ b/tools/llmevalkit/pages/1_Prompt_Management.py @@ -0,0 +1,590 @@ +# Copyright 2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language + +"""Streamlit user interface for managing prompts in the LLM EvalKit. + +This page provides a comprehensive interface for prompt engineering, allowing users +to create, load, edit, and test prompts that are stored and versioned in a +backend service (e.g., Google Cloud's Vertex AI Prompt Management). + +The page is divided into two main sections: +1. **Create New Prompt**: A form to define a new prompt from scratch, including + its name, text, model, system instructions, and other metadata. Users can + test the prompt with sample input before saving it. +2. **Load & Edit Prompt**: A section to load existing prompts and their specific + versions. Users can modify the loaded prompt's details and save the changes + as a new version, facilitating iterative development and A/B testing. + +Helper functions handle JSON parsing, data type conversions, and interactions +with the `gcp_prompt` object, which abstracts the backend communication. +""" + +import json +import logging +from typing import Any + +import streamlit as st +from dotenv import load_dotenv +from src.gcp_prompt import GcpPrompt as gcp_prompt +from vertexai.preview import prompts + +# --- Initial Configuration --- +load_dotenv("src/.env") +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# --- Constants --- +AVAILABLE_PROMPT_TASKS = [ + "Classification", + "Summarization", + "Translation", + "Creative Writing", + "Q&A", +] + + +# --- Helper Functions --- +def _parse_json_input(json_string: str, field_name: str) -> dict[str, Any] | None: + """Safely parses a JSON string from a text area. + + Cleans the input string to handle common copy-paste errors and displays + an error in the Streamlit UI if parsing fails. + + Args: + json_string: The raw string from a Streamlit text_area. + field_name: The user-facing name of the field for error messages. + + Returns: + A dictionary if parsing is successful, otherwise None. + """ + if not json_string: + return None + try: + # Clean up common copy-paste issues like smart quotes and newlines + json_string_cleaned = ( + json_string.replace("’", "'") + .replace("\n", " ") + .replace("\t", " ") + .replace("\r", "") + ) + return json.loads(json_string_cleaned) + except json.JSONDecodeError as e: + st.error(f"Invalid JSON format for {field_name}: {e}") + return None + + +def _apply_generation_config_typing(config: dict[str, Any]) -> dict[str, Any]: + """Applies correct data types to generation config parameters. + + Streamlit text inputs return strings, but the underlying API requires + specific types (e.g., float for temperature). This function converts + common configuration values to their expected types. + + Args: + config: The generation configuration dictionary with string values. + + Returns: + The configuration dictionary with values cast to the correct types. + """ + if "temperature" in config: + config["temperature"] = float(config["temperature"]) + if "top_p" in config: + config["top_p"] = float(config["top_p"]) + if "max_output_tokens" in config: + config["max_output_tokens"] = int(config["max_output_tokens"]) + return config + + +# --- Handlers for "Create New Prompt" Tab --- +def _handle_save_new_prompt() -> None: + """Validates inputs and saves a new prompt. + + Retrieves all necessary data from the Streamlit session state for the + "Create New Prompt" tab, validates that required fields are filled, + constructs the prompt object, and calls the backend service to save it. + Displays success or error messages in the UI. + """ + required_fields = { + "new_prompt_name": "Prompt Name", + "new_prompt_data": "Prompt Text", + "new_model_name": "Model Name", + "new_system_instructions": "System Instructions", + } + for key, name in required_fields.items(): + if not st.session_state.get(key): + st.warning(f"Please enter a value for {name}.") + return + + prompt_obj = st.session_state.local_prompt + prompt_obj.prompt_to_run.prompt_name = st.session_state.new_prompt_name + prompt_obj.prompt_to_run.prompt_data = st.session_state.new_prompt_data + prompt_obj.prompt_to_run.model_name = st.session_state.new_model_name.strip() + prompt_obj.prompt_to_run.system_instruction = ( + st.session_state.new_system_instructions + ) + + response_schema = _parse_json_input( + st.session_state.new_response_schema, "Response Schema" + ) + generation_config = _parse_json_input( + st.session_state.new_generation_config, "Generation Config" + ) + + if generation_config: + generation_config = _apply_generation_config_typing(generation_config) + if response_schema: + generation_config["response_schema"] = response_schema + prompt_obj.prompt_meta["generation_config"] = generation_config + + if response_schema: + prompt_obj.prompt_meta["response_schema"] = response_schema + prompt_obj.prompt_meta["meta_tags"] = st.session_state.new_meta_tags + + try: + logger.info("Saving new prompt...") + prompt_meta_info = prompt_obj.save_prompt(check_existing=True) + logger.info("Prompt saved successfully: %s", prompt_meta_info) + st.success("Prompt saved successfully!") + except Exception as e: + logger.error("Failed to save prompt: %s", e, exc_info=True) + st.error(f"Failed to save prompt: {e}") + + +def _handle_generate_test_for_new() -> None: + """Generates a test response for the new prompt form. + + Takes the user-provided sample input and the current prompt configuration + from the "Create" tab, sends it to the model for a response, and displays + the output in the UI. This allows for quick testing before saving. + """ + user_input_str = st.session_state.new_sample_user_input + if not user_input_str: + st.warning("Please provide sample user input to generate a response.") + return + + sample_user_input = _parse_json_input(user_input_str, "User Input") + if sample_user_input is None: + return + + try: + prompt_obj = st.session_state.local_prompt + prompt_obj.prompt_to_run.prompt_data = st.session_state.new_prompt_data + prompt_obj.prompt_to_run.model_name = st.session_state.new_model_name.strip() + prompt_obj.prompt_to_run.system_instruction = ( + st.session_state.new_system_instructions + ) + prompt_obj.prompt_meta["sample_user_input"] = sample_user_input + + with st.spinner("Generating response..."): + response = prompt_obj.generate_response(sample_user_input) + st.session_state.new_sample_output = response + st.success("Prompt response generated!") + except Exception as e: + logger.error("Error during test generation: %s", e, exc_info=True) + st.error(f"An error occurred during generation: {e}") + + +# --- Handlers for "Load & Edit Prompt" Tab --- +def _populate_ui_from_prompt() -> None: + """Populates session state for UI widgets from the loaded prompt object. + + After a prompt is loaded from the backend, this function takes the data + from the `gcp_prompt` object and sets the corresponding values in the + Streamlit session state. This updates the "Load & Edit" tab's input + widgets to display the loaded prompt's information. + """ + prompt_obj = st.session_state.local_prompt + st.session_state.edit_prompt_name = prompt_obj.prompt_to_run.prompt_name + st.session_state.edit_prompt_data = prompt_obj.prompt_to_run.prompt_data + st.session_state.edit_model_name = prompt_obj.prompt_to_run.model_name.split("/")[ + -1 + ] + st.session_state.edit_system_instructions = ( + prompt_obj.prompt_to_run.system_instruction + ) + st.session_state.edit_response_schema = json.dumps( + prompt_obj.prompt_meta.get("response_schema", {}), indent=2 + ) + st.session_state.edit_generation_config = json.dumps( + prompt_obj.prompt_meta.get("generation_config", {}), indent=2 + ) + st.session_state.edit_meta_tags = prompt_obj.prompt_meta.get("meta_tags", []) + st.session_state.edit_sample_user_input = json.dumps( + prompt_obj.prompt_meta.get("sample_user_input", {}), indent=2 + ) + st.session_state.edit_sample_output = "" # Clear previous output + + +def _handle_load_prompt() -> None: + """Loads the selected prompt and version and populates the UI. + + Triggered by the 'Load Prompt' button. It retrieves the selected prompt + name and version from the UI, calls the backend to fetch the data, + and then uses `_populate_ui_from_prompt` to display it. + """ + if not st.session_state.get("selected_prompt") or not st.session_state.get( + "selected_version" + ): + st.warning("Please select both a prompt and a version to load.") + return + + prompt_name = st.session_state.selected_prompt + prompt_id = st.session_state.local_prompt.existing_prompts[prompt_name] + version_id = st.session_state.selected_version + + try: + with st.spinner(f"Loading version '{version_id}' of prompt '{prompt_name}'..."): + st.session_state.local_prompt.load_prompt( + prompt_id, prompt_name, version_id + ) + logger.info( + "Successfully loaded prompt '%s' version '%s'.", prompt_name, version_id + ) + _populate_ui_from_prompt() + st.success(f"Loaded prompt '{prompt_name}' (Version: {version_id}).") + except Exception as e: + logger.error("Failed to load prompt: %s", e, exc_info=True) + st.error(f"Failed to load prompt: {e}") + + +def _handle_save_edited_prompt() -> None: + """Validates inputs and saves the current prompt config as a new version. + + Similar to saving a new prompt, but it takes the data from the "Edit" tab's + widgets. It saves the current configuration as a new version of the + already existing prompt. + """ + if not st.session_state.get("edit_prompt_name"): + st.warning("Cannot save. Please load a prompt first.") + return + + required_fields = { + "edit_prompt_data": "Prompt Text", + "edit_model_name": "Model Name", + "edit_system_instructions": "System Instructions", + } + for key, name in required_fields.items(): + if not st.session_state.get(key): + st.warning(f"Please ensure '{name}' is not empty.") + return + + prompt_obj = st.session_state.local_prompt + prompt_obj.prompt_to_run.prompt_name = st.session_state.edit_prompt_name + prompt_obj.prompt_to_run.prompt_data = st.session_state.edit_prompt_data + prompt_obj.prompt_to_run.model_name = st.session_state.edit_model_name.strip() + prompt_obj.prompt_to_run.system_instruction = ( + st.session_state.edit_system_instructions + ) + + response_schema = _parse_json_input( + st.session_state.edit_response_schema, "Response Schema" + ) + generation_config = _parse_json_input( + st.session_state.edit_generation_config, "Generation Config" + ) + + if generation_config: + generation_config = _apply_generation_config_typing(generation_config) + if response_schema: + generation_config["response_schema"] = response_schema + prompt_obj.prompt_meta["generation_config"] = generation_config + + if response_schema: + prompt_obj.prompt_meta["response_schema"] = response_schema + prompt_obj.prompt_meta["meta_tags"] = st.session_state.edit_meta_tags + + try: + with st.spinner("Saving as new version..."): + prompt_meta_info = prompt_obj.save_prompt(check_existing=False) + logger.info("Prompt saved successfully: %s", prompt_meta_info) + st.success("Saved as a new version successfully!") + st.session_state.local_prompt.refresh_prompt_cache() + except Exception as e: + logger.error("Failed to save prompt: %s", e, exc_info=True) + st.error(f"Failed to save prompt: {e}") + + +def _handle_generate_test_for_edit() -> None: + """Generates a test response for the edited prompt. + + Allows users to test changes made in the "Edit" tab before saving them + as a new version. It uses the current values in the UI fields to generate + a response from the model. + """ + if not st.session_state.get("edit_prompt_name"): + st.warning("Please load a prompt before generating a response.") + return + + user_input_str = st.session_state.get("edit_sample_user_input", "") + if not user_input_str: + st.warning("Please provide sample user input to generate a response.") + return + + sample_user_input = _parse_json_input(user_input_str, "Sample User Input") + if sample_user_input is None: + return + + try: + prompt_obj = st.session_state.local_prompt + prompt_obj.prompt_to_run.prompt_data = st.session_state.edit_prompt_data + prompt_obj.prompt_to_run.system_instruction = ( + st.session_state.edit_system_instructions + ) + prompt_obj.prompt_meta["sample_user_input"] = sample_user_input + + with st.spinner("Generating response..."): + response = prompt_obj.generate_response(sample_user_input) + st.session_state.edit_sample_output = response + st.success("Prompt response generated!") + except Exception as e: + logger.error("Error during test generation: %s", e, exc_info=True) + st.error(f"An error occurred during generation: {e}") + + +# --- UI Rendering Functions --- +def render_create_tab() -> None: + """Renders the UI components for the 'Create New Prompt' tab. + + This function defines and lays out all the Streamlit widgets (text inputs, + buttons, etc.) for the prompt creation workflow. + """ + st.subheader("1. Define Prompt Details") + st.text_input( + "**Prompt Name**", + key="new_prompt_name", + placeholder="e.g., customer_sentiment_classifier_v1", + help="A unique name to identify your prompt.", + ) + st.text_area( + "**Prompt Text**", + key="new_prompt_data", + height=150, + placeholder="e.g., Classify the sentiment of the following text: {customer_review}", + help="The core text of your prompt. Use curly braces `{}` for variables.", + ) + st.text_input( + "**Model Name**", + key="new_model_name", + placeholder="gemini-2.5-pro-001", + help="The specific model version to use (e.g., gemini-2.5-pro).", + ) + st.text_area( + "**System Instructions**", + key="new_system_instructions", + height=300, + placeholder="e.g., You are an expert in sentiment analysis...", + help="Optional instructions to guide the model's behavior.", + ) + st.multiselect( + "**Prompt Task**", + options=AVAILABLE_PROMPT_TASKS, + key="new_meta_tags", + help="Select the most appropriate task type for this prompt.", + ) + st.text_area( + "**Response Schema (JSON)**", + key="new_response_schema", + height=150, + placeholder='{\n "type": "object", ... \n}', + help="Define the desired JSON structure for the model's output.", + ) + st.text_area( + "**Generation Config (JSON)**", + key="new_generation_config", + height=150, + placeholder='{\n "temperature": 0.2, ... \n}', + help="A dictionary of generation parameters.", + ) + + if st.button( + "Save Prompt", type="primary", use_container_width=True, key="save_new" + ): + _handle_save_new_prompt() + + st.divider() + + st.subheader("2. Test Your Prompt") + st.markdown("You can test your prompt here before saving.") + st.text_area( + "**Sample User Input (JSON)**", + key="new_sample_user_input", + height=150, + placeholder='{\n "customer_review": "The product was amazing!"\n}', + help="A JSON object where keys match the variables in your prompt text.", + ) + + if st.button("Generate Test Response", use_container_width=True, key="test_new"): + _handle_generate_test_for_new() + + st.text_area( + "**Test Output**", + key="new_sample_output", + height=150, + placeholder="The model's response will be displayed here.", + disabled=True, + ) + + +def render_edit_tab() -> None: + """Renders the UI components for the 'Load & Edit Prompt' tab. + + This function defines and lays out all the Streamlit widgets for loading, + editing, and versioning existing prompts. + """ + st.subheader("1. Load Prompt") + if st.button("Refresh List"): + with st.spinner("Refreshing..."): + st.session_state.local_prompt.refresh_prompt_cache() + st.toast("Prompt list refreshed.") + + col1, col2 = st.columns(2) + with col1: + selected_prompt_name = st.selectbox( + "Select Existing Prompt", + options=st.session_state.local_prompt.existing_prompts.keys(), + placeholder="Select Prompt...", + key="selected_prompt", + help="Choose the prompt you want to load.", + ) + + with col2: + versions = [] + if selected_prompt_name: + try: + prompt_id = st.session_state.local_prompt.existing_prompts[ + selected_prompt_name + ] + versions = [v.version_id for v in prompts.list_versions(prompt_id)] + except Exception as e: + st.error(f"Could not fetch versions: {e}") + st.selectbox( + "Select Version", + options=versions, + placeholder="Select Version...", + key="selected_version", + help="Choose the specific version to load.", + ) + + st.button( + "Load Prompt", + on_click=_handle_load_prompt, + use_container_width=True, + type="primary", + ) + + st.divider() + + st.subheader("2. Edit Prompt Details") + st.text_input("Prompt Name", key="edit_prompt_name", disabled=True) + st.text_area("Prompt Text", key="edit_prompt_data", height=150) + st.text_input("Model Name", key="edit_model_name") + st.text_area("System Instructions", key="edit_system_instructions", height=300) + st.multiselect("Prompt Task", options=AVAILABLE_PROMPT_TASKS, key="edit_meta_tags") + + col_schema, col_config = st.columns(2) + with col_schema: + st.text_area("Response Schema (JSON)", key="edit_response_schema", height=200) + with col_config: + st.text_area( + "Generation Config (JSON)", key="edit_generation_config", height=200 + ) + + if st.button( + "Save as New Version", type="primary", use_container_width=True, key="save_edit" + ): + _handle_save_edited_prompt() + + st.divider() + + st.subheader("3. Test Your Prompt") + st.text_area("Sample User Input (JSON)", key="edit_sample_user_input", height=150) + + if st.button("Generate Test Response", use_container_width=True, key="test_edit"): + _handle_generate_test_for_edit() + + st.text_area( + "Test Output", + key="edit_sample_output", + height=150, + placeholder="The model's response will be displayed here.", + disabled=True, + ) + + +# --- Main Application --- +def main() -> None: + """Renders the main Prompt Management page. + + Sets the page configuration, initializes the session state (including the + `gcp_prompt` object and UI field defaults), and renders the main title + and tabbed layout for creating and editing prompts. + """ + st.set_page_config( + layout="wide", + page_title="Prompt Management", + page_icon="assets/favicon.ico", + ) + + # Initialize session state object and UI fields + if "local_prompt" not in st.session_state: + st.session_state.local_prompt = gcp_prompt() + + ui_fields = { + "new_prompt_name": "", + "new_prompt_data": "", + "new_model_name": "", + "new_system_instructions": "", + "new_response_schema": "", + "new_generation_config": "", + "new_meta_tags": [], + "new_sample_user_input": "", + "new_sample_output": "", + "edit_prompt_name": "", + "edit_prompt_data": "", + "edit_model_name": "", + "edit_system_instructions": "", + "edit_response_schema": "", + "edit_generation_config": "", + "edit_meta_tags": [], + "edit_sample_user_input": "", + "edit_sample_output": "", + } + for field, default_val in ui_fields.items(): + if field not in st.session_state: + st.session_state[field] = default_val + + st.title("Prompt Management") + st.markdown( + "Create new prompts or load, edit, and test existing ones from the Prompt Management service." + ) + st.divider() + + # Use st.radio to create stateful tabs that persist across reruns. + # This prevents the UI from resetting to the first tab on every interaction. + selected_tab = st.radio( + "Select Action", + ["Create New Prompt", "Load & Edit Prompt"], + key="prompt_management_tab", + horizontal=True, + label_visibility="collapsed", + ) + + if selected_tab == "Create New Prompt": + render_create_tab() + elif selected_tab == "Load & Edit Prompt": + render_edit_tab() + + st.caption("LLM EvalKit | Prompt Management") + + +if __name__ == "__main__": + main() diff --git a/tools/llmevalkit/pages/2_Dataset_Creation.py b/tools/llmevalkit/pages/2_Dataset_Creation.py new file mode 100644 index 00000000000..f84df2aa7bf --- /dev/null +++ b/tools/llmevalkit/pages/2_Dataset_Creation.py @@ -0,0 +1,250 @@ +# Copyright 2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language + +"""Streamlit page for creating and managing datasets in Google Cloud Storage.""" + +import logging +import os + +import streamlit as st +from dotenv import load_dotenv +from google.cloud import storage +from streamlit.runtime.uploaded_file_manager import UploadedFile + +# Load environment variables from .env file +load_dotenv("src/.env") + +# Configure logging to the console +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +@st.cache_data(ttl=300) +def get_existing_datasets( + _storage_client: storage.Client, bucket_name: str +) -> list[str]: + """Lists 'directories' in GCS under the 'datasets/' prefix. + These directories represent the existing datasets. + """ + if not bucket_name or not _storage_client: + return [] + + bucket = _storage_client.bucket(bucket_name) + prefix = "datasets/" + retrieved_prefixes = set() + + try: + # Explicitly iterate through pages for robustness. + iterator = bucket.list_blobs(prefix=prefix, delimiter="/") + for page in iterator.pages: + retrieved_prefixes.update(page.prefixes) + + # The retrieved prefixes are the "subdirectories". + # e.g., {'datasets/my_dataset_1/', 'datasets/my_dataset_2/'} + dir_names = [] + for p in retrieved_prefixes: + # Extract 'my_dataset_1' from 'datasets/my_dataset_1/' + name = p[len(prefix) :].strip("/") + if name: + dir_names.append(name) + logger.info(f"Found datasets: {dir_names}") + return sorted(dir_names) + except Exception as e: + st.error(f"Error listing datasets from GCS: {e}") + logger.error(f"Error in get_existing_datasets: {e}", exc_info=True) + return [] + + +def _handle_upload( + storage_client: storage.Client, + bucket_name: str, + dataset_name: str, + uploaded_file: UploadedFile, +) -> None: + """Handles the logic of uploading a file to GCS.""" + if not all([storage_client, bucket_name, dataset_name, uploaded_file]): + st.warning("Missing required information for upload.") + return + + try: + file_name = uploaded_file.name + content_type = "text/plain" # Default + if file_name.endswith(".csv"): + content_type = "text/csv" + elif file_name.endswith(".json"): + content_type = "application/json" + elif file_name.endswith(".jsonl"): + content_type = "application/x-jsonlines" + + blob_path = f"datasets/{dataset_name}/{uploaded_file.name}" + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(blob_path) + + with st.spinner(f"Uploading '{uploaded_file.name}' to '{dataset_name}'..."): + blob.upload_from_string(uploaded_file.getvalue(), content_type=content_type) + + st.success( + f"Successfully uploaded '{uploaded_file.name}' to dataset '{dataset_name}'!" + ) + logger.info(f"Uploaded file to gs://{bucket_name}/{blob_path}") + # Clear the cache for get_existing_datasets to reflect the new dataset if created + get_existing_datasets.clear() + st.rerun() + except Exception as e: + st.error(f"Failed to upload file: {e}") + logger.error(f"Error during GCS upload: {e}", exc_info=True) + + +def _ensure_datasets_folder_exists( + storage_client: storage.Client, bucket_name: str +) -> None: + """Ensures the 'datasets/' folder exists by creating a placeholder object if needed. + + This helps it appear in the GCS UI even when empty. + """ + if not storage_client or not bucket_name: + return + try: + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob("datasets/") + if not blob.exists(): + blob.upload_from_string("", content_type="application/x-directory") + logger.info( + f"Created placeholder for 'datasets/' folder in bucket '{bucket_name}'." + ) + except Exception as e: + # This is not a critical failure, so just log a warning. + logger.warning(f"Could not ensure 'datasets/' folder exists: {e}") + + +def main() -> None: + """Renders the Dataset Creation page.""" + st.set_page_config( + layout="wide", page_title="Dataset Management", page_icon="assets/favicon.ico" + ) + + # --- Initialize Session State & GCS Client --- + if "storage_client" not in st.session_state: + try: + st.session_state.storage_client = storage.Client() + except Exception as e: + st.error(f"Could not connect to Google Cloud Storage: {e}") + st.stop() + + BUCKET_NAME = os.getenv("BUCKET") + if not BUCKET_NAME: + st.error("BUCKET environment variable is not set. Please configure it in .env.") + st.stop() + + # Ensure the base 'datasets/' folder exists for UI consistency + _ensure_datasets_folder_exists(st.session_state.storage_client, BUCKET_NAME) + + st.title("Dataset Management") + st.markdown( + "Create new datasets or upload files (CSV, JSON, or JSONL) to existing ones. " + "A 'Dataset' is a folder in your GCS bucket used to group related evaluation files." + ) + st.divider() + + # --- Section 1: Upload File --- + st.subheader("1. Upload a File") + + existing_datasets = get_existing_datasets( + st.session_state.storage_client, BUCKET_NAME + ) + + # Let user choose whether to create a new dataset or add to an existing one + upload_mode = st.radio( + "Choose an action:", + ("Create a new dataset", "Add to an existing dataset"), + key="upload_mode", + horizontal=True, + ) + + dataset_name = "" + if upload_mode == "Create a new dataset": + dataset_name = st.text_input( + "Enter a name for the new dataset:", + key="new_dataset_name", + help="Use a descriptive name, e.g., 'sentiment_analysis_v1'.", + ) + else: + dataset_name = st.selectbox( + "Select an existing dataset:", + options=existing_datasets, + key="selected_dataset_for_upload", + help="Choose the dataset folder to upload your file into.", + index=None, + placeholder="Select a dataset...", + ) + + uploaded_file = st.file_uploader( + "Select a file to upload", + type=["csv", "json", "jsonl"], + key="file_uploader", + ) + + if st.button("Upload to Cloud Storage", type="primary", use_container_width=True): + if not dataset_name: + st.warning("Please provide or select a dataset name.") + elif not uploaded_file: + st.warning("Please select a file to upload.") + else: + _handle_upload( + st.session_state.storage_client, + BUCKET_NAME, + dataset_name, + uploaded_file, + ) + + st.divider() + + # --- Section 2: View Existing Datasets --- + st.subheader("2. View Existing Datasets") + + with st.expander("Browse datasets and their contents", expanded=True): + selected_dataset_to_view = st.selectbox( + "Select a dataset to view its contents:", + options=existing_datasets, + key="selected_dataset_for_view", + index=None, + placeholder="Select a dataset...", + ) + + if selected_dataset_to_view: + prefix = f"datasets/{selected_dataset_to_view}/" + blobs = st.session_state.storage_client.list_blobs( + BUCKET_NAME, prefix=prefix + ) + filenames = [ + os.path.basename(b.name) + for b in blobs + if b.name.endswith((".csv", ".json", ".jsonl")) + ] + + if filenames: + st.write(f"**Files in '{selected_dataset_to_view}':**") + st.text_area( + "Files", + value="\n".join(filenames), + height=150, + disabled=True, + label_visibility="collapsed", + ) + else: + st.info(f"No files found in the '{selected_dataset_to_view}' dataset.") + + st.caption("LLM EvalKit | Dataset Management") + + +if __name__ == "__main__": + main() diff --git a/tools/llmevalkit/pages/3_Evaluation.py b/tools/llmevalkit/pages/3_Evaluation.py new file mode 100644 index 00000000000..a350218681f --- /dev/null +++ b/tools/llmevalkit/pages/3_Evaluation.py @@ -0,0 +1,992 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import logging +import os +import re + +import pandas as pd +import streamlit as st +import vertexai +from dotenv import load_dotenv +from google.cloud import storage +from src.gcp_prompt import GcpPrompt as gcp_prompt +from vertexai.evaluation import ( + EvalTask, + MetricPromptTemplateExamples, + PairwiseMetricPromptTemplate, + PointwiseMetricPromptTemplate, +) +from vertexai.generative_models import ( + GenerationConfig, + GenerativeModel, + HarmBlockThreshold, + HarmCategory, +) +from vertexai.preview import prompts + +load_dotenv("src/.env") + + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def get_metric_object_by_name(metric_name: str): + """Retrieves a metric template object from its string name.""" + try: + return MetricPromptTemplateExamples._PROMPT_TEMPLATE_MAP[metric_name] + except Exception as e: + logger.exception(f"Failed to get metric object for {metric_name}: {e}") + raise + + +def refresh_bucket() -> list[str]: + """Refreshes the list of available dataset URIs from the GCS bucket. + + This function lists all blobs in the configured GCS bucket, filters for + CSV and JSONL files located within the 'datasets/' prefix, and constructs a list + of their full gs:// URI paths. + + Returns: + A list of strings, where each string is a GCS URI to a dataset file. + """ + logger.info("Bucket: %s", os.getenv("BUCKET")) + bucket = st.session_state.storage_client.bucket(os.getenv("BUCKET")) + blobs = bucket.list_blobs() + data_uris = [] + for i in blobs: + if i.name.split("/")[0] == "datasets" and ( + i.name.endswith(".csv") or i.name.endswith(".jsonl") + ): + data_uris.append(f"gs://{i.bucket.name}/{i.name}") + logger.info("Data URIs: %s", data_uris) + return data_uris + + +def get_autorater_pairwise_response(metric_prompt: str, model: str) -> dict: + """Gets a response from the autorater model for pairwise evaluation. + + Args: + metric_prompt: The prompt to send to the autorater model. + model: The name of the evaluation model to use. + + Returns: + A dictionary containing the autorater's response. + """ + metric_response_schema = { + "type": "OBJECT", + "properties": { + "pairwise_choice": {"type": "STRING"}, + "explanation": {"type": "STRING"}, + }, + "required": ["pairwise_choice", "explanation"], + } + + autorater = GenerativeModel( + model, + generation_config=GenerationConfig( + response_mime_type="application/json", + response_schema=metric_response_schema, + ), + safety_settings={ + HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + }, + ) + + response = autorater.generate_content(metric_prompt) + response_json = {} + + if response.candidates and len(response.candidates) > 0: + candidate = response.candidates[0] + if ( + candidate.content + and candidate.content.parts + and len(candidate.content.parts) > 0 + ): + part = candidate.content.parts[0] + if part.text: + response_json = json.loads(part.text) + + return response_json + + +def main() -> None: + """Initializes and runs the Streamlit evaluation application. + + This function sets up the Streamlit page configuration, initializes session state + variables, and orchestrates the user interface for the evaluation workflow. + It handles dataset and prompt selection, response generation or loading, + human-in-the-loop rating, and the launching of automated evaluations. + """ + st.set_page_config( + layout="wide", + page_title="Prompt Engineering App", + page_icon="assets/favicon.ico", + ) + + st.header("Evaluation Set-Up") + + if "storage_client" not in st.session_state: + st.session_state["storage_client"] = storage.Client() + + if "data_uris" not in st.session_state: + st.session_state["data_uris"] = refresh_bucket() + + if "current_index" not in st.session_state: + st.session_state.current_index = 0 + + if "eval_result" not in st.session_state: + st.session_state.eval_result = None + + if "custom_eval_result" not in st.session_state: + st.session_state.custom_eval_result = None + + if "df_data" not in st.session_state: + st.session_state.df_data = None + + if "df_dataset_eval" not in st.session_state: + st.session_state.df_dataset_eval = None + + if "all_metrics_eval" not in st.session_state: + st.session_state.all_metrics_eval = None + + if "metrics" not in st.session_state: + st.session_state.metrics = None + + if "local_prompt" not in st.session_state: + st.session_state.local_prompt = gcp_prompt() + + if "cached_data_files" not in st.session_state: + st.session_state.cached_data_files = {} + if "last_selected_dataset_for_cache" not in st.session_state: + st.session_state.last_selected_dataset_for_cache = None + + if "cached_prompt_versions" not in st.session_state: + st.session_state.cached_prompt_versions = {} + + if "last_selected_prompt_for_versions_cache" not in st.session_state: + st.session_state.last_selected_prompt_for_versions_cache = None + + if "human_rated_dict" not in st.session_state: + st.session_state.human_rated_dict = {} + + if "metric_preview_index" not in st.session_state: + st.session_state.metric_preview_index = 0 + + if "vertex_session_init" not in st.session_state: + vertexai.init( + project=os.getenv("PROJECT_ID"), + location=os.getenv("LOCATION"), + staging_bucket=os.getenv("BUCKET"), + experiment=os.getenv("EXPERIMENT_NAME"), + ) + st.session_state.vertex_session_init = True + + data_sets = list({i.split("/")[4] for i in st.session_state.data_uris}) + logger.info(f"Data Sets: {data_sets}") + st.selectbox( + "Select an Existing Dataset", options=[None, *data_sets], key="selected_dataset" + ) + + files_to_display_in_selectbox = [] + if st.session_state.selected_dataset: + if ( + st.session_state.selected_dataset + != st.session_state.last_selected_dataset_for_cache + or st.session_state.selected_dataset + not in st.session_state.cached_data_files + ): + logger.info( + "Cache miss or dataset changed for files. Fetching for: %s", + st.session_state.selected_dataset, + ) + bucket = st.session_state.storage_client.bucket(os.getenv("BUCKET")) + prefix = f"datasets/{st.session_state.selected_dataset}/" + blobs_iterator = bucket.list_blobs(prefix=prefix) + + current_dataset_files = [] + for blob in blobs_iterator: + if ( + blob.name.endswith(".csv") or blob.name.endswith(".jsonl") + ) and not blob.name.endswith("/"): + filename = blob.name[len(prefix) :] + if filename: + current_dataset_files.append(filename) + + st.session_state.cached_data_files[st.session_state.selected_dataset] = ( + sorted(set(current_dataset_files)) + ) + st.session_state.last_selected_dataset_for_cache = ( + st.session_state.selected_dataset + ) + logger.info( + "Cached files for %s: %s", + st.session_state.selected_dataset, + st.session_state.cached_data_files[st.session_state.selected_dataset], + ) + + if "selected_file_from_dataset" in st.session_state: + st.session_state.selected_file_from_dataset = None + logger.info("Reset selected_file_from_dataset due to dataset change.") + + files_to_display_in_selectbox = st.session_state.cached_data_files.get( + st.session_state.selected_dataset, [] + ) + + st.selectbox( + "Select a file from this dataset:", + options=[None, *files_to_display_in_selectbox], + key="selected_file_from_dataset", + ) + + st.text_input("Number of Samples", key="n_samples") + + st.text_input( + "Ground Truth Column Name", + key="ground_truth_column_name", + value="target", + help="The name of the column in your dataset that contains the ground truth or target response.", + ) + + st.selectbox( + "Select Existing Prompt", + options=[None, *list(st.session_state.local_prompt.existing_prompts.keys())], + placeholder="Select Prompt...", + key="selected_prompt", + ) + + versions_to_display_in_selectbox = [] + if st.session_state.selected_prompt: + st.session_state.local_prompt.prompt_meta["name"] = ( + st.session_state.selected_prompt + ) + + selected_prompt_obj = st.session_state.local_prompt.existing_prompts[ + st.session_state.selected_prompt + ] + prompt_resource_name_for_cache = str(selected_prompt_obj) + + if ( + st.session_state.selected_prompt + != st.session_state.last_selected_prompt_for_versions_cache + or prompt_resource_name_for_cache + not in st.session_state.cached_prompt_versions + ): + logger.info( + "Cache miss or prompt changed for versions. Fetching for: %s", + st.session_state.selected_prompt, + ) + fetched_versions = [ + v.version_id for v in prompts.list_versions(selected_prompt_obj) + ] + + st.session_state.cached_prompt_versions[prompt_resource_name_for_cache] = ( + fetched_versions + ) + st.session_state.last_selected_prompt_for_versions_cache = ( + st.session_state.selected_prompt + ) + logger.info( + "Cached versions for %s: %s", + st.session_state.selected_prompt, + fetched_versions, + ) + + if "selected_version" in st.session_state: + st.session_state.selected_version = None + logger.info("Reset selected_version due to prompt change.") + versions_to_display_in_selectbox = st.session_state.cached_prompt_versions.get( + prompt_resource_name_for_cache, [] + ) + + st.selectbox( + "Select Version", + options=versions_to_display_in_selectbox, + placeholder="Select Version...", + key="selected_version", + ) + + st.button("Load Prompt", key="load_prompt_button") + if st.session_state.load_prompt_button: + logger.info( + f"Selected Prompt ID: {st.session_state.local_prompt.existing_prompts[st.session_state.selected_prompt]}" + ) + logger.info(f"Version: {st.session_state.selected_version}") + st.session_state.local_prompt.load_prompt( + st.session_state.local_prompt.existing_prompts[ + st.session_state.selected_prompt + ], + st.session_state.selected_prompt, + st.session_state.selected_version, + ) + logger.info(f"Local Prompt Meta: {st.session_state.local_prompt.prompt_meta}") + logger.info( + f"Local Prompt Meta Dict Keys: {st.session_state.local_prompt.prompt_meta.keys()}" + ) + + st.session_state.prompt_name = ( + st.session_state.local_prompt.prompt_to_run.prompt_name + ) + st.session_state.prompt_data = ( + st.session_state.local_prompt.prompt_to_run.prompt_data + ) + st.session_state.model_name = ( + st.session_state.local_prompt.prompt_to_run.model_name.split("/")[-1] + ) + st.session_state.system_instructions = ( + st.session_state.local_prompt.prompt_to_run.system_instruction + ) + st.session_state.response_schema = json.dumps( + st.session_state.local_prompt.prompt_meta.get("response_schema", {}) + ) + st.session_state.generation_config = json.dumps( + st.session_state.local_prompt.prompt_meta.get("generation_config", {}) + ) + st.session_state.meta_tags = st.session_state.local_prompt.prompt_meta[ + "meta_tags" + ] + + st.button("Upload Data and Get Responses", key="upload_data_get_responses_button") + + if ( + st.session_state.upload_data_get_responses_button + and st.session_state.n_samples + and st.session_state.selected_dataset + ): + if not st.session_state.n_samples: + st.warning("Please enter the Number of Samples.") + return + if not st.session_state.selected_dataset: + st.warning("Please select an Existing Dataset.") + return + if not st.session_state.selected_file_from_dataset: + st.warning("Please select a file from the dataset.") + return + + try: + num_samples = int(st.session_state.n_samples) + if num_samples <= 0: + st.warning("Number of Samples must be a positive integer.") + return + except ValueError: + st.warning("Number of Samples must be a valid integer.") + return + + gcs_path = f"gs://{os.getenv('BUCKET')}/datasets/{st.session_state.selected_dataset}/{st.session_state.selected_file_from_dataset}" + st.session_state["input_data_uri"] = gcs_path + try: + if gcs_path.endswith(".csv"): + df_full = pd.read_csv(gcs_path) + elif gcs_path.endswith(".jsonl"): + df_full = pd.read_json(gcs_path, lines=True) + else: + st.error(f"Unsupported file type: {gcs_path.split('.')[-1]}") + return + except Exception as e: + st.error(f"Error reading data from {gcs_path}: {e}") + return + + df = df_full.iloc[:num_samples] + if df.empty: + st.warning( + "No data found for the first %s samples in %s, or the file is smaller than requested.", + num_samples, + st.session_state.selected_file_from_dataset, + ) + st.session_state.human_rated_dict = {} + st.session_state.ratings = [] + st.session_state.include_in_evaluations = [] + st.session_state.current_index = 0 + return + + user_input_list = [] + expected_result_list = [] + assistant_response_list = [] + baseline_model_response_list = [] + + generate = False + ground_truth_col = st.session_state.ground_truth_column_name + required_cols_for_loading_existing = [ + "user_input", + ground_truth_col, + "assistant_response", + ] + + if "assistant_response" in df.columns: + missing_loading_cols = [ + col + for col in required_cols_for_loading_existing + if col not in df.columns + ] + if not missing_loading_cols: + logger.info("Sufficient columns found to load existing responses.") + generate = False + else: + st.error( + f"The file has 'assistant_response' column, but is missing other essential columns for loading: {missing_loading_cols}. Required for loading: {required_cols_for_loading_existing}. Found columns: {df.columns.tolist()}", + ) + st.session_state.human_rated_dict = {} + st.session_state.ratings = [] + st.session_state.include_in_evaluations = [] + st.session_state.current_index = 0 + return + else: + if not st.session_state.get("prompt_data"): + st.error( + "To generate new responses, please load a prompt first using the 'Load Prompt' button." + ) + st.session_state.human_rated_dict = {} + st.session_state.ratings = [] + st.session_state.include_in_evaluations = [] + st.session_state.current_index = 0 + return + + template_vars = re.findall(r"{(\w+)}", st.session_state.prompt_data) + required_cols_for_generating_new = list(set(template_vars)) + + all_required_cols = [*required_cols_for_generating_new, ground_truth_col] + missing_generating_cols = [ + col for col in all_required_cols if col not in df.columns + ] + + if not missing_generating_cols: + logger.info( + "'assistant_response' column not found. Required columns for generating new responses are present. Will generate." + ) + generate = True + else: + st.error( + f"The file does not have 'assistant_response' column, and is also missing columns required for generating new responses based on the loaded prompt: {missing_generating_cols}. Required for generation: {all_required_cols}. Found columns: {df.columns.tolist()}", + ) + st.session_state.human_rated_dict = {} + st.session_state.ratings = [] + st.session_state.include_in_evaluations = [] + st.session_state.current_index = 0 + return + + logger.info("Generate flag set to: %s for %s samples.", generate, len(df)) + + if generate: + logger.info("Proceeding with generating new assistant responses.") + if ( + not st.session_state.selected_prompt + or not st.session_state.selected_version + ): + st.error( + "A prompt and version must be loaded to generate new responses. Please use the 'Load Prompt' button." + ) + st.session_state.human_rated_dict = {} + st.session_state.ratings = [] + st.session_state.include_in_evaluations = [] + st.session_state.current_index = 0 + return + if not st.session_state.local_prompt.prompt_to_run.prompt_data: + st.error( + "Prompt data is missing from the loaded prompt. Cannot generate. Please re-load the prompt using 'Load Prompt'." + ) + st.session_state.human_rated_dict = {} + st.session_state.ratings = [] + st.session_state.include_in_evaluations = [] + st.session_state.current_index = 0 + return + + if len(df) > 0: + st.session_state.generation_progress_bar = st.progress( + 0, text="Starting response generation..." + ) + + template_vars = re.findall(r"{(\w+)}", st.session_state.prompt_data) + required_cols_for_generating_new = list(set(template_vars)) + for idx, r in df.iterrows(): + current_user_input_item = { + col: r[col] for col in required_cols_for_generating_new + } + try: + generated_text = st.session_state.local_prompt.generate_response( + current_user_input_item + ) + user_input_list.append(current_user_input_item) + expected_result_list.append( + r[st.session_state.ground_truth_column_name] + ) + assistant_response_list.append(generated_text) + if "generation_progress_bar" in st.session_state: + progress_text = f"Generating response {idx + 1} of {len(df)}..." + st.session_state.generation_progress_bar.progress( + (idx + 1) / len(df), text=progress_text + ) + except Exception as e: + logger.exception( + "Error generating response for row index %s (data: %s): %s", + idx, + current_user_input_item, + e, + ) + st.warning( + f"Skipped generating response for one item (row index {idx}) due to error: {e}" + ) + if "generation_progress_bar" in st.session_state: + progress_text = f"Generating response {idx + 1} of {len(df)}... (Error, skipped)" + st.session_state.generation_progress_bar.progress( + (idx + 1) / len(df), text=progress_text + ) + continue + if "generation_progress_bar" in st.session_state: + st.session_state.generation_progress_bar.empty() + del st.session_state.generation_progress_bar + + if len(user_input_list) < len(df): + st.info( + "Successfully generated responses for %s out of %s requested samples due to errors during generation.", + len(user_input_list), + len(df), + ) + else: + logger.info( + "Proceeding with loading existing assistant responses from file." + ) + parsed_user_inputs_temp = [] + for item_str in df["user_input"].astype(str).tolist(): + try: + parsed_user_inputs_temp.append(json.loads(item_str)) + except json.JSONDecodeError: + logger.debug( + "User input item is not valid JSON, using as raw string: %s", + item_str[:100], + ) + parsed_user_inputs_temp.append(item_str) + user_input_list = parsed_user_inputs_temp + expected_result_list = df[ + st.session_state.ground_truth_column_name + ].tolist() + assistant_response_list = df.assistant_response.tolist() + baseline_model_response_list = [] + if "baseline_model_response" in df.columns: + baseline_model_response_list = df.baseline_model_response.tolist() + + st.session_state.human_rated_dict = { + "user_input": user_input_list, + "ground_truth": expected_result_list, + "assistant_response": assistant_response_list, + } + if baseline_model_response_list: + st.session_state.human_rated_dict["baseline_model_response"] = ( + baseline_model_response_list + ) + num_processed_items = len(user_input_list) + + if num_processed_items > 0: + st.session_state.include_in_evaluations = [True] * num_processed_items + st.session_state.current_index = 0 + st.success(f"Successfully processed {num_processed_items} samples.") + else: + st.warning( + "No data items were processed successfully. Check logs for errors or review file structure." + ) + st.session_state.human_rated_dict = {} + st.session_state.include_in_evaluations = [] + st.session_state.current_index = 0 + + st.divider() + + if st.session_state.human_rated_dict: + st.title("Review Responses") + + col1, col2, col3 = st.columns(3) + + with col1: + st.subheader("User Input") + st.text_area( + label="User's original query/text", + value=st.session_state.human_rated_dict["user_input"][ + st.session_state.current_index + ], + height=200, + key="user_input_text", + disabled=True, + ) + + with col2: + st.subheader("Ground Truth") + st.text_area( + label="The ideal/target response", + value=st.session_state.human_rated_dict["ground_truth"][ + st.session_state.current_index + ], + height=200, + key="ground_truth_text", + disabled=True, + ) + + with col3: + st.subheader("Assistant Response") + st.text_area( + label="The assistant's generated response", + value=st.session_state.human_rated_dict["assistant_response"][ + st.session_state.current_index + ], + height=200, + key="assistant_response_text", + disabled=True, + ) + + eval_include = st.checkbox( + "Include in Evaluation", + value=st.session_state.include_in_evaluations[ + st.session_state.current_index + ], + key="evaluation_checkbox", + ) + + if ( + st.session_state.include_in_evaluations + and eval_include + != st.session_state.include_in_evaluations[st.session_state.current_index] + ): + st.session_state.include_in_evaluations[st.session_state.current_index] = ( + eval_include + ) + + st.markdown("---") + col_prev, col_spacer, col_next = st.columns([1, 3, 1]) + + with col_prev: + if st.button("Previous", disabled=(st.session_state.current_index == 0)): + st.session_state.current_index -= 1 + st.rerun() + + with col_next: + if st.button( + "Next", + disabled=( + st.session_state.current_index + == len(st.session_state.human_rated_dict["user_input"]) - 1 + ), + ): + st.session_state.current_index += 1 + st.rerun() + + st.markdown( + f"
Case {st.session_state.current_index + 1} of {len(st.session_state.human_rated_dict['user_input'])}
", + unsafe_allow_html=True, + ) + + st.markdown("---") + st.subheader("Launch Eval") + + st.subheader("Metrics Selection") + + col1, col2 = st.columns(2) + + with col1: + st.write("**Model-Based**") + metric_names = MetricPromptTemplateExamples.list_example_metric_names() + selected_model_based_metrics = st.multiselect( + "Select from model-based metrics", + metric_names, + key="selected_model_based_metrics", + label_visibility="collapsed", + ) + if selected_model_based_metrics: + num_metrics = len(selected_model_based_metrics) + if st.session_state.metric_preview_index >= num_metrics: + st.session_state.metric_preview_index = 0 + + current_metric_name = selected_model_based_metrics[ + st.session_state.metric_preview_index + ] + + st.markdown( + f"**Previewing Template: {current_metric_name} ({st.session_state.metric_preview_index + 1}/{num_metrics})**" + ) + + try: + metric_object = get_metric_object_by_name(current_metric_name) + if isinstance( + metric_object, + PointwiseMetricPromptTemplate | PairwiseMetricPromptTemplate, + ): + st.text_area( + "Template Preview", + metric_object.metric_prompt_template, + height=200, + ) + except Exception as e: + st.error( + f"Could not retrieve template for {current_metric_name}: {e}" + ) + + if num_metrics > 1: + prev_col, next_col = st.columns(2) + with prev_col: + if st.button( + "Previous Template", + disabled=st.session_state.metric_preview_index <= 0, + ): + st.session_state.metric_preview_index -= 1 + st.rerun() + with next_col: + if st.button( + "Next Template", + disabled=st.session_state.metric_preview_index + >= num_metrics - 1, + ): + st.session_state.metric_preview_index += 1 + st.rerun() + + with col2: + st.write("**Computation-Based Pointwise**") + computation_based_pointwise = [ + "bleu", + "rouge_1", + "rouge_2", + "rouge_l", + "rouge_l_sum", + "exact_match", + ] + st.multiselect( + "Select from computation-based pointwise metrics", + computation_based_pointwise, + key="selected_cbp", + label_visibility="collapsed", + ) + + st.selectbox( + "Select Evaluation Model", + options=[ + "gemini-2.0-flash-lite", + "gemini-2.5-flash", + "gemini-2.5-pro", + ], + key="selected_evaluation_model", + ) + + st.button("Launch Eval", key="launch_eval_button") + + if st.session_state.launch_eval_button: + selected_mbp_names = st.session_state.get( + "selected_model_based_metrics", [] + ) + selected_cbp_metrics = st.session_state.get("selected_cbp", []) + + all_metrics = selected_mbp_names + selected_cbp_metrics + + if not all_metrics: + st.warning("Please select at least one evaluation metric.") + return + + evaluation_data_list = [] + for idx, include_item in enumerate(st.session_state.include_in_evaluations): + if include_item: + user_input_values = st.session_state.human_rated_dict["user_input"][ + idx + ] + prompt_template = ( + st.session_state.local_prompt.prompt_to_run.prompt_data + ) + logger.info(f"Prompt template: {prompt_template}") + system_instruction = ( + st.session_state.local_prompt.prompt_to_run.system_instruction + ) + prediction = str( + st.session_state.human_rated_dict["assistant_response"][idx] + ) + + # Process reference value like in the old code + reference_val = st.session_state.human_rated_dict["ground_truth"][ + idx + ] + final_reference_str = "" + if isinstance(reference_val, int | float | bool): + final_reference_str = json.dumps({"value": reference_val}) + elif isinstance(reference_val, str): + try: + parsed_json = json.loads(reference_val) + if isinstance(parsed_json, int | float | bool): + final_reference_str = json.dumps({"value": parsed_json}) + else: + final_reference_str = reference_val + except json.JSONDecodeError: + final_reference_str = reference_val + elif isinstance(reference_val, dict | list): + final_reference_str = json.dumps(reference_val) + else: + final_reference_str = str(reference_val) + + instruction = ( + prompt_template.format(**user_input_values) + if isinstance(user_input_values, dict) + else prompt_template + ) + context = system_instruction if system_instruction else "" + prompt_str = ( + json.dumps(user_input_values) + if isinstance(user_input_values, dict) + else str(user_input_values) + ) + + eval_item = { + "context": context, + "instruction": instruction, + "prompt": prompt_str, + "prediction": prediction, + "reference": final_reference_str, + } + evaluation_data_list.append(eval_item) + + if not evaluation_data_list: + st.warning( + "No items were selected for evaluation. Please check the 'Include in Evaluation' checkboxes." + ) + return + + df_dataset = pd.DataFrame(evaluation_data_list) + st.session_state.df_dataset_eval = df_dataset + st.session_state.all_metrics_eval = all_metrics + logger.info(f"Evaluation DataFrame columns: {df_dataset.columns.tolist()}") + logger.info(f"Evaluation DataFrame head:\n{df_dataset.head()}") + + task = EvalTask( + dataset=df_dataset, + metrics=all_metrics, + experiment=os.getenv("EXPERIMENT_NAME"), + ) + st.session_state.eval_result = task.evaluate( + response_column_name="prediction", + baseline_model_response_column_name="reference", + ) + + st.markdown("---") + st.subheader("View Eval") + + st.button("View Evaluation Results", key="eval_results_button") + + if st.session_state.eval_result and st.session_state.eval_results_button: + print(st.session_state.eval_result.metrics_table) + st.dataframe(st.session_state.eval_result.metrics_table) + if st.session_state.eval_result: + st.markdown("---") + st.subheader("Summary Scores") + + mean_scores = {} + if ( + st.session_state.eval_result + and hasattr(st.session_state.eval_result, "metrics_table") + and not st.session_state.eval_result.metrics_table.empty + ): + for col in st.session_state.eval_result.metrics_table.columns: + if col.endswith("/score"): + scores = pd.to_numeric( + st.session_state.eval_result.metrics_table[col], + errors="coerce", + ) + if not scores.dropna().empty: + mean_scores[col] = scores.dropna().mean() + if mean_scores: + for metric, score in mean_scores.items(): + st.metric(label=f"Mean {metric}", value=f"{score:.2f}") + else: + st.metric( + label="Mean Automated Score", + value="N/A", + ) + + st.markdown("---") + st.subheader("Save to Prompt Records") + save_to_records = st.checkbox( + "I want to save the results of this evaluation to the prompt records.", + key="save_to_records_checkbox", + ) + if st.button("Save to Prompt Records", key="save_to_records_button"): + if save_to_records: + prompt_name = st.session_state.selected_prompt + prompt_version = st.session_state.selected_version + data_file = st.session_state.input_data_uri + + if "df_dataset_eval" in st.session_state: + data = st.session_state.df_dataset_eval.to_dict( + orient="records" + ) + else: + st.error( + "Evaluation data not found in session state. Please re-run evaluation." + ) + return + + if "all_metrics_eval" in st.session_state: + metrics = st.session_state.all_metrics_eval + else: + st.error( + "Metrics not found in session state. Please re-run evaluation." + ) + return + + scores_df = st.session_state.eval_result.metrics_table + scores = scores_df.to_dict(orient="records") + + mean_scores_to_save = {} + for col in scores_df.columns: + if col.endswith("/score"): + s = pd.to_numeric(scores_df[col], errors="coerce") + if not s.dropna().empty: + mean_scores_to_save[col] = s.dropna().mean() + + record_data = { + "prompt_name": prompt_name, + "prompt_version": prompt_version, + "data_file": data_file, + "metrics": metrics, + "mean_scores": mean_scores_to_save, + "scores": scores, + "evaluation_data": data, + "timestamp": datetime.datetime.now().isoformat(), + } + + try: + filename = f"record_{prompt_name}_v{prompt_version}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.json" + bucket = st.session_state.storage_client.bucket( + os.getenv("BUCKET") + ) + blob = bucket.blob(f"records/{filename}") + + json_data = json.dumps(record_data, indent=4) + blob.upload_from_string( + json_data, content_type="application/json" + ) + + gcs_path = f"gs://{os.getenv('BUCKET')}/records/{filename}" + + st.success( + f"Successfully saved to prompt records at: {gcs_path}" + ) + st.json(json_data) + except Exception as e: + st.error(f"Failed to save to GCS: {e}") + else: + st.warning( + "Please check the box to confirm you want to save to prompt records." + ) + + +if __name__ == "__main__": + main() diff --git a/tools/llmevalkit/pages/4_Prompt_Optimization.py b/tools/llmevalkit/pages/4_Prompt_Optimization.py new file mode 100644 index 00000000000..702a1ea3d19 --- /dev/null +++ b/tools/llmevalkit/pages/4_Prompt_Optimization.py @@ -0,0 +1,418 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Streamlit page for running Vertex AI Prompt Optimization. + +This script provides a user interface for: +- Loading existing prompts from Vertex AI Prompt Registry. +- Loading datasets from a Google Cloud Storage bucket. +- Generating baseline responses and evaluating them against a ground truth. +- Configuring and launching a Vertex AI CustomJob for prompt optimization. +- Displaying baseline evaluation results. + +File Source: +https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/prompt_optimizer/vapo_lib.py +""" + +import json +import logging +import os +from argparse import Namespace + +import pandas as pd +import streamlit as st +from dotenv import load_dotenv +from etils import epath +from google.cloud import aiplatform, storage +from src import vapo_lib +from src.gcp_prompt import GcpPrompt as gcp_prompt +from vertexai.preview import prompts + +load_dotenv("src/.env") + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +TARGET_MODELS = ["gemini-2.0-flash-001", "gemini-2.0-flash-lite-001"] + + +def initialize_session_state() -> None: + """Initializes the session state variables.""" + if "op_id" not in st.session_state: + st.session_state.op_id = vapo_lib.get_id() + + if "local_prompt" not in st.session_state: + st.session_state.local_prompt = gcp_prompt() + + if "storage_client" not in st.session_state: + st.session_state["storage_client"] = storage.Client() + + if "data_uris" not in st.session_state: + st.session_state["data_uris"] = refresh_bucket() + + if "dataset" not in st.session_state: + st.session_state["dataset"] = None + + if "cached_data_files" not in st.session_state: + st.session_state.cached_data_files = {} + if "last_selected_dataset_for_cache" not in st.session_state: + st.session_state.last_selected_dataset_for_cache = None + + +def refresh_bucket() -> list[str]: + """Refreshes the list of available dataset URIs from the GCS bucket. + + This function lists all blobs in the configured GCS bucket, filters for + CSV and JSONL files located within the 'datasets/' prefix, and constructs a list + of their full gs:// URI paths. + + Returns: + A list of strings, where each string is a GCS URI to a dataset file. + """ + logger.info("Bucket: %s", os.getenv("BUCKET")) + bucket = st.session_state.storage_client.bucket(os.getenv("BUCKET")) + blobs = bucket.list_blobs() + data_uris = [] + for i in blobs: + if i.name.split("/")[0] == "datasets" and ( + i.name.endswith(".csv") or i.name.endswith(".jsonl") + ): + data_uris.append(f"gs://{i.bucket.name}/{i.name}") + logger.info("Data URIs: %s", data_uris) + return data_uris + + +def prompt_selection() -> None: + """Handles the prompt selection and loading.""" + st.selectbox( + "Select Existing Prompt", + options=st.session_state.local_prompt.existing_prompts.keys(), + placeholder="Select Prompt...", + key="selected_prompt", + ) + if st.session_state.selected_prompt: + logger.info("Prompt Meta: %s", st.session_state.local_prompt.prompt_meta) + st.session_state.local_prompt.prompt_meta["name"] = ( + st.session_state.selected_prompt + ) + versions = [ + i.version_id + for i in prompts.list_versions( + st.session_state.local_prompt.existing_prompts[ + st.session_state.selected_prompt + ] + ) + ] + + st.selectbox( + "Select Version", + options=versions, + placeholder="Select Version...", + key="selected_version", + ) + + st.button("Load Prompt", key="load_existing_prompt_button") + if st.session_state.load_existing_prompt_button: + logger.info( + "Selected Prompt ID: %s", + st.session_state.local_prompt.existing_prompts[ + st.session_state.selected_prompt + ], + ) + logger.info("Version: %s", st.session_state.selected_version) + st.session_state.local_prompt.load_prompt( + st.session_state.local_prompt.existing_prompts[ + st.session_state.selected_prompt + ], + st.session_state.selected_prompt, + st.session_state.selected_version, + ) + logger.info("Local Prompt Meta: %s", st.session_state.local_prompt.prompt_meta) + logger.info( + "Local Prompt Meta Dict Keys: %s", + st.session_state.local_prompt.prompt_meta.keys(), + ) + + st.session_state.prompt_name = ( + st.session_state.local_prompt.prompt_to_run.prompt_name + ) + st.session_state.prompt_data = ( + st.session_state.local_prompt.prompt_to_run.prompt_data + ) + st.session_state.model_name = ( + st.session_state.local_prompt.prompt_to_run.model_name.split("/")[-1] + ) + st.session_state.system_instructions = ( + st.session_state.local_prompt.prompt_to_run.system_instruction + ) + st.session_state.response_schema = json.dumps( + st.session_state.local_prompt.prompt_meta.get("response_schema", {}) + ) + st.session_state.generation_config = json.dumps( + st.session_state.local_prompt.prompt_meta.get("generation_config", {}) + ) + st.session_state.meta_tags = st.session_state.local_prompt.prompt_meta[ + "meta_tags" + ] + + +def dataset_selection() -> None: + """Handles the dataset selection and loading.""" + data_sets = list({i.split("/")[4] for i in st.session_state.data_uris}) + logger.info("Data Sets: %s", data_sets) + st.selectbox( + "Select an Existing Dataset", options=[None, *data_sets], key="selected_dataset" + ) + + files_to_display_in_selectbox = [] + if st.session_state.selected_dataset: + if ( + st.session_state.selected_dataset + != st.session_state.last_selected_dataset_for_cache + or st.session_state.selected_dataset + not in st.session_state.cached_data_files + ): + logger.info( + "Cache miss or dataset changed for files. Fetching for: %s", + st.session_state.selected_dataset, + ) + bucket = st.session_state.storage_client.bucket(os.getenv("BUCKET")) + prefix = f"datasets/{st.session_state.selected_dataset}/" + blobs_iterator = bucket.list_blobs(prefix=prefix) + + current_dataset_files = [] + for blob in blobs_iterator: + if ( + blob.name.endswith(".csv") or blob.name.endswith(".jsonl") + ) and not blob.name.endswith("/"): + filename = blob.name[len(prefix) :] + if filename: + current_dataset_files.append(filename) + + st.session_state.cached_data_files[st.session_state.selected_dataset] = ( + sorted(set(current_dataset_files)) + ) + st.session_state.last_selected_dataset_for_cache = ( + st.session_state.selected_dataset + ) + logger.info( + "Cached files for %s: %s", + st.session_state.selected_dataset, + st.session_state.cached_data_files[st.session_state.selected_dataset], + ) + + if "selected_file_from_dataset" in st.session_state: + st.session_state.selected_file_from_dataset = None + logger.info("Reset selected_file_from_dataset due to dataset change.") + + files_to_display_in_selectbox = st.session_state.cached_data_files.get( + st.session_state.selected_dataset, [] + ) + + st.selectbox( + "Select a file from this dataset:", + options=[None, *files_to_display_in_selectbox], + key="selected_file_from_dataset", + ) + + st.button("Load Dataset", key="load_existing_dataset_button") + if st.session_state.load_existing_dataset_button: + gcs_uri = f"gs://{os.getenv('BUCKET')}/datasets/{st.session_state.selected_dataset}/{st.session_state.selected_file_from_dataset}" + logger.info("Loading file: %s", gcs_uri) + if st.session_state.selected_file_from_dataset.endswith(".jsonl"): + st.session_state.dataset = pd.read_json(gcs_uri, lines=True) + else: + st.session_state.dataset = pd.read_csv(gcs_uri) + + if st.session_state.dataset is not None: + st.dataframe(st.session_state.dataset) + + +def get_optimization_args( + input_optimization_data_file_uri, output_optimization_run_uri +): + """Gets the arguments for the optimization job.""" + response_schema_str = st.session_state.local_prompt.prompt_meta.get( + "response_schema", "{}" + ) + try: + response_schema = ( + json.loads(response_schema_str) + if isinstance(response_schema_str, str) + else response_schema_str + ) + except json.JSONDecodeError: + response_schema = {} + + if response_schema and response_schema != {}: + response_mime_type = "application/json" + response_schema_arg = response_schema + else: + response_mime_type = "text/plain" + response_schema_arg = "" + + has_multimodal = False + if ( + st.session_state.dataset is not None + and "image" in st.session_state.dataset.columns + ): + has_multimodal = True + + return Namespace( + system_instruction=st.session_state.local_prompt.prompt_to_run.system_instruction, + prompt_template=( + f"{st.session_state.local_prompt.prompt_to_run.prompt_data}" + "\n\tAnswer: {target}" + ), + target_model="gemini-2.0-flash-001", + optimization_mode="instruction", + eval_metrics_types=[ + "question_answering_correctness", + ], + eval_metrics_weights=[ + 1.0, + ], + aggregation_type="weighted_sum", + input_data_path=input_optimization_data_file_uri, + output_path=f"gs://{output_optimization_run_uri}", + project=os.getenv("PROJECT_ID"), + num_steps=10, + num_demo_set_candidates=10, + demo_set_size=3, + target_model_location="us-central1", + source_model="", + source_model_location="", + target_model_qps=1, + optimizer_model_qps=1, + eval_qps=1, + source_model_qps="", + response_mime_type=response_mime_type, + response_schema=response_schema_arg, + language="English", + placeholder_to_content=json.loads("{}"), + data_limit=10, + translation_source_field_name="", + has_multimodal_inputs=has_multimodal, + ) + + +def start_optimization() -> None: + """Starts the optimization job.""" + st.divider() + + st.subheader("Run Optimization") + st.button("Start Optimization", key="start_optimization_button") + + if st.session_state.start_optimization_button: + workspace_uri = ( + epath.Path(os.getenv("BUCKET")) / "optimization" / st.session_state.op_id + ) + logger.info("Workspace URI: %s", workspace_uri) + + input_data_uri = epath.Path(workspace_uri) / "data" + logger.info("Input Data URI: %s", input_data_uri) + + workspace_uri.mkdir(parents=True, exist_ok=True) + input_data_uri.mkdir(parents=True, exist_ok=True) + + output_optimization_data_uri = epath.Path(workspace_uri) / "optimization_jobs" + logger.info("Output Data URI: %s", output_optimization_data_uri) + + prompt_optimization_job = ( + f"{st.session_state.selected_prompt}-" + f"{st.session_state.selected_version}-" + f"{st.session_state.selected_dataset}-" + f"{st.session_state.op_id}" + ) + output_optimization_run_uri = str( + output_optimization_data_uri / prompt_optimization_job + ) + input_optimization_data_file_uri = ( + f"gs://{input_data_uri}/{prompt_optimization_job}.jsonl" + ) + logger.info("Input Optimization Data URI: %s", input_optimization_data_file_uri) + if st.session_state.dataset is not None: + st.session_state.dataset.to_json( + str(input_optimization_data_file_uri), orient="records", lines=True + ) + else: + st.error("Please load a dataset first.") + return + + args = get_optimization_args( + input_optimization_data_file_uri, output_optimization_run_uri + ) + + with st.expander("Prompt Otimization Config"): + st.json(vars(args)) + + args = vars(args) + + config_file_uri = "gs://" + str(workspace_uri / "config" / "config.json") + + with epath.Path(config_file_uri).open("w") as config_file: + json.dump(args, config_file) + config_file.close() + st.success(f"Successfully wrote config file to {config_file_uri}") + + worker_pool_specs = [ + { + "machine_spec": { + "machine_type": "n1-standard-4", + }, + "replica_count": 1, + "container_spec": { + "image_uri": os.getenv("APD_CONTAINER_URI"), + "args": ["--config=" + config_file_uri], + }, + } + ] + + custom_job = aiplatform.CustomJob( + display_name=prompt_optimization_job, + worker_pool_specs=worker_pool_specs, + staging_bucket=str(workspace_uri), + ) + + custom_job.run(service_account=os.getenv("APD_SERVICE_ACCOUNT"), sync=False) + + st.success("Successfully Started Job!!") + + +def main() -> None: + """Streamlit page for Prompt Optimization.""" + st.set_page_config( + layout="wide", page_title="Prompt Optimization", page_icon="assets/favicon.ico" + ) + + initialize_session_state() + + st.header("Prompt Optimization") + + st.selectbox( + "Select Target Model for Optimization:", + options=TARGET_MODELS, + key="target_model_optimization", + ) + + prompt_selection() + dataset_selection() + start_optimization() + + +if __name__ == "__main__": + main() diff --git a/tools/llmevalkit/pages/5_Prompt_Optimization_Results.py b/tools/llmevalkit/pages/5_Prompt_Optimization_Results.py new file mode 100644 index 00000000000..d803bb9ce26 --- /dev/null +++ b/tools/llmevalkit/pages/5_Prompt_Optimization_Results.py @@ -0,0 +1,478 @@ +## Copyright 2025 Google LLC +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## https://www.apache.org/licenses/LICENSE-2.0 +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language + +import json +import logging +import os + +import pandas as pd +import streamlit as st +from dotenv import load_dotenv +from google.cloud import storage +from src import vapo_lib + +# Load environment variables +load_dotenv("src/.env") + +# Configure logging to the console +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# --- Constants --- +BASE_OPTIMIZATION_PREFIX = "optimization/" +OPTIMIZATION_JOBS_SUBDIR = "optimization_jobs/" + +from google.cloud import aiplatform + + +def list_custom_training_jobs(project_id: str, location: str): + """Lists all custom training jobs and their statuses in a given project and location. + + Args: + project_id: The Google Cloud project ID. + location: The region for the Vertex AI jobs, e.g., "us-central1". + + Returns: + A list of dictionaries, where each dictionary contains details of a custom job. + """ + # Initialize the Vertex AI client + # The API endpoint is determined by the location + client_options = {"api_endpoint": f"{location}-aiplatform.googleapis.com"} + client = aiplatform.gapic.JobServiceClient(client_options=client_options) + + # The parent resource path format + parent = f"projects/{project_id}/locations/{location}" + + # Make the API request to list custom jobs + response = client.list_custom_jobs(parent=parent) + + # Process the response and format the output + jobs_list = [] + print(f"Fetching jobs from project '{project_id}' in '{location}'...") + for job in response: + job_info = { + "display_name": job.display_name, + "name": job.name, + "status": job.state.name, # .name gets the string representation of the enum + } + jobs_list.append(job_info) + + print(f"Found {len(jobs_list)} jobs.") + return jobs_list + + +# --- Example Usage --- +if __name__ == "__main__": + # Replace with your project ID and desired location + PROJECT_ID = os.getenv("PROJECT_ID") + LOCATION = os.getenv("LOCATION") + + # Ensure you have authenticated with Google Cloud CLI: + # gcloud auth application-default login + + # And have the necessary permissions (e.g., "Vertex AI User" role) + + try: + all_jobs = list_custom_training_jobs(project_id=PROJECT_ID, location=LOCATION) + + # Print the results + if all_jobs: + print("\n--- Job Statuses ---") + for job in all_jobs: + print(f" - Name: {job['display_name']:<40} Status: {job['status']}") + print("--------------------\n") + else: + print("No custom jobs found.") + + except Exception as e: + print( + "\nAn error occurred. Please ensure your project ID and location are correct," + ) + print(f"and that you have authenticated correctly. Error: {e}") + + +def safe_json_loads(s): + """Safely loads a JSON string, returning the original value on failure.""" + if not isinstance(s, str): + return s + try: + return json.loads(s) + except (json.JSONDecodeError, TypeError): + return s + + +@st.cache_data(ttl=300) +def list_gcs_directories( + bucket_name: str, prefix: str, _storage_client: storage.Client +) -> list[str]: + """Lists 'directories' in GCS under a given prefix. + A 'directory' is inferred from the common prefixes of objects. + Caches the result for 5 minutes to improve performance. + """ + if not bucket_name: + st.warning("BUCKET environment variable is not set.") + return [] + if not _storage_client: + st.warning("Storage client is not initialized.") + return [] + + bucket = _storage_client.bucket(bucket_name) + retrieved_prefixes = set() + try: + for page in bucket.list_blobs(prefix=prefix, delimiter="/").pages: + retrieved_prefixes.update(page.prefixes) + + # The retrieved prefixes are the "subdirectories". + # e.g., for prefix 'optimization/', a retrieved prefix might be 'optimization/op_id/'. + # We want to extract just 'op_id'. + dir_names = [] + for p in retrieved_prefixes: + name = p.replace(prefix, "").strip("/") + if name: + dir_names.append(name) + return sorted(set(dir_names)) + except Exception as e: + st.error( + f"Error listing GCS directories under gs://{bucket_name}/{prefix}: {e}" + ) + logger.error( + f"Error listing GCS directories under gs://{bucket_name}/{prefix}: {e}", + exc_info=True, + ) + return [] + + +def _display_interactive_results(results_ui: vapo_lib.ResultsUI) -> None: + """Processes results from a VAPO run and displays them in an interactive + Streamlit UI with tabs for each prompt version. + """ + try: + if ( + not hasattr(results_ui, "templates") + or not results_ui.templates + or not hasattr(results_ui, "eval_results") + ): + logger.info( + "ResultsUI object does not have 'templates' or 'eval_results', or templates list is empty. Falling back." + ) + else: + processed_results_for_tabs = [] + for i, template_summary_df in enumerate(results_ui.templates): + if ( + not isinstance(template_summary_df, pd.DataFrame) + or template_summary_df.empty + ): + logger.warning( + f"Template summary data at index {i} is not a non-empty DataFrame. Skipping." + ) + continue + + # Get the detailed results to perform the custom calculation + detailed_eval_df = pd.DataFrame() + if i < len(results_ui.eval_results) and isinstance( + results_ui.eval_results[i], pd.DataFrame + ): + detailed_eval_df = results_ui.eval_results[i] + + # Add a custom exact_match calculation. This is more robust than simple + # string comparison as it handles differences in JSON key order and whitespace. + if ( + not detailed_eval_df.empty + and "ground_truth" in detailed_eval_df.columns + and "reference" in detailed_eval_df.columns + ): + # Parse the JSON strings into Python objects before comparing. + parsed_ground_truths = detailed_eval_df["ground_truth"].apply( + safe_json_loads + ) + parsed_references = detailed_eval_df["reference"].apply( + safe_json_loads + ) + + # Create a boolean series for the comparison + is_match = parsed_ground_truths.eq(parsed_references) + + # Map boolean to 'yes'/'no' for display in the detailed table + detailed_eval_df["calculated_exact_match"] = is_match.map( + {True: "yes", False: "no"} + ) + + # Calculate the mean from the boolean series for the summary metric + new_exact_match_mean = is_match.mean() + template_summary_df["metrics.calculated_exact_match/mean"] = ( + new_exact_match_mean + ) + + prompt_text = "Prompt text not found in template data." + if "prompt" in template_summary_df.columns: + prompt_text = template_summary_df["prompt"].iloc[0] + else: + logger.warning( + f"Column 'prompt' not found in template_summary_df at index {i}." + ) + + # Determine the primary score and build the tab name. + primary_score_label = "Score" + primary_score_value = "N/A" + if "metrics.calculated_exact_match/mean" in template_summary_df.columns: + primary_score_label = "Calculated Exact Match" + primary_score_value = template_summary_df[ + "metrics.calculated_exact_match/mean" + ].iloc[0] + else: + # Fallback to the first available metric + mean_metric_columns = [ + col + for col in template_summary_df.columns + if col.startswith("metrics.") and "/mean" in col + ] + if mean_metric_columns: + first_metric_col = mean_metric_columns[0] + primary_score_label = ( + first_metric_col.replace("metrics.", "") + .replace("/mean", "") + .replace("_", " ") + .title() + ) + primary_score_value = template_summary_df[ + first_metric_col + ].iloc[0] + + # Build the tab name with all available metrics for a quick overview. + tab_name_metrics_parts = [] + mean_metric_columns = [ + col + for col in template_summary_df.columns + if col.startswith("metrics.") and "/mean" in col + ] + for metric_col in mean_metric_columns: + metric_name_short = metric_col.replace("metrics.", "").replace( + "/mean", "" + ) + metric_val = template_summary_df[metric_col].iloc[0] + if metric_name_short == "calculated_exact_match" and isinstance( + metric_val, float + ): + tab_name_metrics_parts.append( + f"{metric_name_short}: {metric_val:.1%}" + ) + else: + tab_name_metrics_parts.append( + f"{metric_name_short}: {metric_val:.3f}" + if isinstance(metric_val, float) + else f"{metric_name_short}: {metric_val}" + ) + + tab_name = f"Template {i}" + if tab_name_metrics_parts: + tab_name += f" ({', '.join(tab_name_metrics_parts)})" + + current_summary_df_display = template_summary_df.copy() + if "prompt" in current_summary_df_display.columns: + current_summary_df_display = current_summary_df_display.drop( + columns=["prompt"] + ) + + processed_results_for_tabs.append( + { + "name": tab_name, + "template_text": prompt_text, + "primary_score_label": primary_score_label, + "primary_score_value": primary_score_value, + "summary_metrics_df": current_summary_df_display, + "detailed_eval_df": detailed_eval_df, + } + ) + + if ( + processed_results_for_tabs + ): # If we successfully processed data, show the new UI + st.write("### Interactive Prompt Versions") + tab_titles = [res["name"] for res in processed_results_for_tabs] + tabs = st.tabs(tab_titles) + + for i, tab_content in enumerate(tabs): + with tab_content: + result_data = processed_results_for_tabs[i] + + st.subheader("Prompt Template") + # Sanitize tab name for key + clean_key_name = "".join( + filter(str.isalnum, result_data["name"]) + ) + st.text_area( + "Template", + value=result_data["template_text"], + height=200, + disabled=True, + key=f"template_view_{clean_key_name}_{i}", + ) + + st.subheader("Primary Score") + score_val = result_data["primary_score_value"] + score_label = result_data["primary_score_label"] + if score_label == "Calculated Exact Match" and isinstance( + score_val, float + ): + st.metric(label=score_label, value=f"{score_val:.2%}") + else: + st.metric( + label=score_label, + value=f"{score_val:.4f}" + if isinstance(score_val, float) + else str(score_val), + ) + + if not result_data["summary_metrics_df"].empty: + st.subheader("Summary Metrics (from templates.json)") + st.dataframe(result_data["summary_metrics_df"]) + + if not result_data["detailed_eval_df"].empty: + st.subheader( + "Detailed Evaluation Results (from eval_results.json)" + ) + st.dataframe(result_data["detailed_eval_df"]) + else: + st.caption( + "No detailed evaluation results available for this template." + ) + else: + st.warning("No valid results could be processed for display.") + + except Exception as e: + st.error(f"An error occurred while trying to display results: {e}") + logger.error(f"Error in results display section: {e}", exc_info=True) + st.markdown( + "For now, you can access the results directly at the GCS path shown above." + ) + + +def main() -> None: + """Renders the Streamlit page for viewing Prompt Optimization Results.""" + st.set_page_config( + layout="wide", + page_title="Prompt Optimization Results", + page_icon="assets/favicon.ico", + ) + st.header("Prompt Optimization Results Browser") + + if "storage_client" not in st.session_state: + try: + st.session_state["storage_client"] = storage.Client() + logger.info("Storage client initialized.") + except Exception as e: + st.error(f"Failed to initialize Google Cloud Storage client: {e}") + logger.error( + f"Failed to initialize Google Cloud Storage client: {e}", exc_info=True + ) + st.session_state["storage_client"] = None + return + + bucket_name = os.getenv("BUCKET") + if not bucket_name: + st.error("BUCKET environment variable is not set. Please configure it in .env.") + return + + # --- Step 1: Select Operation ID --- + op_ids = list_gcs_directories( + bucket_name, BASE_OPTIMIZATION_PREFIX, st.session_state.storage_client + ) + if not op_ids: + st.info( + f"No optimization operation IDs found under gs://{bucket_name}/{BASE_OPTIMIZATION_PREFIX}" + ) + return + + if "op_id" in st.session_state and st.session_state.op_id: + st.caption( + f"Hint: The last optimization run you initiated had the ID: `{st.session_state.op_id}`." + ) + + selected_op_id = st.selectbox( + "Select an Operation ID:", options=[None, *op_ids], key="selected_op_id_results" + ) + if not selected_op_id: + st.write("Please select an Operation ID to see its optimization job runs.") + return + + st.divider() + + # --- Step 2: Select Experiment Run --- + st.subheader(f"Optimization Job Runs for Operation ID: {selected_op_id}") + optimization_jobs_prefix = ( + f"{BASE_OPTIMIZATION_PREFIX}{selected_op_id}/{OPTIMIZATION_JOBS_SUBDIR}" + ) + experiment_runs = list_gcs_directories( + bucket_name, optimization_jobs_prefix, st.session_state.storage_client + ) + + if not experiment_runs: + st.info( + f"No completed optimization job runs found under gs://{bucket_name}/{optimization_jobs_prefix}" + ) + return + + selected_run = st.selectbox( + "Select an Optimization Job Run:", + options=[None, *experiment_runs], + key="selected_experiment_run", + ) + if not selected_run: + st.write("Please select an optimization job run to view its results.") + return + + st.divider() + + # --- Step 3: Check Job Status and Display Results --- + st.subheader(f"Results for: {selected_run}") + + project_id = os.getenv("PROJECT_ID") + location = os.getenv("LOCATION") + + if not project_id or not location: + st.error("PROJECT_ID or REGION environment variables are not set.") + return + + try: + jobs = list_custom_training_jobs(project_id=project_id, location=location) + job_status = "Not Found" + for job in jobs: + if job["display_name"] == selected_run: + job_status = job["status"] + break + + st.info(f"Status for job '{selected_run}': **{job_status}**") + + if job_status == "JOB_STATE_FAILED": + st.error( + "This optimization job has failed. Please check the logs in the Vertex AI console for more details." + ) + return + if job_status not in ["JOB_STATE_SUCCEEDED", "JOB_STATE_CANCELLED"]: + st.warning( + f"Job is currently in status: {job_status}. Results may be incomplete." + ) + + except Exception as e: + st.error(f"Could not retrieve job status. Error: {e}") + logger.error( + f"Failed to retrieve job status for {selected_run}: {e}", exc_info=True + ) + + run_uri = f"gs://{bucket_name}/{optimization_jobs_prefix}{selected_run}" + st.info(f"Loading results from: {run_uri}") + results_ui = vapo_lib.ResultsUI(run_uri) + _display_interactive_results(results_ui) + + +if __name__ == "__main__": + main() diff --git a/tools/llmevalkit/pages/6_Prompt_Records.py b/tools/llmevalkit/pages/6_Prompt_Records.py new file mode 100644 index 00000000000..d1bc6e2734d --- /dev/null +++ b/tools/llmevalkit/pages/6_Prompt_Records.py @@ -0,0 +1,132 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os + +import pandas as pd +import streamlit as st +from dotenv import load_dotenv +from google.cloud import storage + +load_dotenv("src/.env") + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def load_records_from_gcs(bucket_name: str, prefix: str) -> pd.DataFrame: + """Loads all JSON record files from a GCS prefix and returns a DataFrame.""" + try: + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blobs = bucket.list_blobs(prefix=prefix) + + all_records = [] + for blob in blobs: + if blob.name.endswith(".json"): + logger.info(f"Loading record from {blob.name}") + try: + record_data = json.loads(blob.download_as_string()) + if isinstance(record_data, list): + all_records.extend(record_data) + else: + all_records.append(record_data) + except json.JSONDecodeError: + logger.warning(f"Could not decode JSON from {blob.name}") + except Exception as e: + logger.exception(f"Failed to process blob {blob.name}: {e}") + + if not all_records: + st.warning(f"No JSON records found at gs://{bucket_name}/{prefix}") + return pd.DataFrame() + + return pd.json_normalize(all_records) + + except Exception as e: + st.error(f"Failed to load or parse records from GCS: {e}") + logger.error("Error loading records: %s", e, exc_info=True) + return pd.DataFrame() + + +def main() -> None: + """Renders the Prompt Records Leaderboard page.""" + st.set_page_config( + layout="wide", + page_title="Prompt Records Leaderboard", + page_icon="assets/favicon.ico", + ) + st.header("Prompt Records Leaderboard") + st.markdown( + "This page allows you to view and compare the evaluation results of different prompt versions." + ) + + records_prefix = "records/" + + if "leaderboard_df" not in st.session_state: + st.session_state.leaderboard_df = pd.DataFrame() + + if st.button("Load/Refresh Leaderboard"): + with st.spinner("Loading records from GCS..."): + st.session_state.leaderboard_df = load_records_from_gcs( + os.getenv("BUCKET"), records_prefix + ) + if not st.session_state.leaderboard_df.empty: + st.success("Leaderboard loaded successfully.") + else: + st.info("Leaderboard is empty or could not be loaded.") + + if st.session_state.leaderboard_df.empty: + st.info("Click the button above to load the leaderboard data.") + return + + st.divider() + + prompt_names = st.session_state.leaderboard_df["prompt_name"].unique().tolist() + selected_prompt = st.selectbox( + "Select a Prompt to Compare Versions", options=[None, *prompt_names] + ) + + if selected_prompt: + st.subheader(f"Comparison for: {selected_prompt}") + + prompt_df = st.session_state.leaderboard_df[ + st.session_state.leaderboard_df["prompt_name"] == selected_prompt + ].copy() + + if prompt_df.empty: + st.info("No records found for the selected prompt.") + return + + # Extract scores into separate columns for easier analysis + scores_df = pd.json_normalize(prompt_df["scores"]) + scores_df.columns = [f"score.{col}" for col in scores_df.columns] + + comparison_df = pd.concat([prompt_df.reset_index(drop=True), scores_df], axis=1) + + # Clean up the view + display_columns = [ + col + for col in comparison_df.columns + if col not in ["scores", "evaluation_data"] + and not (col.startswith("score.") and col[6:].isdigit()) + ] + st.dataframe(comparison_df[display_columns]) + + +if __name__ == "__main__": + main() diff --git a/tools/llmevalkit/prompt-management-tutorial.ipynb b/tools/llmevalkit/prompt-management-tutorial.ipynb new file mode 100644 index 00000000000..60d9d54d068 --- /dev/null +++ b/tools/llmevalkit/prompt-management-tutorial.ipynb @@ -0,0 +1,580 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "81450b47de75" + }, + "outputs": [], + "source": [ + "# Copyright 2025 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a204d0ab284d" + }, + "source": [ + "# Tutorial for Running Prompt Management and Evaluation\n", + "\n", + "\n",
+ " \n",
+ " Open in Colab\n", + " \n", + " | \n",
+ " \n",
+ " \n",
+ " Open in Colab Enterprise\n", + " \n", + " | \n",
+ " \n",
+ " \n",
+ " Open in Vertex AI Workbench\n", + " \n", + " | \n",
+ " \n",
+ " \n",
+ " View on GitHub\n", + " \n", + " | \n",
+ "