1010
1111from fenic ._inference .common_openai .openai_utils import convert_messages
1212from fenic ._inference .common_openai .utils import handle_openai_compatible_response
13+ from fenic ._inference .google .gemini_token_counter import GeminiLocalTokenCounter
1314from fenic ._inference .model_client import (
1415 FatalException ,
1516 ModelClient ,
@@ -87,17 +88,25 @@ def __init__(
8788 self ._aio_client = OpenRouterModelProvider ().aio_client
8889 self ._metrics = LMMetrics ()
8990
91+ self ._google_token_counter = None
92+ provider_and_model = model .split ("/" )
93+ if provider_and_model [0 ] == "google" :
94+ self ._google_token_counter = GeminiLocalTokenCounter (model_name = provider_and_model [1 ])
95+
9096 async def make_single_request (
9197 self , request : FenicCompletionsRequest
9298 ) -> Union [None , FenicCompletionsResponse , TransientException , FatalException ]:
9399 profile = self ._profile_manager .get_profile_by_name (request .model_profile )
94100 common_params = {
95101 "model" : self .model ,
96102 "messages" : convert_messages (request .messages ),
97- "max_completion_tokens" : self ._get_max_output_token_request_limit (request ),
98103 "n" : 1 ,
99104 }
100105
106+ max_completion_tokens = self ._get_max_output_token_request_limit (request )
107+ if max_completion_tokens is not None :
108+ common_params ["max_completion_tokens" ] = max_completion_tokens
109+
101110 if request .top_logprobs :
102111 common_params .update (
103112 {"logprobs" : True , "top_logprobs" : request .top_logprobs }
@@ -238,8 +247,8 @@ def estimate_tokens_for_request(
238247 self , request : FenicCompletionsRequest
239248 ) -> TokenEstimate :
240249 return TokenEstimate (
241- input_tokens = self .token_counter . count_tokens (request . messages ),
242- output_tokens = self .token_counter . count_tokens ( request . messages ) + self . _get_expected_additional_reasoning_tokens (request ),
250+ input_tokens = self ._estimate_input_tokens (request ),
251+ output_tokens = self ._estimate_output_tokens (request ),
243252 )
244253
245254 def reset_metrics (self ):
@@ -248,16 +257,48 @@ def reset_metrics(self):
248257 def get_metrics (self ) -> LMMetrics :
249258 return self ._metrics
250259
251- def _get_max_output_token_request_limit (self , request : FenicCompletionsRequest ) -> int :
252- """Get the upper limit of output tokens for a request.
260+ def _estimate_output_tokens (self , request : FenicCompletionsRequest ) -> int :
261+ """Estimate the number of output tokens for a request."""
262+ base_tokens = request .max_completion_tokens or 0
263+ if request .max_completion_tokens is None and request .messages .user_file :
264+ # TODO(DY): the semantic operator should dictate how the file affects the token estimate
265+ if self ._google_token_counter :
266+ base_tokens += self ._google_token_counter .count_file_output_tokens (messages = request .messages )
267+ else :
268+ base_tokens += self .token_counter .count_file_output_tokens (messages = request .messages )
269+ return base_tokens + self ._get_expected_additional_reasoning_tokens (request )
253270
254- If max_completion_tokens is not set, don't apply a limit and return None.
271+ def _get_max_output_token_request_limit (self , request : FenicCompletionsRequest ) -> Optional [int ]:
272+ """Get the upper limit of output tokens for a request.
255273
256- Include the thinking token budget with a safety margin."""
274+ Returns None if max_completion_tokens is not provided (no limit should be set).
275+ If max_completion_tokens is provided, includes the thinking token budget with a safety margin."""
257276 if request .max_completion_tokens is None :
258277 return None
259278 return request .max_completion_tokens + self ._get_expected_additional_reasoning_tokens (request )
260279
280+ def _estimate_input_tokens (self , request : FenicCompletionsRequest ) -> int :
281+ """Estimate the number of input tokens for a request."""
282+ if self ._google_token_counter :
283+ input_tokens = self ._google_token_counter .count_tokens (request .messages , ignore_file = True )
284+ else :
285+ input_tokens = self .token_counter .count_tokens (request .messages )
286+ if request .messages .user_file :
287+ input_tokens += self ._estimate_file_input_tokens (request )
288+ return input_tokens
289+
290+ def _estimate_file_input_tokens (self , request : FenicCompletionsRequest ) -> int :
291+ """Estimate the number of input tokens from a file in a request."""
292+ profile_config = self ._profile_manager .get_profile_by_name (request .model_profile )
293+ if profile_config .parsing_engine and profile_config .parsing_engine == "native" :
294+ if self ._google_token_counter :
295+ return self ._google_token_counter .count_file_input_tokens (messages = request .messages )
296+ else :
297+ return self .token_counter .count_file_input_tokens (messages = request .messages )
298+ # OpenRouter's engine tool processes the file first and passes annotated text to the model.
299+ # We can estimate by extracting the text and tokenizing it (which is what count_file_output_tokens does)
300+ return self .token_counter .count_file_output_tokens (messages = request .messages )
301+
261302 # This is a slightly less conservative estimate than the OpenRouter documentation on how reasoning_effort is used to
262303 # generate a reasoning.max_tokens for models that only support reasoning.max_tokens.
263304 # These percentages are slightly lower, since our use-cases generally require fewer reasoning tokens.
0 commit comments