Skip to content
This repository was archived by the owner on Sep 20, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/emd/models/embeddings/bge_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@
local_instance,
],
supported_services=[
sagemaker_service,
ecs_service,
local_service
ecs_service
],
supported_frameworks=[
fastapi_framework
Expand All @@ -73,4 +71,4 @@
model_series=BGE_SERIES,
description="BGE-VL-large is a larger multimodal embedding model that supports text, image, and text-image pair inputs for high-performance multimodal representation learning and cross-modal retrieval tasks."
)
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def start(self):
device_map="cuda",
**self.pretrained_model_init_kwargs
)

# BGE-VL specific initialization
if self.is_bge_vl:
try:
self.model.set_processor(model_abs_path)
logger.info(f"BGE-VL processor set successfully for model: {self.model_id}")
except Exception as e:
logger.warning(f"Failed to set BGE-VL processor: {e}")

logger.info(f"model: {self.model}")
# TODO add tokenizer init args from model's definition
# self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -106,20 +106,20 @@ def _process_base64_image(self, image_data: str) -> Image.Image:
# Handle data URL format
if image_data.startswith('data:image'):
image_data = image_data.split(',')[1]

# Decode base64
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))

# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')

return image
except Exception as e:
logger.error(f"Failed to process base64 image: {e}")
raise ValueError(f"Invalid image data: {e}")

def _convert_pil_to_bytesio(self, pil_image: Image.Image) -> io.BytesIO:
"""Convert PIL Image to BytesIO object for BGE-VL compatibility"""
try:
Expand All @@ -131,13 +131,13 @@ def _convert_pil_to_bytesio(self, pil_image: Image.Image) -> io.BytesIO:
except Exception as e:
logger.error(f"Failed to convert PIL image to BytesIO: {e}")
raise ValueError(f"Image conversion failed: {e}")

def _parse_multimodal_inputs(self, inputs):
"""Parse and categorize multimodal inputs for BGE-VL"""
text_inputs = []
image_inputs = []
multimodal_inputs = []

for inp in inputs:
if isinstance(inp, str):
# Simple text input
Expand All @@ -162,14 +162,14 @@ def _parse_multimodal_inputs(self, inputs):
# Convert PIL Image to BytesIO for BGE-VL compatibility
bytesio_image = self._convert_pil_to_bytesio(pil_image)
multimodal_inputs.append((text, bytesio_image))

return text_inputs, image_inputs, multimodal_inputs

def _generate_bge_vl_embeddings(self, inputs):
"""Generate embeddings using BGE-VL model"""
text_inputs, image_inputs, multimodal_inputs = self._parse_multimodal_inputs(inputs)
all_embeddings = []

# Process text-only inputs
if text_inputs:
try:
Expand All @@ -182,7 +182,7 @@ def _generate_bge_vl_embeddings(self, inputs):
except Exception as e:
logger.error(f"Failed to encode text inputs: {e}")
raise ValueError(f"BGE-VL text encoding failed: {e}")

# Process image-only inputs
if image_inputs:
try:
Expand All @@ -195,7 +195,7 @@ def _generate_bge_vl_embeddings(self, inputs):
except Exception as e:
logger.error(f"Failed to encode image inputs: {e}")
raise ValueError(f"BGE-VL image encoding failed: {e}")

# Process multimodal inputs (text + image)
if multimodal_inputs:
for text, bytesio_image in multimodal_inputs:
Expand All @@ -209,7 +209,7 @@ def _generate_bge_vl_embeddings(self, inputs):
except Exception as e:
logger.error(f"Failed to encode multimodal input: {e}")
raise ValueError(f"BGE-VL multimodal encoding failed: {e}")

return all_embeddings

def invoke(self, request:dict):
Expand All @@ -219,7 +219,7 @@ def invoke(self, request:dict):

logger.info(f'request: {request}')
t0 = time.time()

if self.is_bge_vl:
# Use BGE-VL multimodal processing
embeddings_list = self._generate_bge_vl_embeddings(inputs)
Expand All @@ -229,10 +229,10 @@ def invoke(self, request:dict):
truncate_dim = request.get('truncate_dim', None)
embeddings = self.model.encode(inputs, task=task, truncate_dim=truncate_dim)
embeddings_list = embeddings.tolist()

logger.info(f'embeddings generated, count: {len(embeddings_list)}, elapsed time: {time.time()-t0}')
return self.format_openai_response(embeddings_list)

async def ainvoke(self, request: dict):
"""Async version of invoke method"""
return self.invoke(request)