Skip to content

Commit e3f58cd

Browse files
authored
feat: Add openai support for semantic parse_pdf (#253)
### TL;DR Added PDF file support for OpenAI models with proper token counting and estimation. ### What changed? - Made `max_completion_tokens` optional in OpenAI chat completions requests - Implemented PDF file token counting for OpenAI models: - Added methods to count tokens for PDF files in input and output - Updated token counter to handle PDF files with proper token estimation - Added support for PDF parsing to various OpenAI models in the model catalog - Refactored openai token counting logic - Mini refactor - separate `_max_output_tokens` user limit concept from `_estimate_output_tokens` for cost estimation and throttling - added to openai and updated gemini - Max tokens (put in request) - Use max tokens provided by semantic_operator if exists **- OR, for page parsing specifically, use an upper limit based on output limit of our smallest VLM supported (8000 tokens)** - Add expected reasoning effort - Estimate output tokens (for cost estimate and throttling) - Use max tokens provided by semantic_operator if exists - OR estimate file output tokens - Add expected reasoning effort - Added openai models to semantic_parse_pdf tests ### Out of scope The token estimation should happen at the semantic operator level, since it has the context of what its expecting from the model. Currently, semantic operator only passes 'max token' limit to the client and we use that upper limit in our estimates. As a future improvement we should refactor and have the semantic operator decide on the output token limit for the request ### How to test? 1. Run the new token counter tests: `pytest tests/_inference/test_openai_token_counter.py` 2. Test PDF parsing with OpenAI models: `pytest tests/_backends/local/functions/test_semantic_parse_pdf.py` 3. Verify that token estimation works correctly with PDF files by using a model that supports PDF parsing
1 parent c99b9ac commit e3f58cd

20 files changed

+256
-104
lines changed

src/fenic/_backends/local/semantic_operators/parse_pdf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ def __init__(
4848
page_separator: Optional[str] = None,
4949
describe_images: bool = False,
5050
model_alias: Optional[ResolvedModelAlias] = None,
51+
max_output_tokens: Optional[int] = None,
5152
):
5253
self.page_separator = page_separator
5354
self.describe_images = describe_images
5455
self.model = model
5556
self.model_alias = model_alias
57+
self.max_output_tokens = max_output_tokens
5658

5759
DocFolderLoader.check_file_extensions(input.to_list(), "pdf")
5860

@@ -62,7 +64,7 @@ def __init__(
6264
model=model,
6365
operator_name="semantic.parse_pdf",
6466
inference_config=InferenceConfiguration(
65-
max_output_tokens=None,
67+
max_output_tokens=max_output_tokens,
6668
temperature=1.0, # Use a higher temperature so gemini flash models can handle complex table formatting. For more info see the conversation here: https://discuss.ai.google.dev/t/gemini-2-0-flash-has-a-weird-bug/65119/26
6769
model_profile=model_alias.profile if model_alias else None,
6870
),

src/fenic/_backends/local/transpiler/expr_converter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,8 +712,9 @@ def parse_pdf_fn(batch: pl.Series) -> pl.Series:
712712
page_separator=logical.page_separator,
713713
describe_images=logical.describe_images,
714714
model_alias=logical.model_alias,
715+
max_output_tokens=logical.max_output_tokens,
715716
).execute()
716-
717+
717718
return self._convert_expr(logical.expr).map_batches(
718719
parse_pdf_fn, return_dtype=pl.Utf8
719720
)

src/fenic/_inference/anthropic/anthropic_batch_chat_completions_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _estimate_structured_output_overhead(self, response_format) -> int:
275275
"""
276276
return self.estimate_response_format_tokens(response_format)
277277

278-
def _get_max_output_tokens(self, request: FenicCompletionsRequest) -> int:
278+
def _get_max_output_token_request_limit(self, request: FenicCompletionsRequest) -> int:
279279
"""Get maximum output tokens including thinking budget.
280280
281281
Args:
@@ -329,7 +329,7 @@ def estimate_tokens_for_request(self, request: FenicCompletionsRequest):
329329
input_tokens += self._count_auxiliary_input_tokens(request)
330330

331331
# Estimate output tokens
332-
output_tokens = self._get_max_output_tokens(request)
332+
output_tokens = self._get_max_output_token_request_limit(request)
333333

334334
return TokenEstimate(
335335
input_tokens=input_tokens,

src/fenic/_inference/cohere/cohere_batch_embeddings_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def estimate_tokens_for_request(self, request: FenicEmbeddingsRequest) -> TokenE
171171
output_tokens=0
172172
)
173173

174-
def _get_max_output_tokens(self, request: FenicEmbeddingsRequest) -> int:
174+
def _get_max_output_token_request_limit(self, request: FenicEmbeddingsRequest) -> int:
175175
"""Get maximum output tokens (always 0 for embeddings).
176176
177177
Returns:

src/fenic/_inference/common_openai/openai_chat_completions_core.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,13 @@ async def make_single_request(
9090
common_params: dict[str, Any] = {
9191
"model": self._model,
9292
"messages": convert_messages(request.messages),
93-
"max_completion_tokens": request.max_completion_tokens + profile_configuration.expected_additional_reasoning_tokens,
9493
"n": 1,
9594
}
95+
96+
max_completion_tokens = self.get_max_output_token_request_limit(request, profile_configuration)
97+
if max_completion_tokens is not None:
98+
common_params["max_completion_tokens"] = max_completion_tokens
99+
96100
if request.temperature:
97101
common_params.update({"temperature": request.temperature})
98102

@@ -213,3 +217,13 @@ def get_request_key(self, request: FenicCompletionsRequest) -> str:
213217
A unique key for the request
214218
"""
215219
return generate_completion_request_key(request)
220+
221+
def get_max_output_token_request_limit(self, request: FenicCompletionsRequest, profile_config:OpenAICompletionProfileConfiguration) -> Optional[int]:
222+
"""Return the maximum output token limit for a request.
223+
224+
Returns None if max_completion_tokens is not provided (no limit should be set).
225+
If max_completion_tokens is provided, includes the thinking token budget with a safety margin.
226+
"""
227+
if request.max_completion_tokens is None:
228+
return None
229+
return request.max_completion_tokens + profile_config.expected_additional_reasoning_tokens

src/fenic/_inference/google/gemini_batch_embeddings_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def estimate_tokens_for_request(self, request: FenicEmbeddingsRequest) -> TokenE
121121
input_tokens=self.token_counter.count_tokens(request.doc), output_tokens=0
122122
)
123123

