Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,000 changes: 1,000 additions & 0 deletions examples/fenic_in_120_seconds/18_pdf_processing.ipynb

Large diffs are not rendered by default.

1,346 changes: 1,346 additions & 0 deletions specs/llm_cache_design.md

Large diffs are not rendered by default.

100 changes: 73 additions & 27 deletions src/fenic/_backends/local/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@

logger = logging.getLogger(__name__)


@dataclass
class LanguageModelRegistry:
models: dict[str, LanguageModel]
default_model: LanguageModel


@dataclass
class EmbeddingModelRegistry:
models: dict[str, EmbeddingModel]
default_model: EmbeddingModel


class SessionModelRegistry:
"""Registry for managing language and embedding models in a session.

Expand All @@ -59,18 +62,19 @@ class SessionModelRegistry:
language_model_registry: Optional[LanguageModelRegistry] = None
embedding_model_registry: Optional[EmbeddingModelRegistry] = None

def __init__(self, config: ResolvedSemanticConfig):
def __init__(self, config: ResolvedSemanticConfig, cache=None):
"""Initialize the model registry with configuration.

Args:
config (ResolvedSemanticConfig): Configuration containing model settings and defaults.
cache: Optional LLM response cache instance to pass to language models.
"""
validate_providers: set[ModelProviderClass] = set()
if config.language_models:
language_model_config = config.language_models
models: dict[str, LanguageModel] = {}
for alias, model_config in language_model_config.model_configs.items():
model = self._initialize_language_model(model_config)
model = self._initialize_language_model(model_config, cache)
models[alias] = model
validate_providers.add(model.client.model_provider_class)
self.language_model_registry = LanguageModelRegistry(
Expand All @@ -91,7 +95,9 @@ def __init__(self, config: ResolvedSemanticConfig):
)
if len(validate_providers) > 0:
with EventLoopManager().loop_context() as loop:
future = asyncio.run_coroutine_threadsafe(_validate_provider_api_keys(validate_providers), loop)
future = asyncio.run_coroutine_threadsafe(
_validate_provider_api_keys(validate_providers), loop
)
future.result()

def get_language_model_metrics(self) -> LMMetrics:
Expand Down Expand Up @@ -132,7 +138,9 @@ def reset_embedding_model_metrics(self):
for embedding_model in self.embedding_model_registry.models.values():
embedding_model.reset_metrics()

def get_language_model(self, alias: Optional[ResolvedModelAlias] = None) -> LanguageModel:
def get_language_model(
self, alias: Optional[ResolvedModelAlias] = None
) -> LanguageModel:
"""Get a language model by alias or return the default model.

Args:
Expand All @@ -145,15 +153,21 @@ def get_language_model(self, alias: Optional[ResolvedModelAlias] = None) -> Lang
SessionError: If the requested model is not found.
"""
if not self.language_model_registry:
raise InternalError("Requested language model, but no language models are configured.")
raise InternalError(
"Requested language model, but no language models are configured."
)
if alias is None:
return self.language_model_registry.default_model
language_model_for_alias = self.language_model_registry.models.get(alias.name)
if language_model_for_alias is None:
raise InternalError(f"Language Model with alias '{alias.name}' not found in configured models: {sorted(list(self.language_model_registry.models.keys()))}")
raise InternalError(
f"Language Model with alias '{alias.name}' not found in configured models: {sorted(list(self.language_model_registry.models.keys()))}"
)
return language_model_for_alias

def get_embedding_model(self, alias: Optional[ResolvedModelAlias] = None) -> EmbeddingModel:
def get_embedding_model(
self, alias: Optional[ResolvedModelAlias] = None
) -> EmbeddingModel:
"""Get an embedding model by alias or return the default model.

Args:
Expand All @@ -166,12 +180,18 @@ def get_embedding_model(self, alias: Optional[ResolvedModelAlias] = None) -> Emb
SessionError: If no embedding models are configured or if the requested model is not found.
"""
if not self.embedding_model_registry:
raise InternalError("Requested embedding model, but no embedding models are configured.")
raise InternalError(
"Requested embedding model, but no embedding models are configured."
)
if alias is None:
return self.embedding_model_registry.default_model
embedding_model_for_model_alias = self.embedding_model_registry.models.get(alias.name)
embedding_model_for_model_alias = self.embedding_model_registry.models.get(
alias.name
)
if embedding_model_for_model_alias is None:
raise InternalError(f"Embedding Model with model name '{alias.name}' not found in configured models: {sorted(list(self.embedding_model_registry.models.keys()))}")
raise InternalError(
f"Embedding Model with model name '{alias.name}' not found in configured models: {sorted(list(self.embedding_model_registry.models.keys()))}"
)
return embedding_model_for_model_alias

def shutdown_models(self):
Expand All @@ -181,15 +201,21 @@ def shutdown_models(self):
try:
language_model.client.shutdown()
except Exception as e:
logger.warning(f"Failed graceful shutdown of language model client {alias}: {e}")
logger.warning(
f"Failed graceful shutdown of language model client {alias}: {e}"
)
if self.embedding_model_registry:
for alias, embedding_model in self.embedding_model_registry.models.items():
try:
embedding_model.client.shutdown()
except Exception as e:
logger.warning(f"Failed graceful shutdown of embedding model client {alias}: {e}")
logger.warning(
f"Failed graceful shutdown of embedding model client {alias}: {e}"
)

def _initialize_embedding_model(self, model_config: ResolvedModelConfig) -> EmbeddingModel:
def _initialize_embedding_model(
self, model_config: ResolvedModelConfig
) -> EmbeddingModel:
"""Initialize an embedding model with the given configuration.

