-
Notifications
You must be signed in to change notification settings - Fork 58
[ChatQnA Core] Ollama Integration #890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 29 commits
03bbea3
1628cb2
033e15f
3e69d76
60dfe07
5be48af
73b429e
30ab643
1cdbbfe
7c5b689
abd3e02
af7b358
74ac839
9787324
fd4da5b
7b83a7c
bd36f39
acf48ed
88abdf8
de68af6
e497e02
d8abf0c
63646b9
7a2edd1
3780832
d79d059
ddc0f0f
59aed07
13c6016
23992cc
aa83c9b
2ff26e5
4c5ffca
9011358
3efa9b0
a829c59
fb7ed12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ coverage | |
*.lock | ||
!poetry.lock | ||
.vscode | ||
nginx_config/*.conf |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,14 @@ | ||
from .config import config | ||
from .utils import login_to_huggingface, download_huggingface_model, convert_model | ||
from .document import load_file_document | ||
from .logger import logger | ||
from langchain_community.vectorstores import FAISS | ||
from langchain_community.embeddings import OpenVINOBgeEmbeddings | ||
from langchain_community.document_compressors.openvino_rerank import OpenVINOReranker | ||
from langchain.retrievers import ContextualCompressionRetriever | ||
from langchain_huggingface import HuggingFacePipeline | ||
from langchain_community.vectorstores import FAISS | ||
from langchain_core.runnables import RunnablePassthrough | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
import os | ||
import importlib | ||
import pandas as pd | ||
|
||
vectorstore = None | ||
|
@@ -20,56 +17,23 @@ | |
# If RUN_TEST is set to "True", the model download and conversion steps are skipped. | ||
# This flag is set in the conftest.py file before running the tests. | ||
if os.getenv("RUN_TEST", "").lower() != "true": | ||
# login huggingface | ||
login_to_huggingface(config.HF_ACCESS_TOKEN) | ||
if config.MODEL_BACKEND == "openvino": | ||
backend_module = importlib.import_module("app.openvino_backend") | ||
backend_instance = backend_module.OpenVINOBackend() | ||
|
||
# Download convert the model to openvino optimized | ||
download_huggingface_model(config.EMBEDDING_MODEL_ID, config._CACHE_DIR) | ||
download_huggingface_model(config.RERANKER_MODEL_ID, config._CACHE_DIR) | ||
download_huggingface_model(config.LLM_MODEL_ID, config._CACHE_DIR) | ||
elif config.MODEL_BACKEND == "ollama": | ||
backend_module = importlib.import_module("app.ollama_backend") | ||
backend_instance = backend_module.OllamaBackend() | ||
|
||
# Convert to openvino IR | ||
convert_model(config.EMBEDDING_MODEL_ID, config._CACHE_DIR, "embedding") | ||
convert_model(config.RERANKER_MODEL_ID, config._CACHE_DIR, "reranker") | ||
convert_model(config.LLM_MODEL_ID, config._CACHE_DIR, "llm") | ||
else: | ||
raise ValueError(f"Unsupported model backend: {config.MODEL_BACKEND}") | ||
|
||
embedding, llm, reranker = backend_instance.init_models() | ||
|
||
|
||
template = config.PROMPT_TEMPLATE | ||
|
||
prompt = ChatPromptTemplate.from_template(template) | ||
|
||
# Initialize Embedding Model | ||
embedding = OpenVINOBgeEmbeddings( | ||
model_name_or_path=f"{config._CACHE_DIR}/{config.EMBEDDING_MODEL_ID}", | ||
model_kwargs={"device": config.EMBEDDING_DEVICE, "compile": False}, | ||
) | ||
embedding.ov_model.compile() | ||
|
||
# Initialize Reranker Model | ||
reranker = OpenVINOReranker( | ||
model_name_or_path=f"{config._CACHE_DIR}/{config.RERANKER_MODEL_ID}", | ||
model_kwargs={"device": config.RERANKER_DEVICE}, | ||
top_n=2, | ||
) | ||
|
||
# Initialize LLM | ||
llm = HuggingFacePipeline.from_model_id( | ||
model_id=f"{config._CACHE_DIR}/{config.LLM_MODEL_ID}", | ||
task="text-generation", | ||
backend="openvino", | ||
model_kwargs={ | ||
"device": config.LLM_DEVICE, | ||
"ov_config": { | ||
"PERFORMANCE_HINT": "LATENCY", | ||
"NUM_STREAMS": "1", | ||
"CACHE_DIR": f"{config._CACHE_DIR}/{config.LLM_MODEL_ID}/model_cache", | ||
}, | ||
"trust_remote_code": True, | ||
}, | ||
pipeline_kwargs={"max_new_tokens": config.MAX_TOKENS}, | ||
) | ||
if llm.pipeline.tokenizer.eos_token_id: | ||
llm.pipeline.tokenizer.pad_token_id = llm.pipeline.tokenizer.eos_token_id | ||
else: | ||
logger.info("Bypassing to mock these functions because RUN_TEST is set to 'True' to run pytest unit test.") | ||
|
||
|
@@ -88,35 +52,35 @@ def default_context(docs): | |
return "" | ||
|
||
|
||
def get_retriever(enable_rerank=True, search_method="similarity_score_threshold"): | ||
def get_retriever(): | ||
""" | ||
Creates and returns a retriever object with optional reranking capability. | ||
|
||
Args: | ||
enable_rerank (bool): If True, enables the reranker to improve retrieval results. Defaults to True. | ||
search_method (str): The method used for searching within the vector store. Defaults to "similarity_score_threshold". | ||
|
||
Returns: | ||
retriever: A retriever object, optionally wrapped with a contextual compression reranker. | ||
|
||
""" | ||
|
||
enable_rerank = config._ENABLE_RERANK | ||
search_method = config._SEARCH_METHOD | ||
fetch_k = config._FETCH_K | ||
|
||
if vectorstore == None: | ||
return None | ||
|
||
else: | ||
retriever = vectorstore.as_retriever( | ||
search_kwargs={"k": 3, "score_threshold": 0.5}, search_type=search_method | ||
search_kwargs={ | ||
"k": 3, | ||
"fetch_k": fetch_k, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have any empirical number on the quality of retrieval with threshold vs. top-k strategy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. both are finds and now i try to align with what modular does to keep the design same. |
||
}, | ||
search_type=search_method | ||
) | ||
if enable_rerank: | ||
logger.info("Enable reranker") | ||
|
||
return ContextualCompressionRetriever( | ||
base_compressor=reranker, base_retriever=retriever | ||
) | ||
else: | ||
logger.info("Disable reranker") | ||
|
||
return retriever | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from pydantic import PrivateAttr | ||
from pydantic_settings import BaseSettings | ||
from typing import Union | ||
from os.path import dirname, abspath | ||
from .prompt import get_prompt_template | ||
import os | ||
|
@@ -8,34 +9,45 @@ | |
class Settings(BaseSettings): | ||
""" | ||
Settings class for configuring the Chatqna-Core application. | ||
This class manages application-wide configuration, including model settings, device preferences, | ||
supported file formats, and paths for caching and configuration files. It loads additional | ||
configuration from a YAML file if provided, and updates its attributes accordingly. | ||
This class manages application settings, including model backend selection, | ||
model IDs, device configurations, prompt templates, and various internal paths. | ||
It loads configuration from a YAML file, validates backend-specific requirements, | ||
and ensures prompt templates contain required placeholders. | ||
|
||
Attributes: | ||
APP_DISPLAY_NAME (str): Display name of the application. | ||
BASE_DIR (str): Base directory of the application. | ||
SUPPORTED_FORMATS (set): Supported document file formats. | ||
DEBUG (bool): Flag to enable or disable debug mode. | ||
HF_ACCESS_TOKEN (str): Hugging Face access token for model downloads. | ||
EMBEDDING_MODEL_ID (str): Model ID for embeddings. | ||
RERANKER_MODEL_ID (str): Model ID for reranker. | ||
LLM_MODEL_ID (str): Model ID for large language model. | ||
PROMPT_TEMPLATE (str): Prompt template for the LLM. | ||
EMBEDDING_DEVICE (str): Device to run embedding model on. | ||
RERANKER_DEVICE (str): Device to run reranker model on. | ||
LLM_DEVICE (str): Device to run LLM on. | ||
SUPPORTED_FORMATS (set): Supported file formats for input documents. | ||
DEBUG (bool): Debug mode flag. | ||
HF_ACCESS_TOKEN (str): Hugging Face access token. | ||
MODEL_BACKEND (str): Backend to use for models ('openvino' or 'ollama'). | ||
EMBEDDING_MODEL_ID (str): Identifier for the embedding model. | ||
RERANKER_MODEL_ID (str): Identifier for the reranker model. | ||
LLM_MODEL_ID (str): Identifier for the large language model. | ||
PROMPT_TEMPLATE (str): Prompt template string for the LLM. | ||
EMBEDDING_DEVICE (str): Device for embedding model ('CPU', etc.). | ||
RERANKER_DEVICE (str): Device for reranker model ('CPU', etc.). | ||
LLM_DEVICE (str): Device for LLM ('CPU', etc.). | ||
MAX_TOKENS (int): Maximum number of tokens for LLM responses. | ||
ENABLE_RERANK (bool): Flag to enable or disable reranking. | ||
_CACHE_DIR (str): Directory for model cache (private). | ||
_HF_DATASETS_CACHE (str): Directory for Hugging Face datasets cache (private). | ||
_TMP_FILE_PATH (str): Temporary file path for documents (private). | ||
_DEFAULT_MODEL_CONFIG (str): Path to default model configuration YAML (private). | ||
_MODEL_CONFIG_PATH (str): Path to user-provided model configuration YAML (private). | ||
KEEP_ALIVE (Union[str, int, None]): Keep-alive setting for the application. | ||
|
||
Private Attributes: | ||
_ENABLE_RERANK (bool): Whether reranking is enabled. | ||
_SEARCH_METHOD (str): Search method used for retrieval. | ||
_FETCH_K (int): Number of documents to fetch during retrieval. | ||
_CACHE_DIR (str): Directory for model cache. | ||
_HF_DATASETS_CACHE (str): Directory for Hugging Face datasets cache. | ||
_TMP_FILE_PATH (str): Temporary file path for documents. | ||
_DEFAULT_MODEL_CONFIG (str): Path to the default model configuration YAML. | ||
_MODEL_CONFIG_PATH (str): Path to the user-provided model configuration YAML. | ||
|
||
Methods: | ||
__init__(**kwargs): Initializes the Settings object, loads configuration from YAML file, | ||
and updates attributes accordingly. | ||
__init__(**kwargs): Initializes settings, loads configuration from YAML, and validates settings. | ||
_validate_backend_settings(): Validates backend-specific settings and required model IDs. | ||
_check_and_validate_prompt_template(): Ensures the prompt template is set and contains required placeholders. | ||
|
||
Raises: | ||
ValueError: If required settings are missing or invalid, or if unsupported backend is specified. | ||
""" | ||
|
||
APP_DISPLAY_NAME: str = "Chatqna-Core" | ||
|
@@ -44,6 +56,7 @@ class Settings(BaseSettings): | |
DEBUG: bool = False | ||
|
||
HF_ACCESS_TOKEN: str = "" | ||
MODEL_BACKEND: str = "" | ||
EMBEDDING_MODEL_ID: str = "" | ||
RERANKER_MODEL_ID: str = "" | ||
LLM_MODEL_ID: str = "" | ||
|
@@ -52,9 +65,12 @@ class Settings(BaseSettings): | |
RERANKER_DEVICE: str = "CPU" | ||
LLM_DEVICE: str = "CPU" | ||
MAX_TOKENS: int = 1024 | ||
ENABLE_RERANK: bool = True | ||
KEEP_ALIVE: Union[str, int, None] = None | ||
|
||
# These fields will not be affected by environment variables | ||
_ENABLE_RERANK: bool = PrivateAttr(True) | ||
_SEARCH_METHOD: str = PrivateAttr("mmr") | ||
_FETCH_K: int = PrivateAttr(10) | ||
_CACHE_DIR: str = PrivateAttr("/tmp/model_cache") | ||
_HF_DATASETS_CACHE: str = PrivateAttr("/tmp/model_cache") | ||
_TMP_FILE_PATH: str = PrivateAttr("/tmp/chatqna/documents") | ||
|
@@ -65,13 +81,6 @@ class Settings(BaseSettings): | |
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
# The RUN_TEST flag is used to bypass the model config loading during pytest unit testing. | ||
# If RUN_TEST is set to "True", the model config loading is skipped. | ||
# This flag is set in the conftest.py file before running the tests. | ||
if os.getenv("RUN_TEST", "").lower() == "true": | ||
print("INFO - Skipping model config loading in test mode.") | ||
return | ||
|
||
config_file = self._MODEL_CONFIG_PATH if os.path.isfile(self._MODEL_CONFIG_PATH) else self._DEFAULT_MODEL_CONFIG | ||
|
||
if config_file == self._MODEL_CONFIG_PATH: | ||
|
@@ -89,15 +98,58 @@ def __init__(self, **kwargs): | |
if hasattr(self, key): | ||
setattr(self, key, value) | ||
|
||
self._validate_model_ids() | ||
|
||
self._validate_backend_settings() | ||
self._check_and_validate_prompt_template() | ||
|
||
def _validate_model_ids(self): | ||
for model_name in ["EMBEDDING_MODEL_ID", "RERANKER_MODEL_ID", "LLM_MODEL_ID"]: | ||
model_id = getattr(self, model_name) | ||
if not model_id: | ||
raise ValueError(f"{model_name} must not be an empty string.") | ||
|
||
def _validate_backend_settings(self): | ||
|
||
if self.MODEL_BACKEND: | ||
self.MODEL_BACKEND = self.MODEL_BACKEND.lower() | ||
else: | ||
raise ValueError("MODEL_BACKEND must not be an empty string.") | ||
|
||
if self.MODEL_BACKEND == "openvino": | ||
self._ENABLE_RERANK = True | ||
|
||
# Validate Huggingface token | ||
if not self.HF_ACCESS_TOKEN: | ||
raise ValueError("HF_ACCESS_TOKEN must not be an empty string for 'openvino' backend.") | ||
|
||
# Validate required model IDs | ||
for model_name in ["EMBEDDING_MODEL_ID", "RERANKER_MODEL_ID", "LLM_MODEL_ID"]: | ||
model_id = getattr(self, model_name) | ||
if not model_id: | ||
raise ValueError(f"{model_name} must not be an empty string for 'openvino' backend.") | ||
|
||
elif self.MODEL_BACKEND == "ollama": | ||
self._ENABLE_RERANK = False | ||
|
||
# Validate that all devices are set to "CPU" as ollama currently only enabled for CPU | ||
invalid_devices = [ | ||
attr for attr in ["EMBEDDING_DEVICE", "RERANKER_DEVICE", "LLM_DEVICE"] | ||
if getattr(self, attr, "") != "CPU" | ||
] | ||
|
||
if invalid_devices: | ||
raise ValueError( | ||
f"When MODEL_BACKEND is 'ollama', the following devices must be set to 'CPU': {', '.join(invalid_devices)}" | ||
) | ||
|
||
# Handle RERANKER_MODEL_ID | ||
if self.RERANKER_MODEL_ID: | ||
print("WARNING - RERANKER_MODEL_ID is ignored when MODEL_BACKEND is 'ollama'. Setting it to empty.") | ||
self.RERANKER_MODEL_ID = "" | ||
else: | ||
print("INFO - MODEL_BACKEND is 'ollama'. Reranker model is not supported.") | ||
|
||
# Validate required model IDs (excluding reranker) | ||
for model_name in ["EMBEDDING_MODEL_ID", "LLM_MODEL_ID"]: | ||
model_id = getattr(self, model_name) | ||
if not model_id: | ||
raise ValueError(f"{model_name} must not be an empty string for 'ollama' backend.") | ||
|
||
else: | ||
raise ValueError(f"Unsupported MODEL_BACKEND '{self.MODEL_BACKEND}'. Only 'openvino' and 'ollama' are supported.") | ||
|
||
def _check_and_validate_prompt_template(self): | ||
if not self.PROMPT_TEMPLATE: | ||
|
@@ -111,4 +163,4 @@ def _check_and_validate_prompt_template(self): | |
raise ValueError(f"PROMPT_TEMPLATE must include the placeholder {placeholder}.") | ||
|
||
|
||
config = Settings() | ||
config = Settings() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking out aloud. Can we use runtime instead of backend? Or a better name
Reason being Ollama can have OpenVINO backend too though that is not the case currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure we can rename that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes updated.