Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks-data
6 changes: 6 additions & 0 deletions benchmarks/700.image/701.image-captioning/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"timeout": 60,
"memory": 256,
"languages": ["python"]
}

40 changes: 40 additions & 0 deletions benchmarks/700.image/701.image-captioning/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import glob
import os

def buckets_count():
return (1, 1)

'''
Generate test, small, and large workload for image captioning benchmark.

:param data_dir: Directory where benchmark data is placed
:param size: Workload size
:param benchmarks_bucket: Storage container for the benchmark
:param input_paths: List of input paths
:param output_paths: List of output paths
:param upload_func: Upload function taking three params (bucket_idx, key, filepath)
'''
def generate_input(data_dir, size, benchmarks_bucket, input_paths, output_paths, upload_func):
input_files = glob.glob(os.path.join(data_dir, '*.jpg')) + glob.glob(os.path.join(data_dir, '*.png')) + glob.glob(os.path.join(data_dir, '*.jpeg'))

if not input_files:
raise ValueError("No input files found in the provided directory.")

for file in input_files:
img = os.path.relpath(file, data_dir)
upload_func(0, img, file)

input_config = {
'object': {
'key': img,
'width': 200,
'height': 200
},
'bucket': {
'bucket': benchmarks_bucket,
'input': input_paths[0],
'output': output_paths[0]
}
}

return input_config
67 changes: 67 additions & 0 deletions benchmarks/700.image/701.image-captioning/python/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import datetime
import io
import os
from urllib.parse import unquote_plus
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from . import storage

# Load the pre-trained ViT-GPT2 model
# Model URL: https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
# License: Apache 2.0 License (https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md)
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

model.eval()

client = storage.storage.get_instance()

def generate_caption(image_bytes):
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values

with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=16, num_beams=4)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

return generated_text

def handler(event):
bucket = event.get('bucket').get('bucket')
input_prefix = event.get('bucket').get('input')
output_prefix = event.get('bucket').get('output')
key = unquote_plus(event.get('object').get('key'))

download_begin = datetime.datetime.now()
img = client.download_stream(bucket, os.path.join(input_prefix, key))
download_end = datetime.datetime.now()

process_begin = datetime.datetime.now()
caption = generate_caption(img)
process_end = datetime.datetime.now()

upload_begin = datetime.datetime.now()
caption_file_name = os.path.splitext(key)[0] + '.txt'
caption_file_path = os.path.join(output_prefix, caption_file_name)
client.upload_stream(bucket, caption_file_path, io.BytesIO(caption.encode('utf-8')))
upload_end = datetime.datetime.now()

download_time = (download_end - download_begin) / datetime.timedelta(microseconds=1)
upload_time = (upload_end - upload_begin) / datetime.timedelta(microseconds=1)
process_time = (process_end - process_begin) / datetime.timedelta(microseconds=1)

return {
'result': {
'bucket': bucket,
'key': caption_file_path
},
'measurement': {
'download_time': download_time,
'download_size': len(img),
'upload_time': upload_time,
'upload_size': len(caption.encode('utf-8')),
'compute_time': process_time
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transformers==4.44.2
torch==2.4.0
pillow==10.4.0