Args:
Expand All @@ -203,7 +229,9 @@ def _initialize_embedding_model(self, model_config: ResolvedModelConfig) -> Embe
"""
try:
if isinstance(model_config, ResolvedOpenAIModelConfig):
rate_limit_strategy = UnifiedTokenRateLimitStrategy(rpm=model_config.rpm, tpm=model_config.tpm)
rate_limit_strategy = UnifiedTokenRateLimitStrategy(
rpm=model_config.rpm, tpm=model_config.tpm
)
client = OpenAIBatchEmbeddingsClient(
rate_limit_strategy=rate_limit_strategy,
model=model_config.model_name,
Expand All @@ -217,16 +245,20 @@ def _initialize_embedding_model(self, model_config: ResolvedModelConfig) -> Embe
raise ImportError(
"To use Google models, please install the required dependencies by running: pip install fenic[google]"
) from err
rate_limit_strategy = UnifiedTokenRateLimitStrategy(rpm=model_config.rpm, tpm=model_config.tpm)
rate_limit_strategy = UnifiedTokenRateLimitStrategy(
rpm=model_config.rpm, tpm=model_config.tpm
)
client = GoogleBatchEmbeddingsClient(
rate_limit_strategy=rate_limit_strategy,
model=model_config.model_name,
model_provider=model_config.model_provider,
profiles=model_config.profiles,
default_profile_name=model_config.default_profile
default_profile_name=model_config.default_profile,
)
elif isinstance(model_config, ResolvedCohereModelConfig):
rate_limit_strategy = UnifiedTokenRateLimitStrategy(rpm=model_config.rpm, tpm=model_config.tpm)
rate_limit_strategy = UnifiedTokenRateLimitStrategy(
rpm=model_config.rpm, tpm=model_config.tpm
)
try:
from fenic._inference.cohere.cohere_batch_embeddings_client import (
CohereBatchEmbeddingsClient,
Expand All @@ -239,25 +271,29 @@ def _initialize_embedding_model(self, model_config: ResolvedModelConfig) -> Embe
rate_limit_strategy=rate_limit_strategy,
model=model_config.model_name,
profile_configurations=model_config.profiles,
default_profile_name=model_config.default_profile
default_profile_name=model_config.default_profile,
)
else:
raise ConfigurationError(f"Unsupported model configuration: {model_config}")
raise ConfigurationError(
f"Unsupported model configuration: {model_config}"
)

except Exception as e:
raise SessionError(f"Failed to create embedding model client: {e}") from e

return EmbeddingModel(client=client)

def _initialize_language_model(self, model_config: ResolvedModelConfig) -> LanguageModel:
def _initialize_language_model(
self, model_config: ResolvedModelConfig, cache=None
) -> LanguageModel:
"""Initialize a language client model with the given configuration.

Args:
model_alias: Base alias for the model
model_config (ModelConfig): Configuration for the language model.
cache: Optional LLM response cache instance.

Returns:
dict[str, LanguageModel]: Dictionary mapping alias to initialized language models.
LanguageModel: Initialized language model.

