Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch

from loguru import logger
Expand All @@ -13,6 +14,7 @@

__all__ = ["Model"]

TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
# Disable gradients
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -40,7 +42,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
device = get_device()
logger.info(f"backend device: {device}")

config = AutoConfig.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
if config.model_type == "bert":
config: BertConfig
if (
Expand All @@ -51,12 +53,22 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
and FLASH_ATTENTION
):
if pool != "cls":
return DefaultModel(model_path, device, datatype, pool)
return DefaultModel(
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
)
return FlashBert(model_path, device, datatype)
if config.architectures[0].endswith("Classification"):
return ClassificationModel(model_path, device, datatype)
return ClassificationModel(
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
)
else:
return DefaultModel(model_path, device, datatype, pool)
return DefaultModel(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
else:
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
Expand All @@ -66,13 +78,35 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):

adapt_transformers_to_gaudi()
if config.architectures[0].endswith("Classification"):
model_handle = ClassificationModel(model_path, device, datatype)
model_handle = ClassificationModel(
model_path,
device,
datatype,
trust_remote=TRUST_REMOTE_CODE,
)
else:
model_handle = DefaultModel(model_path, device, datatype, pool)
model_handle = DefaultModel(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
model_handle.model = wrap_in_hpu_graph(model_handle.model)
return model_handle
elif use_ipex():
if config.architectures[0].endswith("Classification"):
return ClassificationModel(model_path, device, datatype)
return ClassificationModel(
model_path,
device,
datatype,
trust_remote=TRUST_REMOTE_CODE,
)
else:
return DefaultModel(model_path, device, datatype, pool)
return DefaultModel(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,16 @@


class ClassificationModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
model = AutoModelForSequenceClassification.from_pretrained(model_path)
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
trust_remote: bool = False,
):
model = AutoModelForSequenceClassification.from_pretrained(
model_path, trust_remote_code=trust_remote
)
model = model.to(dtype).to(device)

self.hidden_size = model.config.hidden_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@

class DefaultModel(Model):
def __init__(
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str,
trust_remote: bool = False,
):
model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
model = (
AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote)
.to(dtype)
.to(device)
)
self.hidden_size = model.config.hidden_size
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)

Expand Down
Loading