124-
def _get_max_output_tokens(self, request: FenicEmbeddingsRequest) -> int:
124+
def _get_max_output_token_request_limit(self, request: FenicEmbeddingsRequest) -> int:
125125
return 0
126126

127127
def reset_metrics(self):

src/fenic/_inference/google/gemini_native_chat_completions_client.py

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -132,56 +132,6 @@ def count_tokens(self, messages: Tokenizable) -> int: # type: ignore[override]
132132
# Re-expose for mypy – same implementation as parent.
133133
return super().count_tokens(messages)
134134

135-
def _estimate_structured_output_overhead(self, response_format: ResolvedResponseFormat) -> int:
136-
"""Use Google-specific response schema token estimation.
137-
138-
Args:
139-
response_format: Pydantic model class defining the response format
140-
141-
Returns:
142-
Estimated token overhead for structured output
143-
"""
144-
return self._estimate_response_schema_tokens(response_format)
145-
146-
def _get_max_output_tokens(self, request: FenicCompletionsRequest) -> Optional[int]:
147-
"""Get maximum output tokens including thinking budget.
148-
149-
If max_completion_tokens is not set, return None.
150-
151-
Conservative estimate that includes both completion tokens and
152-
thinking token budget with a safety margin.
153-
154-
Args:
155-
request: The completion request
156-
157-
Returns:
158-
Maximum output tokens (completion + thinking budget with safety margin)
159-
"""
160-
if request.max_completion_tokens is None:
161-
return None
162-
profile_config = self._profile_manager.get_profile_by_name(
163-
request.model_profile
164-
)
165-
return request.max_completion_tokens + int(
166-
1.5 * profile_config.thinking_token_budget
167-
)
168-
169-
@cache # noqa: B019 – builtin cache OK here.
170-
def _estimate_response_schema_tokens(self, response_format: ResolvedResponseFormat) -> int:
171-
"""Estimate token count for a response format schema.
172-
173-
Uses Google's tokenizer to count tokens in a JSON schema representation
174-
of the response format. Results are cached for performance.
175-
176-
Args:
177-
response_format: Pydantic model class defining the response format
178-
179-
Returns:
180-
Estimated token count for the response format
181-
"""
182-
schema_str = response_format.schema_fingerprint
183-
return self._token_counter.count_tokens(schema_str)
184-
185135
def get_request_key(self, request: FenicCompletionsRequest) -> str:
186136
"""Generate a unique key for the request.
187137
@@ -196,19 +146,17 @@ def get_request_key(self, request: FenicCompletionsRequest) -> str:
196146
def estimate_tokens_for_request(self, request: FenicCompletionsRequest):
197147
"""Estimate the number of tokens for a request.
198148
149+
If the request provides a max_completion_tokens value, use that. Otherwise, estimate the output tokens based on the file size.
150+
199151
Args:
200152
request: The request to estimate tokens for
201153
202154
Returns:
203155
TokenEstimate: The estimated token usage
204156
"""
205-
206-
# Count input tokens
207157
input_tokens = self.count_tokens(request.messages)
208158
input_tokens += self._count_auxiliary_input_tokens(request)
209-
210-
output_tokens = self._get_max_output_tokens(request) or self._model_parameters.max_output_tokens
211-
159+
output_tokens = self._estimate_output_tokens(request)
212160
return TokenEstimate(input_tokens=input_tokens, output_tokens=output_tokens)
213161

214162
async def make_single_request(
@@ -228,16 +176,17 @@ async def make_single_request(
228176
"""
229177