Raises:
SessionError: If model initialization fails.
Expand All @@ -266,12 +302,15 @@ def _initialize_language_model(self, model_config: ResolvedModelConfig) -> Langu
"""
try:
if isinstance(model_config, ResolvedOpenAIModelConfig):
rate_limit_strategy = UnifiedTokenRateLimitStrategy(rpm=model_config.rpm, tpm=model_config.tpm)
rate_limit_strategy = UnifiedTokenRateLimitStrategy(
rpm=model_config.rpm, tpm=model_config.tpm
)
client = OpenAIBatchChatCompletionsClient(
model=model_config.model_name,
rate_limit_strategy=rate_limit_strategy,
profiles=model_config.profiles,
default_profile_name=model_config.default_profile,
cache=cache,
)

elif isinstance(model_config, ResolvedAnthropicModelConfig):
Expand All @@ -286,13 +325,14 @@ def _initialize_language_model(self, model_config: ResolvedModelConfig) -> Langu
rate_limit_strategy = SeparatedTokenRateLimitStrategy(
rpm=model_config.rpm,
input_tpm=model_config.input_tpm,
output_tpm=model_config.output_tpm
output_tpm=model_config.output_tpm,
)
client = AnthropicBatchCompletionsClient(
model=model_config.model_name,
rate_limit_strategy=rate_limit_strategy,
profiles=model_config.profiles,
default_profile_name=model_config.default_profile,
cache=cache,
)

elif isinstance(model_config, ResolvedGoogleModelConfig):
Expand All @@ -304,13 +344,16 @@ def _initialize_language_model(self, model_config: ResolvedModelConfig) -> Langu
raise ImportError(
"To use Google models, please install the required dependencies by running: pip install fenic[google]"
) from err
rate_limit_strategy = UnifiedTokenRateLimitStrategy(rpm=model_config.rpm, tpm=model_config.tpm)
rate_limit_strategy = UnifiedTokenRateLimitStrategy(
rpm=model_config.rpm, tpm=model_config.tpm
)
client = GeminiNativeChatCompletionsClient(
model=model_config.model_name,
model_provider=model_config.model_provider,
rate_limit_strategy=rate_limit_strategy,
profiles=model_config.profiles,
default_profile_name=model_config.default_profile
default_profile_name=model_config.default_profile,
cache=cache,
)
elif isinstance(model_config, ResolvedOpenRouterModelConfig):
rate_limit_strategy = AdaptiveBackoffRateLimitStrategy()
Expand All @@ -319,9 +362,12 @@ def _initialize_language_model(self, model_config: ResolvedModelConfig) -> Langu
rate_limit_strategy=rate_limit_strategy,
profiles=model_config.profiles,
default_profile_name=model_config.default_profile,
cache=cache,
)
else:
raise ConfigurationError(f"Unsupported model configuration: {model_config}")
raise ConfigurationError(
f"Unsupported model configuration: {model_config}"
)
return LanguageModel(client=client)

except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
model: LanguageModel,
temperature: float,
model_alias: Optional[ResolvedModelAlias] = None,
request_timeout: Optional[float] = None,
):
super().__init__(
input,
Expand All @@ -145,6 +146,7 @@ def __init__(
temperature=temperature,
response_format=SENTIMENT_ANALYSIS_FORMAT,
model_profile=model_alias.profile if model_alias else None,
request_timeout=request_timeout,
),
),
EXAMPLES,
Expand Down
1 change: 1 addition & 0 deletions src/fenic/_backends/local/semantic_operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def send_requests(
response_format=self.inference_config.response_format,
top_logprobs=self.inference_config.top_logprobs,
model_profile=self.inference_config.model_profile,
request_timeout=self.inference_config.request_timeout,
)

completions = [
Expand Down
2 changes: 2 additions & 0 deletions src/fenic/_backends/local/semantic_operators/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
temperature: float,
examples: Optional[ClassifyExampleCollection] = None,
model_alias: Optional[ResolvedModelAlias] = None,
request_timeout: Optional[float] = None,
):
self.classes = classes
self.valid_labels = {class_def.label for class_def in classes}
Expand All @@ -59,6 +60,7 @@ def __init__(
temperature=temperature,
response_format=ResolvedResponseFormat.from_pydantic_model(self.output_model, generate_struct_type=False),
model_profile=model_alias.profile if model_alias else None,
request_timeout=request_timeout,
),
),
examples,
Expand Down
2 changes: 2 additions & 0 deletions src/fenic/_backends/local/semantic_operators/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
max_output_tokens: int,
temperature: float,
model_alias: Optional[ResolvedModelAlias] = None,
request_timeout: Optional[float] = None,
):
self.resolved_format = response_format
super().__init__(
Expand All @@ -61,6 +62,7 @@ def __init__(
temperature=temperature,
response_format=response_format,
model_profile=model_alias.profile if model_alias else None,
request_timeout=request_timeout,
),
model=model,
),
Expand Down
2 changes: 2 additions & 0 deletions src/fenic/_backends/local/semantic_operators/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
model_alias: Optional[ResolvedModelAlias] = None,
response_format: Optional[ResolvedResponseFormat] = None,
examples: Optional[MapExampleCollection] = None,
request_timeout: Optional[float] = None,
):
super().__init__(
input,
Expand All @@ -64,6 +65,7 @@ def __init__(
response_format=response_format,
temperature=temperature,
model_profile=model_alias.profile if model_alias else None,
request_timeout=request_timeout,
),
),
jinja_template=jinja2.Template(jinja_template),
Expand Down
Loading
Loading