1010
1111from fenic ._inference .common_openai .openai_utils import convert_messages
1212from fenic ._inference .common_openai .utils import handle_openai_compatible_response
13+ from fenic ._inference .google .gemini_token_counter import GeminiLocalTokenCounter
1314from fenic ._inference .model_client import (
1415 FatalException ,
1516 ModelClient ,
@@ -65,16 +66,23 @@ def __init__(
6566 profiles : Optional [dict [str , object ]] = None ,
6667 default_profile_name : Optional [str ] = None ,
6768 ):
69+ # Choose token counter based on the model's provider
70+ token_counter = None
71+ provider_and_model = model .split ("/" )
72+ if provider_and_model [0 ] == "google" :
73+ token_counter = GeminiLocalTokenCounter (model_name = provider_and_model [1 ])
74+ else :
75+ token_counter = TiktokenTokenCounter (
76+ model_name = provider_and_model [1 ], fallback_encoding = "o200k_base"
77+ )
6878 super ().__init__ (
6979 model = model ,
7080 model_provider = ModelProvider .OPENROUTER ,
7181 model_provider_class = OpenRouterModelProvider (),
7282 rate_limit_strategy = rate_limit_strategy ,
7383 queue_size = queue_size ,
7484 max_backoffs = max_backoffs ,
75- token_counter = TiktokenTokenCounter (
76- model_name = model , fallback_encoding = "o200k_base"
77- ),
85+ token_counter = token_counter ,
7886 )
7987 self ._model_parameters = model_catalog .get_completion_model_parameters (
8088 ModelProvider .OPENROUTER , model
@@ -87,17 +95,22 @@ def __init__(
8795 self ._aio_client = OpenRouterModelProvider ().aio_client
8896 self ._metrics = LMMetrics ()
8997
98+
99+
90100 async def make_single_request (
91101 self , request : FenicCompletionsRequest
92102 ) -> Union [None , FenicCompletionsResponse , TransientException , FatalException ]:
93103 profile = self ._profile_manager .get_profile_by_name (request .model_profile )
94104 common_params = {
95105 "model" : self .model ,
96106 "messages" : convert_messages (request .messages ),
97- "max_completion_tokens" : self ._get_max_output_token_request_limit (request ),
98107 "n" : 1 ,
99108 }
100109
110+ max_completion_tokens = self ._get_max_output_token_request_limit (request )
111+ if max_completion_tokens is not None :
112+ common_params ["max_completion_tokens" ] = max_completion_tokens
113+
101114 if request .top_logprobs :
102115 common_params .update (
103116 {"logprobs" : True , "top_logprobs" : request .top_logprobs }
@@ -238,8 +251,8 @@ def estimate_tokens_for_request(
238251 self , request : FenicCompletionsRequest
239252 ) -> TokenEstimate :
240253 return TokenEstimate (
241- input_tokens = self .token_counter . count_tokens (request . messages ),
242- output_tokens = self .token_counter . count_tokens ( request . messages ) + self . _get_expected_additional_reasoning_tokens (request ),
254+ input_tokens = self ._estimate_input_tokens (request ),
255+ output_tokens = self ._estimate_output_tokens (request ),
243256 )
244257
245258 def reset_metrics (self ):
@@ -248,16 +261,39 @@ def reset_metrics(self):
248261 def get_metrics (self ) -> LMMetrics :
249262 return self ._metrics
250263
251- def _get_max_output_token_request_limit (self , request : FenicCompletionsRequest ) -> int :
252- """Get the upper limit of output tokens for a request.
264+ def _estimate_output_tokens (self , request : FenicCompletionsRequest ) -> int :
265+ """Estimate the number of output tokens for a request."""
266+ base_tokens = request .max_completion_tokens or 0
267+ if request .max_completion_tokens is None and request .messages .user_file :
268+ # TODO(DY): the semantic operator should dictate how the file affects the token estimate
269+ base_tokens += self .token_counter .count_file_output_tokens (messages = request .messages )
270+ return base_tokens + self ._get_expected_additional_reasoning_tokens (request )
253271
254- If max_completion_tokens is not set, don't apply a limit and return None.
272+ def _get_max_output_token_request_limit (self , request : FenicCompletionsRequest ) -> Optional [int ]:
273+ """Return the maximum output token limit for a request.
255274
256- Include the thinking token budget with a safety margin."""
275+ Returns None if max_completion_tokens is not provided (no limit should be set).
276+ If max_completion_tokens is provided, includes the thinking token budget with a safety margin."""
257277 if request .max_completion_tokens is None :
258278 return None
259279 return request .max_completion_tokens + self ._get_expected_additional_reasoning_tokens (request )
260280
281+ def _estimate_input_tokens (self , request : FenicCompletionsRequest ) -> int :
282+ """Estimate the number of input tokens for a request."""
283+ input_tokens = self .token_counter .count_tokens (request .messages , ignore_file = True )
284+ if request .messages .user_file :
285+ input_tokens += self ._estimate_file_input_tokens (request )
286+ return input_tokens
287+
288+ def _estimate_file_input_tokens (self , request : FenicCompletionsRequest ) -> int :
289+ """Estimate the number of input tokens from a file in a request."""
290+ profile_config = self ._profile_manager .get_profile_by_name (request .model_profile )
291+ if profile_config .parsing_engine and profile_config .parsing_engine == "native" :
292+ return self .token_counter .count_file_input_tokens (messages = request .messages )
293+ # OpenRouter's engine tool processes the file first and passes annotated text to the model.
294+ # We can estimate by extracting the text and tokenizing it (which is what count_file_output_tokens does)
295+ return self .token_counter .count_file_output_tokens (messages = request .messages )
296+
261297 # This is a slightly less conservative estimate than the OpenRouter documentation on how reasoning_effort is used to
262298 # generate a reasoning.max_tokens for models that only support reasoning.max_tokens.
263299 # These percentages are slightly lower, since our use-cases generally require fewer reasoning tokens.
0 commit comments