230178
profile_config = self._profile_manager.get_profile_by_name(request.model_profile)
231-
max_output_tokens = self._get_max_output_tokens(request)
232-
233179
generation_config: GenerateContentConfigDict = {
234180
"temperature": request.temperature,
235181
"response_logprobs": request.top_logprobs is not None,
236182
"logprobs": request.top_logprobs,
237183
"system_instruction": request.messages.system,
238184
}
185+
186+
max_output_tokens = self._get_max_output_token_request_limit(request)
239187
if max_output_tokens is not None:
240188
generation_config["max_output_tokens"] = max_output_tokens
189+
241190
generation_config.update(profile_config.additional_generation_config)
242191
if request.structured_output is not None:
243192
generation_config.update(
@@ -355,3 +304,54 @@ async def make_single_request(
355304
finally:
356305
if file_obj:
357306
await delete_file(self._client, file_obj.name)
307+
308+
@cache # noqa: B019 – builtin cache OK here.
309+
def _estimate_response_schema_tokens(self, response_format: ResolvedResponseFormat) -> int:
310+
"""Estimate token count for a response format schema.
311+
312+
Uses Google's tokenizer to count tokens in a JSON schema representation
313+
of the response format. Results are cached for performance.
314+
315+
Args:
316+
response_format: Pydantic model class defining the response format
317+
318+
Returns:
319+
Estimated token count for the response format
320+
"""
321+
schema_str = response_format.schema_fingerprint
322+
return self._token_counter.count_tokens(schema_str)
323+
324+
def _estimate_structured_output_overhead(self, response_format: ResolvedResponseFormat) -> int:
325+
"""Use Google-specific response schema token estimation.
326+
327+
Args:
328+
response_format: Pydantic model class defining the response format
329+
330+
Returns:
331+
Estimated token overhead for structured output
332+
"""
333+
return self._estimate_response_schema_tokens(response_format)
334+
335+
def _estimate_output_tokens(self, request: FenicCompletionsRequest) -> int:
336+
"""Estimate the number of output tokens for a request."""
337+
estimated_output_tokens = request.max_completion_tokens or 0
338+
if request.max_completion_tokens is None and request.messages.user_file:
339+
# TODO(DY): the semantic operator should dictate how the file affects the token estimate
340+
estimated_output_tokens = self.token_counter.count_file_output_tokens(request.messages)
341+
return estimated_output_tokens + self._get_expected_additional_reasoning_tokens(request)
342+
343+
def _get_max_output_token_request_limit(self, request: FenicCompletionsRequest) -> Optional[int]:
344+
"""Get the upper limit of output tokens for a request.
345+
346+
Returns None if max_completion_tokens is not provided (no limit should be set).
347+
If max_completion_tokens is provided, includes the thinking token budget with a safety margin."""
348+
if request.max_completion_tokens is None:
349+
return None
350+
return request.max_completion_tokens + self._get_expected_additional_reasoning_tokens(request)
351+
352+
def _get_expected_additional_reasoning_tokens(self, request: FenicCompletionsRequest) -> int:
353+
"""Get the expected additional reasoning tokens for a request. Include a safety margin."""
354+
profile_config = self._profile_manager.get_profile_by_name(request.model_profile)
355+
return int(
356+
1.5 * profile_config.thinking_token_budget
357+
)

src/fenic/_inference/language_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
@dataclass
2424
class InferenceConfiguration:
25-
# If max_output_tokens is not provided, do not include it in the request.
25+
# If max_output_tokens is not provided, model_client will add a guardrail based on the estimated output tokens.
2626
max_output_tokens: Optional[int]
2727
temperature: float
2828
top_logprobs: Optional[int] = None

src/fenic/_inference/model_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def _estimate_structured_output_overhead(self, response_format: ResolvedResponse
245245

246246

247247
@abstractmethod
248-
def _get_max_output_tokens(self, request: RequestT) -> int:
249-
"""Get conservative output token estimate. Override in subclasses for provider-specific logic."""
248+
def _get_max_output_token_request_limit(self, request: RequestT) -> int:
249+
"""Get the upper limit of output tokens to set on a request."""
250250
pass
251251

252252
#

src/fenic/_inference/openai/openai_batch_chat_completions_client.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
profile_configurations=profiles,
6666
default_profile_name=default_profile_name,
6767
)
68+
6869
self._core = OpenAIChatCompletionsCore(
6970
model=model,
7071
model_provider=ModelProvider.OPENAI,
@@ -108,7 +109,7 @@ def estimate_tokens_for_request(self, request: FenicCompletionsRequest) -> Token
108109
"""
109110
return TokenEstimate(
110111
input_tokens=self.token_counter.count_tokens(request.messages),
111-
output_tokens=self._get_max_output_tokens(request)
112+
output_tokens=self._estimate_output_tokens(request)
112113
)
113114

114115
def reset_metrics(self):
@@ -123,10 +124,21 @@ def get_metrics(self) -> LMMetrics:
123124
"""
124125
return self._core.get_metrics()
125126

126-
def _get_max_output_tokens(self, request: FenicCompletionsRequest) -> int:
127-
"""Conservative estimate: max_completion_tokens + reasoning effort-based thinking tokens."""
128-
base_tokens = request.max_completion_tokens
129-
130-
# Get profile-specific reasoning effort
127+
def _estimate_output_tokens(self, request: FenicCompletionsRequest) -> int:
128+
"""Estimate the number of output tokens for a request."""
129+
base_tokens = request.max_completion_tokens or 0
130+
if request.max_completion_tokens is None and request.messages.user_file:
131+
# TODO(DY): the semantic operator should dictate how the file affects the token estimate
132+
base_tokens += self.token_counter.count_file_output_tokens(messages=request.messages)
131133
profile_config = self._profile_manager.get_profile_by_name(request.model_profile)
132134
return base_tokens + profile_config.expected_additional_reasoning_tokens
135+
136+
def _get_max_output_token_request_limit(self, request: FenicCompletionsRequest) -> int:
137+
"""Return the maximum output token limit for a request.
138+
139+
For file parsing requests, use a guardrail limit of 8192 tokens (the lowest output limit of a VLM model we support).
140+
141+
Include the thinking token budget with a safety margin.
142+
"""
143+
profile_config = self._profile_manager.get_profile_by_name(request.model_profile)
144+
return self._core.get_max_output_token_request_limit(request, profile_config)

0 commit comments

Comments
 (0)