2323
2424import argparse
2525import logging
26+ import asyncio
27+ import time
2628
2729import pandas as pd
2830
3739)
3840
3941
40- def evaluate_response (model : str , instructions : str , input : str ) -> pd .DataFrame :
42+ async def evaluate_response (model : str , instructions : str , input : str ) -> pd .DataFrame :
4143 """
4244 Test a prompt with a set of test data by scoring each item in the data set
4345 """
4446
4547 try :
4648 handler = ModelFactory .get_handler (model )
4749
48- generated_text , token_usage , pricing , duration = handler .handle_request (
50+ generated_text , token_usage , pricing , duration = await handler .handle_request (
4951 instructions , input
5052 )
5153
@@ -116,19 +118,22 @@ def calculate_cost_metrics(token_usage: dict, pricing: dict) -> dict:
116118 }
117119
118120
119- def load_csv (file_path : str , required_columns : list ) -> pd .DataFrame :
121+ def load_csv (file_path : str , required_columns : list , nrows : int = None ) -> pd .DataFrame :
120122 """
121123 Load a CSV file and validate that it contains the required columns
122124
123125 Args:
124126 file_path (str): Path to the CSV file
125127 required_columns (list): List of required column names
126-
128+ nrows (int): Number of rows to read from the CSV file
127129 Returns:
128130 pd.DataFrame
129131 """
130132
131- df = pd .read_csv (file_path )
133+ if nrows is not None :
134+ logging .info (f"Test mode enabled: Reading first { nrows } rows of { file_path } " )
135+
136+ df = pd .read_csv (file_path , nrows = nrows )
132137
133138 # Remove trailing whitespace from column names
134139 df .columns = df .columns .str .strip ()
@@ -145,9 +150,7 @@ def load_csv(file_path: str, required_columns: list) -> pd.DataFrame:
145150 return df
146151
147152
148- if __name__ == "__main__" :
149- # TODO: Add test evaluation argument to run on the first 10 rows of the dataset file
150-
153+ async def main ():
151154 parser = argparse .ArgumentParser ()
152155 parser .add_argument (
153156 "--experiments" , "-e" , required = True , help = "Path to experiments CSV file"
@@ -158,33 +161,35 @@ def load_csv(file_path: str, required_columns: list) -> pd.DataFrame:
158161 parser .add_argument (
159162 "--results" , "-r" , required = True , help = "Path to results CSV file"
160163 )
164+ parser .add_argument (
165+ "--test" , "-t" , type = int , help = "Run evaluation on first n rows of dataset only"
166+ )
161167
162168 args = parser .parse_args ()
163169
170+ # Load the experiment DataFrame
164171 df_experiment = load_csv (
165172 args .experiments , required_columns = ["MODEL" , "INSTRUCTIONS" ]
166173 )
167- # Check if all models are supported by ModelFactory
168- if not all (
169- model in ModelFactory .HANDLERS .keys ()
170- for model in df_experiment ["MODEL" ].unique ()
171- ):
172- raise ValueError (
173- f"Unsupported model(s) found: { set (df_experiment ['MODEL' ].unique ()) - set (ModelFactory .HANDLERS .keys ())} "
174- )
175- df_dataset = load_csv (args .dataset , required_columns = ["INPUT" ])
174+
175+ # Load the dataset DataFrame
176+ df_dataset = load_csv (args .dataset , required_columns = ["INPUT" ], nrows = args .test )
176177
177178 # Bulk model and prompt experimentation: Cross join the experiment and dataset DataFrames
178179 df_in = df_experiment .merge (df_dataset , how = "cross" )
179180
180- # Evaluate each row in the input DataFrame
181- results = []
182- for index , row in enumerate (df_in .itertuples (index = False )):
183- result = evaluate_response (row .MODEL , row .INSTRUCTIONS , row .INPUT )
184- results .append (result )
181+ # Evaluate each row in the input DataFrame concurrently
182+ logging .info (f"Starting evaluation of { len (df_in )} rows" )
183+ start_time = time .time ()
184+ tasks = [
185+ evaluate_response (row .MODEL , row .INSTRUCTIONS , row .INPUT )
186+ for row in df_in .itertuples (index = False )
187+ ]
185188
186- # TODO: Use tqdm or similar library to show progress bar
187- logging .info (f"Processed row { index + 1 } /{ len (df_in )} " )
189+ results = await asyncio .gather (* tasks )
190+ end_time = time .time ()
191+ duration = end_time - start_time
192+ logging .info (f"Completed evaluation of { len (results )} rows in { duration } seconds" )
188193
189194 df_evals = pd .concat (results , axis = 0 , ignore_index = True )
190195
@@ -195,3 +200,7 @@ def load_csv(file_path: str, required_columns: list) -> pd.DataFrame:
195200 df_out .to_csv (args .results , index = False )
196201 logging .info (f"Results saved to { args .results } " )
197202 logging .info ("Evaluation completed successfully." )
203+
204+
205+ if __name__ == "__main__" :
206+ asyncio .run (main ())
0 commit comments