diff --git a/config.py b/config.py index aa4c903..db242f5 100644 --- a/config.py +++ b/config.py @@ -21,6 +21,8 @@ OPEN_SEARCH_INDEX = 'scan-explorer' OPEN_SEARCH_AGG_BUCKET_LIMIT = 10000 +REDIS_URL = 'redis://redis-backend:6379/4' + ADS_SEARCH_SERVICE_URL = 'https://api.adsabs.harvard.edu/v1/search/query' ADS_SEARCH_SERVICE_TOKEN = '' diff --git a/requirements.txt b/requirements.txt index 3b4d3e5..82529c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ alembic==1.8.0 img2pdf==0.4.4 appmap>=1.1.0.dev0 boto3==1.34.75 +redis==4.6.0 diff --git a/scan_explorer_service/open_search.py b/scan_explorer_service/open_search.py index 039a76c..50c59a7 100644 --- a/scan_explorer_service/open_search.py +++ b/scan_explorer_service/open_search.py @@ -3,7 +3,22 @@ from flask import current_app from scan_explorer_service.utils.search_utils import EsFields, OrderOptions + +def _get_os_client(): + """Return a singleton OpenSearch client for the current Flask app context.""" + if not hasattr(current_app, '_os_client') or current_app._os_client is None: + url = current_app.config.get('OPEN_SEARCH_URL') + current_app._os_client = opensearchpy.OpenSearch( + url, + timeout=30, + max_retries=2, + retry_on_timeout=True, + pool_maxsize=20, + ) + return current_app._os_client + def create_query_string_query(query_string: str): + """Build an OpenSearch query_string query dict with default fields and AND operator.""" query = { "query": { "query_string": { @@ -16,6 +31,7 @@ def create_query_string_query(query_string: str): return query def append_aggregate(query: dict, agg_field: EsFields, page: int, size: int, sort: OrderOptions): + """Add a terms aggregation with bucket sort and pagination to an OpenSearch query.""" from_number = (page - 1) * size query['size'] = 0 if sort == OrderOptions.Bibcode_desc or sort == OrderOptions.Bibcode_asc: @@ -60,6 +76,7 @@ def append_aggregate(query: dict, agg_field: EsFields, page: int, size: int, sor return query def append_highlight(query: dict): + """Add text field highlighting to an OpenSearch query.""" query['highlight'] = { "fields": { "text": {} @@ -70,12 +87,14 @@ def append_highlight(query: dict): def es_search(query: dict) -> Iterator[str]: - es = opensearchpy.OpenSearch(current_app.config.get('OPEN_SEARCH_URL')) + """Execute an OpenSearch query against the configured index and return the raw response.""" + es = _get_os_client() resp = es.search(index=current_app.config.get( 'OPEN_SEARCH_INDEX'), body=query) return resp def text_search_highlight(text: str, filter_field: EsFields, filter_value: str): + """Search for text with an optional field filter and yield page IDs with highlight snippets.""" query_string = text if filter_field: query_string += " " + filter_field.value + ":" + str(filter_value) @@ -101,6 +120,7 @@ def text_search_highlight(text: str, filter_field: EsFields, filter_value: str): } def set_page_ocr_fields(query: dict) -> dict: + """Restrict the query's _source to include only the OCR text field.""" if '_source' in query.keys(): query["_source"]["include"].append("text") else: @@ -108,10 +128,12 @@ def set_page_ocr_fields(query: dict) -> dict: return query def set_page_search_fields(query: dict) -> dict: + """Restrict the query's _source to page identification fields only.""" query["_source"] = {"include": ["page_id", "volume_id", "page_label", "page_number"]} return query def page_os_search(qs: str, page, limit, sort): + """Run a paginated page-level search against OpenSearch with sorting.""" qs = qs.replace("&", "+") query = create_query_string_query(qs) query = set_page_search_fields(query) @@ -137,6 +159,7 @@ def page_os_search(qs: str, page, limit, sort): return es_result def page_ocr_os_search(collection_id: str, page_number:int): + """Fetch the OCR text for a specific page within a collection from OpenSearch.""" qs = EsFields.volume_id_lowercase + ":" + collection_id + " " + EsFields.page_number + ":" + str(page_number) query = create_query_string_query(qs) query = set_page_ocr_fields(query) @@ -144,6 +167,7 @@ def page_ocr_os_search(collection_id: str, page_number:int): return es_result def aggregate_search(qs: str, aggregate_field, page, limit, sort): + """Run a paginated aggregation search grouped by the specified field.""" qs = qs.replace("&", "+") query = create_query_string_query(qs) query = append_aggregate(query, aggregate_field, page, limit, sort) diff --git a/scan_explorer_service/tests/test_cache.py b/scan_explorer_service/tests/test_cache.py new file mode 100644 index 0000000..2dcb0f1 --- /dev/null +++ b/scan_explorer_service/tests/test_cache.py @@ -0,0 +1,420 @@ +import unittest +import json +from flask import url_for +from unittest.mock import patch, MagicMock, PropertyMock +from scan_explorer_service.tests.base import TestCaseDatabase +from scan_explorer_service.models import Base, Collection, Page, Article + + +class TestRedisReconnection(TestCaseDatabase): + """Verify Redis client resets on ConnectionError and automatically reconnects on next call.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False, + 'REDIS_URL': 'redis://localhost:6379/15', + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + import scan_explorer_service.views.manifest as m + m._redis_client = None + + def tearDown(self): + import scan_explorer_service.views.manifest as m + m._redis_client = None + super().tearDown() + + @patch('scan_explorer_service.views.manifest.redis.from_url') + def test_cache_get_resets_on_connection_error(self, mock_from_url): + """Verify _redis_client is set to None after a ConnectionError so next call reconnects.""" + import scan_explorer_service.views.manifest as m + import redis as redis_lib + mock_client = MagicMock() + mock_client.ping.return_value = True + mock_from_url.return_value = mock_client + + m._cache_get('test') + self.assertIsNotNone(m._redis_client) + + mock_client.get.side_effect = redis_lib.ConnectionError("connection lost") + result = m._cache_get('test') + self.assertIsNone(result) + self.assertIsNone(m._redis_client) + + @patch('scan_explorer_service.views.manifest.redis.from_url') + def test_cache_set_resets_on_connection_error(self, mock_from_url): + """Verify _redis_client resets on ConnectionError during cache writes too.""" + import scan_explorer_service.views.manifest as m + import redis as redis_lib + mock_client = MagicMock() + mock_client.ping.return_value = True + mock_from_url.return_value = mock_client + + m._cache_set('test', '{}') + self.assertIsNotNone(m._redis_client) + + mock_client.setex.side_effect = redis_lib.ConnectionError("connection lost") + m._cache_set('test', '{}') + self.assertIsNone(m._redis_client) + + @patch('scan_explorer_service.views.manifest.redis.from_url') + def test_reconnects_after_reset(self, mock_from_url): + """Verify a new Redis client is created after a previous connection was reset.""" + import scan_explorer_service.views.manifest as m + import redis as redis_lib + mock_client = MagicMock() + mock_client.ping.return_value = True + mock_from_url.return_value = mock_client + + m._cache_get('test') + mock_client.get.side_effect = redis_lib.ConnectionError("lost") + m._cache_get('test') + self.assertIsNone(m._redis_client) + + mock_client2 = MagicMock() + mock_client2.ping.return_value = True + mock_client2.get.return_value = '{"cached": true}' + mock_from_url.return_value = mock_client2 + + result = m._cache_get('test') + self.assertEqual(result, '{"cached": true}') + self.assertIsNotNone(m._redis_client) + + @patch('scan_explorer_service.views.manifest.redis.from_url') + def test_get_redis_uses_lock(self, mock_from_url): + """Verify from_url is only called once even with multiple _get_redis calls (singleton pattern).""" + import scan_explorer_service.views.manifest as m + mock_client = MagicMock() + mock_client.ping.return_value = True + mock_from_url.return_value = mock_client + + m._get_redis() + m._get_redis() + self.assertEqual(mock_from_url.call_count, 1) + + +class TestManifestCaching(TestCaseDatabase): + """Verify manifest and search results are cached in Redis and served on subsequent requests.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False, + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + import scan_explorer_service.views.manifest as m + m._redis_client = None + + self.collection = Collection(type='type', journal='journal', volume='volume') + self.app.db.session.add(self.collection) + self.app.db.session.commit() + self.app.db.session.refresh(self.collection) + + self.article = Article(bibcode='1988ApJ...333..341R', + collection_id=self.collection.id) + self.app.db.session.add(self.article) + self.app.db.session.commit() + + self.page = Page(name='page', collection_id=self.collection.id) + self.page.width = 1000 + self.page.height = 1000 + self.page.label = 'label' + self.app.db.session.add(self.page) + self.app.db.session.commit() + + self.article.pages.append(self.page) + self.app.db.session.commit() + + def tearDown(self): + import scan_explorer_service.views.manifest as m + m._redis_client = None + super().tearDown() + + @patch('scan_explorer_service.views.manifest._cache_set') + @patch('scan_explorer_service.views.manifest._cache_get') + def test_manifest_serves_from_cache(self, mock_get, mock_set): + """Verify manifest is generated on first request, then served from cache on second.""" + mock_get.return_value = None + url = url_for("manifest.get_manifest", id=self.article.id) + r1 = self.client.get(url) + self.assertStatus(r1, 200) + mock_set.assert_called_once() + cached_json = mock_set.call_args[0][1] + + mock_get.return_value = cached_json + r2 = self.client.get(url) + self.assertStatus(r2, 200) + self.assertEqual(r2.content_type, 'application/json') + + @patch('scan_explorer_service.views.manifest._search_cache_set') + @patch('scan_explorer_service.views.manifest._search_cache_get') + @patch('opensearchpy.OpenSearch') + def test_search_cache_key_is_hashed(self, OpenSearch, mock_get, mock_set): + """Verify search cache keys are 32-char hex MD5 hashes, not raw query strings.""" + es = OpenSearch.return_value + es.search.return_value = { + "hits": {"total": {"value": 1}, "hits": [ + {'_source': {'page_id': self.page.id, 'volume_id': self.collection.id, + 'page_label': 'label', 'page_number': 1}, + 'highlight': {'text': ['some text']}} + ]} + } + mock_get.return_value = None + + url = url_for("manifest.search", id=self.article.id, q='test query') + self.client.get(url) + + cache_key = mock_set.call_args[0][0] + self.assertEqual(len(cache_key), 32) + self.assertTrue(all(c in '0123456789abcdef' for c in cache_key)) + + +class TestFetchImagesMemoryLimit(TestCaseDatabase): + """Verify fetch_images respects memory_limit and stops yielding when exceeded.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False, + 'IMAGE_PDF_MEMORY_LIMIT': 50, + 'IMAGE_PDF_PAGE_LIMIT': 100, + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + + self.collection = Collection(type='type', journal='journal', volume='volume') + self.app.db.session.add(self.collection) + self.app.db.session.commit() + + for i in range(5): + p = Page(name=f'page{i}', collection_id=self.collection.id, volume_running_page_num=i+1) + p.width = 100 + p.height = 100 + p.label = str(i+1) + self.app.db.session.add(p) + self.app.db.session.commit() + + @patch('scan_explorer_service.views.image_proxy.S3Provider') + def test_memory_limit_stops_yielding(self, mock_s3_cls): + """Verify fetch_images stops yielding once cumulative image size exceeds memory_limit.""" + import sys + chunk = b'x' * 30 + chunk_size = sys.getsizeof(chunk) + + mock_s3 = MagicMock() + mock_s3.read_object_s3.return_value = chunk + mock_s3_cls.return_value = mock_s3 + + memory_limit = chunk_size * 2 + 1 + from scan_explorer_service.views.image_proxy import fetch_images + images = list(fetch_images( + self.app.db.session, self.collection, 1, 5, 100, memory_limit)) + self.assertEqual(len(images), 2) + self.assertTrue(mock_s3.read_object_s3.call_count <= 5) + + +class TestPdfEarlyLimitCheck(TestCaseDatabase): + """PDF over-limit returns 400 without hitting DB.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False, + 'IMAGE_PDF_MEMORY_LIMIT': 100 * 1024 * 1024, + 'IMAGE_PDF_PAGE_LIMIT': 100, + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + + def test_over_limit_returns_400(self): + """Verify requesting more pages than IMAGE_PDF_PAGE_LIMIT returns 400 before hitting DB.""" + url = url_for('proxy.pdf_save', id='anything', page_start=1, page_end=200) + r = self.client.get(url) + self.assertStatus(r, 400) + data = json.loads(r.data) + self.assertIn('exceeds limit', data['Message']) + + def test_missing_id_returns_400(self): + """Verify missing 'id' parameter returns 400.""" + url = url_for('proxy.pdf_save', page_start=1, page_end=5) + r = self.client.get(url) + self.assertStatus(r, 400) + + +class TestSearchValidationBeforeCache(TestCaseDatabase): + """Verify invalid queries are rejected before any Redis cache lookup occurs.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'OPEN_SEARCH_URL': 'http://localhost:1234', + 'OPEN_SEARCH_INDEX': 'test', + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False, + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + + @patch('scan_explorer_service.views.metadata._search_cache_get') + def test_empty_query_returns_400_without_cache_lookup(self, mock_cache_get): + """Verify empty search query is rejected before any cache lookup occurs.""" + url = url_for("metadata.article_search", q='') + r = self.client.get(url) + self.assertStatus(r, 400) + mock_cache_get.assert_not_called() + + @patch('scan_explorer_service.views.metadata._search_cache_get') + def test_collection_empty_query_no_cache(self, mock_cache_get): + """Same validation-before-cache check for collection search endpoint.""" + url = url_for("metadata.collection_search", q='') + r = self.client.get(url) + self.assertStatus(r, 400) + mock_cache_get.assert_not_called() + + @patch('scan_explorer_service.views.metadata._search_cache_get') + def test_page_search_empty_query_no_cache(self, mock_cache_get): + """Same validation-before-cache check for page search endpoint.""" + url = url_for("metadata.page_search", q='') + r = self.client.get(url) + self.assertStatus(r, 400) + mock_cache_get.assert_not_called() + + +class TestSearchCacheKeyMultiValue(TestCaseDatabase): + """Verify cache keys are distinct when query params have multiple values for the same key.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + + def test_different_multi_params_produce_different_keys(self): + """Verify multi-valued query params (e.g. field=title&field=abstract) produce distinct cache keys.""" + from scan_explorer_service.views.metadata import _make_search_cache_key + from werkzeug.datastructures import ImmutableMultiDict + + args1 = ImmutableMultiDict([('q', 'star'), ('field', 'title')]) + args2 = ImmutableMultiDict([('q', 'star'), ('field', 'title'), ('field', 'abstract')]) + + key1 = _make_search_cache_key('test', args1) + key2 = _make_search_cache_key('test', args2) + self.assertNotEqual(key1, key2) + + +class TestOcrCaching(TestCaseDatabase): + """Verify OCR text results are cached and served as text/plain on cache hits.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'OPEN_SEARCH_URL': 'http://localhost:1234', + 'OPEN_SEARCH_INDEX': 'test', + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False, + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + + self.collection = Collection(type='type', journal='journal', volume='volume') + self.app.db.session.add(self.collection) + self.app.db.session.commit() + + self.article = Article(bibcode='1988ApJ...333..341R', + collection_id=self.collection.id) + self.app.db.session.add(self.article) + self.app.db.session.commit() + + self.page = Page(name='page', collection_id=self.collection.id) + self.page.width = 1000 + self.page.height = 1000 + self.page.label = 'label' + self.page.volume_running_page_num = 100 + self.app.db.session.add(self.page) + self.app.db.session.commit() + self.article.pages.append(self.page) + self.app.db.session.commit() + + @patch('scan_explorer_service.views.metadata._search_cache_set') + @patch('scan_explorer_service.views.metadata._search_cache_get') + @patch('opensearchpy.OpenSearch') + def test_ocr_result_is_cached(self, OpenSearch, mock_cache_get, mock_cache_set): + """Verify OCR text is stored in cache after first fetch and returned as text/plain.""" + es = OpenSearch.return_value + es.search.return_value = { + "hits": {"total": {"value": 1}, "hits": [ + {"_source": {"text": "Some OCR text here"}} + ]} + } + mock_cache_get.return_value = None + + url = url_for("metadata.get_page_ocr", id=self.article.id) + r = self.client.get(url) + self.assertStatus(r, 200) + self.assertEqual(r.data, b'Some OCR text here') + self.assertIn('text/plain', r.content_type) + mock_cache_set.assert_called_once() + self.assertEqual(mock_cache_set.call_args[0][1], 'Some OCR text here') + + @patch('scan_explorer_service.views.metadata._search_cache_get') + def test_ocr_served_from_cache(self, mock_cache_get): + """Verify cached OCR text is served directly without hitting OpenSearch.""" + mock_cache_get.return_value = 'Cached OCR text' + + url = url_for("metadata.get_page_ocr", id=self.article.id) + r = self.client.get(url) + self.assertStatus(r, 200) + self.assertEqual(r.data, b'Cached OCR text') + self.assertIn('text/plain', r.content_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/scan_explorer_service/tests/test_perf.py b/scan_explorer_service/tests/test_perf.py new file mode 100644 index 0000000..4bed405 --- /dev/null +++ b/scan_explorer_service/tests/test_perf.py @@ -0,0 +1,360 @@ +import unittest +import json +import time +from flask import url_for +from unittest.mock import patch, MagicMock +from scan_explorer_service.tests.base import TestCaseDatabase +from scan_explorer_service.models import Article, Base, Collection, Page + + +class TestManifestCache(TestCaseDatabase): + """Tests for Redis-backed manifest caching behavior.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + + import scan_explorer_service.views.manifest as m + m._redis_client = None + + self.collection = Collection(type='type', journal='cacheJ', volume='0099') + self.app.db.session.add(self.collection) + self.app.db.session.commit() + self.app.db.session.refresh(self.collection) + + self.article = Article(bibcode='2099CacheTest..001A', + collection_id=self.collection.id) + self.app.db.session.add(self.article) + self.app.db.session.commit() + self.app.db.session.refresh(self.article) + + self.page1 = Page(name='cp1', collection_id=self.collection.id, + volume_running_page_num=1) + self.page1.width = 100 + self.page1.height = 100 + self.page1.label = '1' + self.page2 = Page(name='cp2', collection_id=self.collection.id, + volume_running_page_num=2) + self.page2.width = 100 + self.page2.height = 100 + self.page2.label = '2' + self.app.db.session.add_all([self.page1, self.page2]) + self.app.db.session.commit() + + self.article.pages.append(self.page1) + self.article.pages.append(self.page2) + self.app.db.session.commit() + + def tearDown(self): + import scan_explorer_service.views.manifest as m + m._redis_client = None + self.app.db.session.remove() + self.app.db.drop_all() + + def _mock_redis(self): + """Create an in-memory mock Redis client with get/setex/delete and TTL support.""" + mock_r = MagicMock() + store = {} + + def mock_get(key): + entry = store.get(key) + if entry is None: + return None + val, exp = entry + if exp and time.monotonic() > exp: + del store[key] + return None + return val + + def mock_setex(key, ttl, val): + store[key] = (val, time.monotonic() + ttl) + + def mock_delete(key): + store.pop(key, None) + + mock_r.get = mock_get + mock_r.setex = mock_setex + mock_r.delete = mock_delete + mock_r.ping.return_value = True + + import scan_explorer_service.views.manifest as m + m._redis_client = mock_r + return mock_r, store + + def test_cache_hit_returns_cached_json(self): + """Verifies that a cached manifest is returned directly without regeneration.""" + mock_r, store = self._mock_redis() + from scan_explorer_service.views.manifest import MANIFEST_CACHE_PREFIX + store[MANIFEST_CACHE_PREFIX + self.article.id] = ('{"@type":"sc:Manifest","cached":true}', time.monotonic() + 3600) + + url = url_for("manifest.get_manifest", id=self.article.id) + r = self.client.get(url) + self.assertStatus(r, 200) + data = json.loads(r.data) + self.assertTrue(data.get('cached')) + + def test_cache_hit_returns_correct_content_type(self): + """Verifies that cached manifest responses have application/json content type.""" + mock_r, store = self._mock_redis() + from scan_explorer_service.views.manifest import MANIFEST_CACHE_PREFIX + store[MANIFEST_CACHE_PREFIX + self.collection.id] = ('{"@type":"sc:Manifest"}', time.monotonic() + 3600) + + url = url_for("manifest.get_manifest", id=self.collection.id) + r = self.client.get(url) + self.assertStatus(r, 200) + self.assertIn('application/json', r.content_type) + + def test_cache_miss_calls_setex(self): + """Verifies that a cache miss triggers a setex call to store the manifest.""" + mock_r, store = self._mock_redis() + original_setex = mock_r.setex + setex_calls = [] + + def tracking_setex(key, ttl, val): + setex_calls.append(key) + return original_setex(key, ttl, val) + + mock_r.setex = tracking_setex + + from scan_explorer_service.views.manifest import _cache_set, MANIFEST_CACHE_PREFIX + _cache_set(self.article.id, '{"@type":"sc:Manifest"}') + + self.assertEqual(len(setex_calls), 1) + self.assertEqual(setex_calls[0], MANIFEST_CACHE_PREFIX + self.article.id) + + def test_cached_manifest_skips_manifest_factory(self): + """Verifies that manifest_factory is not called when the manifest is cached.""" + mock_r, store = self._mock_redis() + from scan_explorer_service.views.manifest import MANIFEST_CACHE_PREFIX + store[MANIFEST_CACHE_PREFIX + self.article.id] = ('{"@type":"sc:Manifest"}', time.monotonic() + 3600) + + with patch('scan_explorer_service.views.manifest.manifest_factory') as mock_factory: + url = url_for("manifest.get_manifest", id=self.article.id) + r = self.client.get(url) + self.assertStatus(r, 200) + mock_factory.create_manifest.assert_not_called() + + def test_404_not_cached(self): + """Verifies that 404 responses are not stored in the cache.""" + mock_r, store = self._mock_redis() + from scan_explorer_service.views.manifest import MANIFEST_CACHE_PREFIX + + url = url_for("manifest.get_manifest", id='nonexistent') + r = self.client.get(url) + self.assertStatus(r, 404) + self.assertNotIn(MANIFEST_CACHE_PREFIX + 'nonexistent', store) + + def test_redis_unavailable_falls_through(self): + """Verifies that the endpoint still works when Redis is unavailable.""" + import scan_explorer_service.views.manifest as m + m._redis_client = None + + with patch('scan_explorer_service.views.manifest.redis.from_url', side_effect=Exception("connection refused")): + url = url_for("manifest.get_manifest", id=self.article.id) + r = self.client.get(url) + self.assertStatus(r, 200) + data = json.loads(r.data) + self.assertEqual(data['@type'], 'sc:Manifest') + + +class TestPdfEarlyLimitCheck(TestCaseDatabase): + """Tests for early page-count validation before PDF generation.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False, + 'IMAGE_PDF_MEMORY_LIMIT': 100 * 1024 * 1024, + 'IMAGE_PDF_PAGE_LIMIT': 100, + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + + self.collection = Collection(type='type', journal='journal', volume='volume') + self.app.db.session.add(self.collection) + self.app.db.session.commit() + + def test_over_limit_returns_400_immediately(self): + """Verifies that requesting more pages than the limit returns 400 without processing.""" + response = self.client.get(url_for('proxy.pdf_save', + id=self.collection.id, + page_start=1, + page_end=150)) + self.assertEqual(response.status_code, 400) + data = json.loads(response.data) + self.assertIn('exceeds limit', data['Message']) + + def test_exactly_at_limit_passes(self): + """Verifies that requesting exactly the page limit is allowed.""" + with patch('scan_explorer_service.views.image_proxy.fetch_images') as mock_fi, \ + patch('scan_explorer_service.views.image_proxy.img2pdf.convert') as mock_conv: + mock_fi.return_value = [b'data'] + mock_conv.return_value = b'pdf' + response = self.client.get(url_for('proxy.pdf_save', + id=self.collection.id, + page_start=1, + page_end=100)) + self.assertEqual(response.status_code, 200) + + def test_one_over_limit_returns_400(self): + """Verifies that requesting one page over the limit returns 400.""" + response = self.client.get(url_for('proxy.pdf_save', + id=self.collection.id, + page_start=1, + page_end=101)) + self.assertEqual(response.status_code, 400) + + def test_no_page_end_passes_limit_check(self): + """Verifies that omitting page_end bypasses the page limit check.""" + with patch('scan_explorer_service.views.image_proxy.fetch_images') as mock_fi, \ + patch('scan_explorer_service.views.image_proxy.img2pdf.convert') as mock_conv: + mock_fi.return_value = [b'data'] + mock_conv.return_value = b'pdf' + response = self.client.get(url_for('proxy.pdf_save', + id=self.collection.id, + page_start=1)) + self.assertEqual(response.status_code, 200) + + @patch('scan_explorer_service.views.image_proxy.get_item') + def test_over_limit_does_not_touch_db(self, mock_get_item): + """Verifies that over-limit requests are rejected before any database access.""" + response = self.client.get(url_for('proxy.pdf_save', + id=self.collection.id, + page_start=1, + page_end=200)) + self.assertEqual(response.status_code, 400) + mock_get_item.assert_not_called() + + def test_inverted_page_range_returns_empty_pdf(self): + """Verifies that an inverted page range (start > end) is handled gracefully.""" + with patch('scan_explorer_service.views.image_proxy.fetch_images') as mock_fi, \ + patch('scan_explorer_service.views.image_proxy.img2pdf.convert') as mock_conv: + mock_fi.return_value = [] + mock_conv.return_value = b'pdf' + response = self.client.get(url_for('proxy.pdf_save', + id=self.collection.id, + page_start=10, + page_end=5)) + self.assertIn(response.status_code, [200, 400]) + + +class TestParallelFetchImages(TestCaseDatabase): + """Tests for parallel image fetching used in PDF generation.""" + + def create_app(self): + from scan_explorer_service.app import create_app + return create_app(**{ + 'SQLALCHEMY_DATABASE_URI': self.postgresql_url, + 'SQLALCHEMY_ECHO': False, + 'TESTING': True, + 'PROPAGATE_EXCEPTIONS': True, + 'TRAP_BAD_REQUEST_ERRORS': True, + 'PRESERVE_CONTEXT_ON_EXCEPTION': False, + 'IMAGE_PDF_MEMORY_LIMIT': 100 * 1024 * 1024, + 'IMAGE_PDF_PAGE_LIMIT': 100, + }) + + def setUp(self): + Base.metadata.drop_all(bind=self.app.db.engine) + Base.metadata.create_all(bind=self.app.db.engine) + + self.collection = Collection(type='type', journal='journal', volume='volume') + self.app.db.session.add(self.collection) + self.app.db.session.commit() + + self.article = Article(bibcode='1988ApJ...333..341R', + collection_id=self.collection.id) + self.app.db.session.add(self.article) + self.app.db.session.commit() + + pages = [] + for i in range(5): + p = Page(name=f'page{i}', collection_id=self.collection.id, + volume_running_page_num=i + 1) + p.width = 100 + p.height = 100 + p.label = str(i + 1) + pages.append(p) + self.app.db.session.add_all(pages) + self.app.db.session.commit() + + for p in pages: + self.article.pages.append(p) + self.app.db.session.commit() + + self.pages = pages + + @patch('scan_explorer_service.views.image_proxy.S3Provider') + def test_fetch_images_returns_all_pages(self, mock_s3_cls): + """Verifies that fetch_images returns image data for all pages in the range.""" + mock_s3 = MagicMock() + mock_s3.read_object_s3.return_value = b'image_data' + mock_s3_cls.return_value = mock_s3 + + from scan_explorer_service.views.image_proxy import fetch_images + images = list(fetch_images( + self.app.db.session, self.collection, 1, 5, 100, + 100 * 1024 * 1024)) + self.assertEqual(len(images), 5) + self.assertTrue(all(img == b'image_data' for img in images)) + + @patch('scan_explorer_service.views.image_proxy.S3Provider') + def test_fetch_images_respects_memory_limit(self, mock_s3_cls): + """Verifies that fetch_images stops fetching when the memory limit is reached.""" + mock_s3 = MagicMock() + mock_s3.read_object_s3.return_value = b'x' * 1000 + mock_s3_cls.return_value = mock_s3 + + from scan_explorer_service.views.image_proxy import fetch_images + images = list(fetch_images( + self.app.db.session, self.collection, 1, 5, 100, + 500)) + self.assertLess(len(images), 5) + + @patch('scan_explorer_service.views.image_proxy.S3Provider') + def test_fetch_images_skips_none_results(self, mock_s3_cls): + """Verifies that fetch_images filters out None results from S3.""" + mock_s3 = MagicMock() + mock_s3.read_object_s3.side_effect = [b'data1', None, b'data3', b'data4', b'data5'] + mock_s3_cls.return_value = mock_s3 + + from scan_explorer_service.views.image_proxy import fetch_images + images = list(fetch_images( + self.app.db.session, self.collection, 1, 5, 100, + 100 * 1024 * 1024)) + self.assertEqual(len(images), 4) + + @patch('scan_explorer_service.views.image_proxy.S3Provider') + def test_single_s3provider_instance(self, mock_s3_cls): + """Verifies that fetch_images reuses a single S3Provider instance across all pages.""" + mock_s3 = MagicMock() + mock_s3.read_object_s3.return_value = b'image_data' + mock_s3_cls.return_value = mock_s3 + + from scan_explorer_service.views.image_proxy import fetch_images + list(fetch_images( + self.app.db.session, self.collection, 1, 5, 100, + 100 * 1024 * 1024)) + self.assertEqual(mock_s3_cls.call_count, 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/scan_explorer_service/tests/test_proxy.py b/scan_explorer_service/tests/test_proxy.py index 27a9f03..8270fd3 100644 --- a/scan_explorer_service/tests/test_proxy.py +++ b/scan_explorer_service/tests/test_proxy.py @@ -8,6 +8,7 @@ from scan_explorer_service.views.image_proxy import img2pdf, fetch_images, fetch_object class TestProxy(TestCaseDatabase): + """Tests for image proxy, thumbnail, PDF, and S3 fetch endpoints.""" def create_app(self): '''Start the wsgi application''' @@ -72,6 +73,7 @@ def setUp(self): self.app.db.session.refresh(self.article2) def mocked_request(*args, **kwargs): + """Return mock HTTP responses based on URL path keywords.""" class Raw: def __init__(self, data): self.data = data @@ -99,7 +101,7 @@ def close(self): @patch('requests.request', side_effect=mocked_request) def test_get_image(self, mock_request): - + """Verifies that image proxy forwards requests and returns correct status codes.""" url = url_for('proxy.image_proxy', path='valid-~image-~path') response = self.client.get(url) @@ -116,6 +118,7 @@ def test_get_image(self, mock_request): @patch('scan_explorer_service.views.image_proxy.requests.request') def test_image_proxy_closes_upstream_response(self, mock_request): + """Verifies that the upstream response is closed after the streamed response completes.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.headers = {} @@ -130,7 +133,7 @@ def test_image_proxy_closes_upstream_response(self, mock_request): @patch('requests.request', side_effect=mocked_request) def test_get_thumbnail(self, mock_request): - + """Verifies that thumbnail proxy returns a streamed 200 response for a valid article.""" data = { 'id': '1988ApJ...333..341R', 'type': 'article' @@ -157,9 +160,12 @@ def test_get_item(self): get_item(self.app.db.session, 'non-existent-id') assert("ID: non-existent-id not found" in str(context.exception)) - @patch('scan_explorer_service.views.image_proxy.fetch_object') - def test_fetch_images(self, mock_fetch_object): - mock_fetch_object.return_value = b'image_data' + @patch('scan_explorer_service.views.image_proxy.S3Provider') + def test_fetch_images(self, mock_s3_cls): + """Verifies that fetch_images yields image bytes for each page in the range.""" + mock_s3 = MagicMock() + mock_s3.read_object_s3.return_value = b'image_data' + mock_s3_cls.return_value = mock_s3 item = self.article page_start = 1 page_end = 2 @@ -169,10 +175,11 @@ def test_fetch_images(self, mock_fetch_object): gen = fetch_images(self.app.db.session, item, page_start, page_end, page_limit, memory_limit) images = list(gen) self.assertEqual(images, [b'image_data', b'image_data']) - mock_fetch_object.assert_called() + mock_s3.read_object_s3.assert_called() @patch('scan_explorer_service.utils.s3_utils.S3Provider.read_object_s3') def test_fetch_object(self, mock_read_object_s3): + """Verifies that fetch_object reads the correct S3 object and returns its bytes.""" object_name = 'bitmaps/type/journal/volume/600/page' mock_read_object_s3.return_value = b'image-data' @@ -185,6 +192,7 @@ def test_fetch_object(self, mock_read_object_s3): @patch('scan_explorer_service.views.image_proxy.fetch_object') def test_pdf_save_success_article(self, mock_fetch_object): + """Verifies that PDF download for an article returns 200 with application/pdf content type.""" mock_fetch_object.return_value = b'my_image_name' data = { @@ -201,6 +209,7 @@ def test_pdf_save_success_article(self, mock_fetch_object): @patch('scan_explorer_service.views.image_proxy.img2pdf.convert') @patch('scan_explorer_service.views.image_proxy.fetch_images') def test_pdf_save_success_collection(self, mock_fetch_images, mock_img2pdf_convert): + """Verifies that PDF download for a collection page range returns converted PDF data.""" mock_fetch_images.return_value = [b'image_data_1', b'image_data_2', b'image_data_3'] mock_img2pdf_convert.return_value = b'pdf_data' @@ -240,6 +249,7 @@ def setUp(self): Base.metadata.create_all(bind=self.app.db.engine) def _make_mock_response(self, data, status_code, headers=None): + """Build a mock HTTP response with streamable raw data.""" class Raw: def __init__(self, d): self.data = d @@ -256,6 +266,7 @@ def close(self): @patch('requests.request') def test_retry_on_cold_cache_400(self, mock_request): + """Verifies that a 400 from Cantaloupe triggers a retry that succeeds.""" fail = self._make_mock_response([b'error'], 400) success = self._make_mock_response([b'ok'], 200) mock_request.side_effect = [fail, success] @@ -268,6 +279,7 @@ def test_retry_on_cold_cache_400(self, mock_request): @patch('requests.request') def test_retry_on_cold_cache_500(self, mock_request): + """Verifies that a 500 from Cantaloupe triggers a retry that succeeds.""" fail = self._make_mock_response([b'error'], 500) success = self._make_mock_response([b'ok'], 200) mock_request.side_effect = [fail, success] @@ -280,6 +292,7 @@ def test_retry_on_cold_cache_500(self, mock_request): @patch('requests.request') def test_no_retry_on_success(self, mock_request): + """Verifies that a successful response does not trigger any retries.""" success = self._make_mock_response([b'ok'], 200) mock_request.return_value = success @@ -291,6 +304,7 @@ def test_no_retry_on_success(self, mock_request): @patch('requests.request') def test_returns_error_after_exhausted_retries(self, mock_request): + """Verifies that the error response is returned after all retries are exhausted.""" fail = self._make_mock_response([b'error'], 400) mock_request.return_value = fail @@ -336,7 +350,7 @@ def setUp(self): @patch('scan_explorer_service.views.image_proxy.fetch_object') def test_pdf_save_article_no_pages_returns_400(self, mock_fetch_object): - """S4: PDF endpoint returns 400 when article has no pages and no pre-built PDF.""" + """Verifies that PDF download returns 400 when article has no pages and no pre-built PDF.""" mock_fetch_object.side_effect = ValueError("File content is empty") response = self.client.get(url_for('proxy.pdf_save', id=self.article_no_pages_id)) @@ -345,7 +359,7 @@ def test_pdf_save_article_no_pages_returns_400(self, mock_fetch_object): self.assertIn('No pages found', data['Message']) def test_get_pages_article_no_pages_raises(self): - """S4: get_pages raises Exception for article with no pages.""" + """Verifies that get_pages raises an exception for an article with no pages.""" from scan_explorer_service.views.image_proxy import get_pages with self.app.app_context(): with self.assertRaises(Exception) as ctx: @@ -354,7 +368,7 @@ def test_get_pages_article_no_pages_raises(self): @patch('scan_explorer_service.views.image_proxy.fetch_object') def test_fetch_article_exception_no_unbound_local(self, mock_fetch_object): - """S6: fetch_article logs correctly even when fetch_object raises.""" + """Verifies that fetch_article handles S3 exceptions without an UnboundLocalError.""" mock_fetch_object.side_effect = ValueError("S3 error") from scan_explorer_service.views.image_proxy import fetch_article @@ -363,7 +377,7 @@ def test_fetch_article_exception_no_unbound_local(self, mock_fetch_object): @patch('scan_explorer_service.views.image_proxy.fetch_object') def test_thumbnail_empty_collection_returns_400(self, mock_fetch_object): - """S2 (HTTP layer): thumbnail endpoint returns 400 for empty collection.""" + """Verifies that thumbnail endpoint returns 400 for a collection with no pages.""" response = self.client.get(url_for('proxy.image_proxy_thumbnail', id=self.collection.id, type='collection')) self.assertEqual(response.status_code, 400) diff --git a/scan_explorer_service/utils/search_utils.py b/scan_explorer_service/utils/search_utils.py index d46acb0..631485c 100644 --- a/scan_explorer_service/utils/search_utils.py +++ b/scan_explorer_service/utils/search_utils.py @@ -17,6 +17,7 @@ class SearchOptions(enum.Enum): Volume = 'volume' class EsFields(str, enum.Enum): + """OpenSearch field name mappings for indexed scan documents.""" article_id = 'article_bibcodes' article_id_lowercase = 'article_bibcodes_lowercase' volume_id = 'volume_id' @@ -44,6 +45,7 @@ class EsFields(str, enum.Enum): }) class OrderOptions(str, enum.Enum): + """Sort order options for search results.""" Relevance_desc = 'relevance_desc' Relevance_asc = 'relevance_asc' Bibcode_desc = 'bibcode_desc' @@ -52,7 +54,10 @@ class OrderOptions(str, enum.Enum): Collection_asc = 'collection_asc' def parse_query_args(args): + """Parse HTTP request args into a search query string, field dict, pagination, and sort order.""" qs = re.sub(':\s*', ':', args.get('q', '', str)) + if not qs or not qs.strip(): + raise ValueError('Query string is required') qs, qs_dict = parse_query_string(qs) @@ -65,6 +70,7 @@ def parse_query_args(args): return qs, qs_dict, page, limit, sort def parse_query_string(qs): + """Split a raw query string into an OpenSearch query and a dict of field:value filters.""" qs_to_split = qs.replace('[', '"[').replace(']',']"') qs_arr = [q for q in shlex.split(qs_to_split) if ':' in q] qs_dict = {} @@ -96,6 +102,7 @@ def parse_query_string(qs): return qs, qs_dict def parse_sorting_option(sort_input: str): + """Convert a sort string to an OrderOptions enum, defaulting to Bibcode_desc.""" sort = OrderOptions.Bibcode_desc if sort_input: for sort_opt in OrderOptions: @@ -114,7 +121,9 @@ def check_query(qs_dict: dict): check_page_color(qs_dict) check_project(qs_dict) -def check_page_type(qs_dict: dict): +def check_page_type(qs_dict: dict): + """Validate and normalize the pagetype filter value to match the PageType enum.""" + if SearchOptions.PageType.value in qs_dict.keys(): page_type = qs_dict[SearchOptions.PageType.value] valid_types = [p.name for p in PageType] @@ -127,7 +136,9 @@ def check_page_type(qs_dict: dict): return raise Exception("%s is not a valid page type, %s is possible choices"% (page_type, str(valid_types))) -def check_page_color(qs_dict: dict): +def check_page_color(qs_dict: dict): + """Validate and normalize the pagecolor filter value to match the PageColor enum.""" + if SearchOptions.PageColor.value in qs_dict.keys(): page_color = qs_dict[SearchOptions.PageColor.value] valid_types = [p.name for p in PageColor] @@ -140,7 +151,9 @@ def check_page_color(qs_dict: dict): return raise Exception("%s is not a valid page color, %s is possible choices"% (page_color, str(valid_types))) -def check_project(qs_dict: dict): +def check_project(qs_dict: dict): + """Validate and normalize the project filter value against known project names.""" + if SearchOptions.Project.value in qs_dict.keys(): project = qs_dict[SearchOptions.Project.value] valid_types = ['PHaEDRA', 'Historical Literature', 'Microfilm Scanning'] @@ -155,6 +168,7 @@ def check_project(qs_dict: dict): raise Exception("%s is not a valid project, %s is possible choices"% (project, str(valid_types))) def serialize_os_agg_page_bucket(bucket: dict): + """Convert an OpenSearch page hit into a page result dict with collection and label info.""" id = bucket['_source']['page_id'] volume_id = bucket['_source']['volume_id'] label = bucket['_source']['page_label'] @@ -164,6 +178,7 @@ def serialize_os_agg_page_bucket(bucket: dict): return {'id': id, 'collection_id':volume_id, 'journal': journal, 'volume': volume, 'label':label, 'volume_page_num': page_number} def serialize_os_page_result(result: dict, page: int, limit: int, contentQuery): + """Serialize an OpenSearch page search response into a paginated result dict.""" total_count = result['hits']['total']['value'] page_count = int(math.ceil(min(total_count,10000) / limit)) es_buckets = result['hits']['hits'] @@ -172,18 +187,21 @@ def serialize_os_page_result(result: dict, page: int, limit: int, contentQuery): 'items': [serialize_os_agg_page_bucket(b) for b in es_buckets]} def serialize_os_page_ocr_result(result: dict): + """Extract the OCR text from an OpenSearch page result, raising if no page is found.""" es_buckets = result['hits']['hits'] if len(es_buckets) < 1: raise Exception("No page with those parameters found") return es_buckets[0]['_source']['text'] def serialize_os_agg_collection_bucket(bucket: dict): + """Convert an OpenSearch collection aggregation bucket into a collection result dict.""" id = bucket['key'] journal = id[0:5] volume = id[5:9] return {'id': id, 'journal': journal, 'volume': volume, 'pages': bucket['doc_count']} def serialize_os_collection_result(result: dict, page: int, limit: int, contentQuery, agg_bucket_limit: int = 10000): + """Serialize an OpenSearch collection aggregation response into a paginated result dict.""" total_count = result['aggregations']['total_count']['value'] page_count = int(math.ceil(min(total_count, agg_bucket_limit) / limit)) es_buckets = result['aggregations']['ids']['buckets'] @@ -192,10 +210,12 @@ def serialize_os_collection_result(result: dict, page: int, limit: int, contentQ 'items': [serialize_os_agg_collection_bucket(b) for b in es_buckets]} def serialize_os_agg_article_bucket(bucket: dict): + """Convert an OpenSearch article aggregation bucket into an article result dict.""" id = bucket['key'] return {'id': id, 'bibcode': id, 'pages': bucket['doc_count']} def serialize_os_article_result(result: dict, page: int, limit: int, contentQuery = '', extra_col_count = 0, extra_page_count = 0, agg_bucket_limit: int = 10000): + """Serialize an OpenSearch article aggregation response into a paginated result dict.""" total_count = result['aggregations']['total_count']['value'] page_count = int(math.ceil(min(total_count, agg_bucket_limit) / limit)) es_buckets = result['aggregations']['ids']['buckets'] diff --git a/scan_explorer_service/views/image_proxy.py b/scan_explorer_service/views/image_proxy.py index 658450d..a288f56 100644 --- a/scan_explorer_service/views/image_proxy.py +++ b/scan_explorer_service/views/image_proxy.py @@ -14,6 +14,11 @@ import sys import time +try: + from gevent.pool import Pool as GeventPool +except ImportError: + GeventPool = None + bp_proxy = Blueprint('proxy', __name__, url_prefix='/image') @@ -51,6 +56,7 @@ def image_proxy(path): @stream_with_context def generate(): + """Stream the upstream response body chunk by chunk to avoid buffering the full image in memory.""" for chunk in r.raw.stream(decode_content=False): yield chunk @@ -80,16 +86,19 @@ def image_proxy_thumbnail(): return jsonify(Message=str(e)), 400 def get_item(session, id): + """Look up an Article or Collection by ID, raising if neither exists.""" item: Union[Article, Collection] = ( session.query(Article).filter(Article.id == id).one_or_none() or session.query(Collection).filter(Collection.id == id).one_or_none()) if not item: - raise Exception("ID: " + id + " not found") + raise Exception("ID: " + str(id) + " not found") return item def get_pages(item, session, page_start, page_end, page_limit): + """Query pages for an Article or Collection within the given page range. + For articles, page numbers are relative to the article's first page in the volume.""" if isinstance(item, Article): first_page = item.pages.first() if first_page is None: @@ -105,32 +114,51 @@ def get_pages(item, session, page_start, page_end, page_limit): return query -@stream_with_context def fetch_images(session, item, page_start, page_end, page_limit, memory_limit): - n_pages = 0 - memory_sum = 0 + """Yield page images from S3, stopping when memory_limit is exceeded. + Uses gevent pool for parallel fetching when available.""" query = get_pages(item, session, page_start, page_end, page_limit) - for page in query.all(): + pages = query.all() - n_pages += 1 + page_objects = [] + for page in pages[:page_limit]: + image_path, fmt = page.image_path_basic + object_name = '/'.join(image_path) + fmt + page_objects.append(object_name) - current_app.logger.debug(f"Generating image for page: {n_pages}") - if n_pages > page_limit: - break - if memory_sum > memory_limit: - current_app.logger.error(f"Memory limit reached: {memory_sum} > {memory_limit}") - break - image_path, format = page.image_path_basic - object_name = '/'.join(image_path) - object_name += format + config = current_app.config + app_logger = current_app.logger + s3 = S3Provider(config, 'AWS_BUCKET_NAME_IMAGE') - im_data = fetch_object(object_name, 'AWS_BUCKET_NAME_IMAGE') - memory_sum += sys.getsizeof(im_data) + def _fetch(obj_name): + return s3.read_object_s3(obj_name) - yield im_data + pool = None + if GeventPool is not None: + pool = GeventPool(size=20) + # maxsize=4 limits prefetch buffering to avoid holding too many images in memory at once + results = pool.imap(_fetch, page_objects, maxsize=4) + else: + results = (_fetch(obj) for obj in page_objects) + + try: + memory_sum = 0 + for im_data in results: + if not im_data: + continue + memory_sum += sys.getsizeof(im_data) + if memory_sum > memory_limit: + app_logger.error(f"Memory limit reached: {memory_sum} > {memory_limit}") + break + yield im_data + finally: + if pool is not None: + pool.join(timeout=5) + pool.kill(block=False) def fetch_object(object_name, bucket_name): + """Download a single object from S3, raising ValueError if the content is empty.""" file_content = S3Provider(current_app.config, bucket_name).read_object_s3(object_name) if not file_content: current_app.logger.error(f"Failed to fetch content for {object_name}. File might be empty.") @@ -139,6 +167,7 @@ def fetch_object(object_name, bucket_name): def fetch_article(item, memory_limit): + """Try to fetch a pre-rendered PDF for an article from the ads-classic-pdf S3 bucket.""" object_name = f'{item.id}.pdf'.lower() try: full_path = f'pdfs/{object_name}' @@ -160,6 +189,7 @@ def fetch_article(item, memory_limit): def generate_pdf(item, session, page_start, page_end, page_limit, memory_limit): + """Return a pre-rendered PDF for articles if available, otherwise generate one from page images.""" if isinstance(item, Article): response = fetch_article(item, memory_limit) if response: @@ -182,6 +212,12 @@ def pdf_save(): memory_limit = current_app.config.get("IMAGE_PDF_MEMORY_LIMIT") page_limit = current_app.config.get("IMAGE_PDF_PAGE_LIMIT") + if page_end != math.inf and (page_end - page_start + 1) > page_limit: + return jsonify(Message=f"Requested {page_end - page_start + 1} pages exceeds limit of {page_limit}"), 400 + + if not id: + return jsonify(Message="Missing required parameter: id"), 400 + with current_app.session_scope() as session: item = get_item(session, id) diff --git a/scan_explorer_service/views/manifest.py b/scan_explorer_service/views/manifest.py index 05e4e49..f1bfef6 100644 --- a/scan_explorer_service/views/manifest.py +++ b/scan_explorer_service/views/manifest.py @@ -1,5 +1,5 @@ -from flask import Blueprint, current_app, jsonify, request +from flask import Blueprint, current_app, jsonify, request, Response from flask_restful import abort from scan_explorer_service.extensions import manifest_factory from scan_explorer_service.models import Article, Page, Collection @@ -8,12 +8,122 @@ from scan_explorer_service.utils.utils import proxy_url, url_for_proxy from sqlalchemy.orm import selectinload from typing import Union +import json as json_lib +import hashlib +import redis +import logging +import threading + +logger = logging.getLogger(__name__) + +MANIFEST_CACHE_TTL = 3600 +MANIFEST_CACHE_PREFIX = 'scan:manifest:' +SEARCH_CACHE_TTL = 60 +SEARCH_CACHE_PREFIX = 'scan:search:' + +_redis_client = None +_redis_lock = threading.Lock() + + +def _get_redis(): + """Return the singleton Redis client, creating it on first call with double-checked locking.""" + global _redis_client + if _redis_client is not None: + return _redis_client + with _redis_lock: + # Double-checked locking: another thread may have connected while we waited for the lock + if _redis_client is not None: + return _redis_client + try: + url = current_app.config.get('REDIS_URL', 'redis://redis-backend:6379/4') + client = redis.from_url(url, decode_responses=True, socket_timeout=2, socket_connect_timeout=2) + client.ping() + _redis_client = client + return _redis_client + except Exception: + logger.warning("Redis unavailable, manifest caching disabled") + return None + + +def _reset_redis(): + """Clear the cached Redis client so the next call to _get_redis reconnects.""" + global _redis_client + _redis_client = None + + +def _cache_get(key): + """Fetch a cached manifest JSON string by key, returning None on miss or Redis failure.""" + r = _get_redis() + if r is None: + return None + try: + return r.get(MANIFEST_CACHE_PREFIX + key) + except redis.ConnectionError: + _reset_redis() + return None + except Exception: + return None + + +def _cache_set(key, json_str): + """Store a manifest JSON string in Redis with TTL. Silently resets on connection failure.""" + r = _get_redis() + if r is None: + return + try: + r.setex(MANIFEST_CACHE_PREFIX + key, MANIFEST_CACHE_TTL, json_str) + except redis.ConnectionError: + _reset_redis() + except Exception: + logger.debug("Failed to write manifest cache for key %s", key, exc_info=True) + + +def _cache_delete(key): + """Remove a manifest cache entry. Called when a collection is overwritten via PUT.""" + r = _get_redis() + if r is None: + return + try: + r.delete(MANIFEST_CACHE_PREFIX + key) + except redis.ConnectionError: + _reset_redis() + except Exception: + logger.debug("Failed to delete manifest cache for key %s", key, exc_info=True) + + +def _search_cache_get(key): + """Fetch a cached search result by key. Uses a shorter TTL than manifests.""" + r = _get_redis() + if r is None: + return None + try: + return r.get(SEARCH_CACHE_PREFIX + key) + except redis.ConnectionError: + _reset_redis() + return None + except Exception: + return None + + +def _search_cache_set(key, json_str): + """Store a search result in Redis with a short TTL (60s).""" + r = _get_redis() + if r is None: + return + try: + r.setex(SEARCH_CACHE_PREFIX + key, SEARCH_CACHE_TTL, json_str) + except redis.ConnectionError: + _reset_redis() + except Exception: + logger.debug("Failed to write search cache for key %s", key, exc_info=True) + bp_manifest = Blueprint('manifest', __name__, url_prefix='/manifest') @bp_manifest.before_request def before_request(): + """Configure manifest_factory base URIs from the proxy URL before each request.""" server, prefix = proxy_url() base_uri = f'{server}/{prefix}/manifest' manifest_factory.set_base_prezi_uri(base_uri) @@ -27,6 +137,10 @@ def before_request(): def get_manifest(id: str): """ Creates an IIIF manifest from an article or Collection""" + cached = _cache_get(id) + if cached is not None: + return Response(cached, content_type='application/json') + with current_app.session_scope() as session: item = session.query(Article).filter(Article.id == id).one_or_none() @@ -34,7 +148,10 @@ def get_manifest(id: str): manifest = manifest_factory.create_manifest(item) search_url = url_for_proxy('manifest.search', id=id) manifest_factory.add_search_service(manifest, search_url) - return manifest.toJSON(top=True) + result = manifest.toJSON(top=True) + result_json = json_lib.dumps(result) if isinstance(result, dict) else result + _cache_set(id, result_json) + return result collection = session.query(Collection).filter(Collection.id == id).one_or_none() @@ -58,7 +175,10 @@ def get_manifest(id: str): collection, pages, articles, article_pages) search_url = url_for_proxy('manifest.search', id=id) manifest_factory.add_search_service(manifest, search_url) - return manifest.toJSON(top=True) + result = manifest.toJSON(top=True) + result_json = json_lib.dumps(result) if isinstance(result, dict) else result + _cache_set(id, result_json) + return result return jsonify(exception='Article not found'), 404 @@ -85,6 +205,11 @@ def search(id: str): if not query or len(query) <= 0: return jsonify(exception='No search query specified'), 400 + cache_key = hashlib.md5(f"{id}:{query}".encode()).hexdigest() + cached = _search_cache_get(cache_key) + if cached is not None: + return Response(cached, content_type='application/json') + with current_app.session_scope() as session: item: Union[Article, Collection] = ( session.query(Article).filter(Article.id == id).one_or_none() @@ -92,7 +217,7 @@ def search(id: str): if item: annotation_list = manifest_factory.annotationList(request.url) annotation_list.resources = [] - + es_field = EsFields.article_id if isinstance(item, Article) else EsFields.volume_id results = text_search_highlight(query, es_field, item.id) @@ -103,8 +228,10 @@ def search(id: str): highlight_text = "

".join(res['highlight']).replace("em>", "b>") annotation.text(highlight_text, format="text/html") - return annotation_list.toJSON(top=True) - + result = annotation_list.toJSON(top=True) + result_json = json_lib.dumps(result) if isinstance(result, dict) else result + _search_cache_set(cache_key, result_json) + return result else: return jsonify(exception='Article or volume not found'), 404 diff --git a/scan_explorer_service/views/metadata.py b/scan_explorer_service/views/metadata.py index abcf065..964c7f7 100644 --- a/scan_explorer_service/views/metadata.py +++ b/scan_explorer_service/views/metadata.py @@ -7,8 +7,12 @@ from flask_discoverer import advertise from scan_explorer_service.utils.search_utils import * from scan_explorer_service.views.view_utils import ApiErrors +from scan_explorer_service.views.manifest import _cache_delete, _search_cache_get, _search_cache_set from scan_explorer_service.open_search import EsFields, page_os_search, aggregate_search, page_ocr_os_search +import opensearchpy import requests +import hashlib +import json as json_lib bp_metadata = Blueprint('metadata', __name__, url_prefix='/metadata') @@ -128,6 +132,7 @@ def put_collection(): pg_insert(page_article_association_table).values(page_article_data).on_conflict_do_nothing() ) session.commit() + _cache_delete(collection.id) return jsonify({'id': collection.id}), 200 except Exception: @@ -163,12 +168,23 @@ def put_page(): return jsonify(message='Invalid page json'), 400 +def _make_search_cache_key(prefix, args): + """Build an MD5 cache key from the search type prefix and all query params (including multi-valued).""" + raw = prefix + str(sorted(args.items(multi=True))) + return hashlib.md5(raw.encode()).hexdigest() + + @advertise(scopes=['api'], rate_limit=[300, 3600*24]) @bp_metadata.route('/article/search', methods=['GET']) def article_search(): """Search for an article using one or some of the available keywords""" try: qs, qs_dict, page, limit, sort = parse_query_args(request.args) + + cache_key = _make_search_cache_key('article', request.args) + cached = _search_cache_get(cache_key) + if cached is not None: + return current_app.response_class(cached, content_type='application/json') result = aggregate_search(qs, EsFields.article_id, page, limit, sort) text_query = '' if SearchOptions.FullText.value in qs_dict.keys(): @@ -180,7 +196,12 @@ def article_search(): collection_count = aggregate_search(qs, EsFields.volume_id, page, limit, sort)['aggregations']['total_count']['value'] page_count = page_os_search(qs, page, limit, sort)['hits']['total']['value'] agg_limit = current_app.config.get("OPEN_SEARCH_AGG_BUCKET_LIMIT", 10000) - return jsonify(serialize_os_article_result(result, page, limit, text_query, collection_count, page_count, agg_limit)) + response_data = serialize_os_article_result(result, page, limit, text_query, collection_count, page_count, agg_limit) + _search_cache_set(cache_key, json_lib.dumps(response_data)) + return jsonify(response_data) + except (opensearchpy.exceptions.ConnectionError, opensearchpy.exceptions.ConnectionTimeout, opensearchpy.exceptions.TransportError) as e: + current_app.logger.exception(f"OpenSearch error: {e}") + return jsonify(message='Search service temporarily unavailable', type=ApiErrors.SearchError.value), 503 except Exception as e: current_app.logger.exception(f"An exception has occurred: {e}") return jsonify(message=str(e), type=ApiErrors.SearchError.value), 400 @@ -192,12 +213,22 @@ def collection_search(): """Search for a collection using one or some of the available keywords""" try: qs, qs_dict, page, limit, sort = parse_query_args(request.args) + + cache_key = _make_search_cache_key('collection', request.args) + cached = _search_cache_get(cache_key) + if cached is not None: + return current_app.response_class(cached, content_type='application/json') result = aggregate_search(qs, EsFields.volume_id, page, limit, sort) text_query = '' if SearchOptions.FullText.value in qs_dict.keys(): text_query = qs_dict[SearchOptions.FullText.value] agg_limit = current_app.config.get("OPEN_SEARCH_AGG_BUCKET_LIMIT", 10000) - return jsonify(serialize_os_collection_result(result, page, limit, text_query, agg_limit)) + response_data = serialize_os_collection_result(result, page, limit, text_query, agg_limit) + _search_cache_set(cache_key, json_lib.dumps(response_data)) + return jsonify(response_data) + except (opensearchpy.exceptions.ConnectionError, opensearchpy.exceptions.ConnectionTimeout, opensearchpy.exceptions.TransportError) as e: + current_app.logger.exception(f"OpenSearch error: {e}") + return jsonify(message='Search service temporarily unavailable', type=ApiErrors.SearchError.value), 503 except Exception as e: return jsonify(message=str(e), type=ApiErrors.SearchError.value), 400 @@ -207,11 +238,21 @@ def page_search(): """Search for a page using one or some of the available keywords""" try: qs, qs_dict, page, limit, sort = parse_query_args(request.args) + + cache_key = _make_search_cache_key('page', request.args) + cached = _search_cache_get(cache_key) + if cached is not None: + return current_app.response_class(cached, content_type='application/json') result = page_os_search(qs, page, limit, sort) text_query = '' if SearchOptions.FullText.value in qs_dict.keys(): text_query = qs_dict[SearchOptions.FullText.value] - return jsonify(serialize_os_page_result(result, page, limit, text_query)) + response_data = serialize_os_page_result(result, page, limit, text_query) + _search_cache_set(cache_key, json_lib.dumps(response_data)) + return jsonify(response_data) + except (opensearchpy.exceptions.ConnectionError, opensearchpy.exceptions.ConnectionTimeout, opensearchpy.exceptions.TransportError) as e: + current_app.logger.exception(f"OpenSearch error: {e}") + return jsonify(message='Search service temporarily unavailable', type=ApiErrors.SearchError.value), 503 except Exception as e: return jsonify(message=str(e), type=ApiErrors.SearchError.value), 400 @@ -223,6 +264,11 @@ def get_page_ocr(): id = request.args.get('id') page_number = request.args.get('page_number', 1, int) + cache_key = _make_search_cache_key('ocr', request.args) + cached = _search_cache_get(cache_key) + if cached is not None: + return current_app.response_class(cached, content_type='text/plain') + with current_app.session_scope() as session: item: Union[Article, Collection] = ( session.query(Article).filter(Article.id == id).one_or_none() @@ -240,7 +286,12 @@ def get_page_ocr(): collection_id = item.id result = page_ocr_os_search(collection_id, page_number) - return serialize_os_page_ocr_result(result) + ocr_text = serialize_os_page_ocr_result(result) + _search_cache_set(cache_key, ocr_text) + return current_app.response_class(ocr_text, content_type='text/plain') + except (opensearchpy.exceptions.ConnectionError, opensearchpy.exceptions.ConnectionTimeout, opensearchpy.exceptions.TransportError) as e: + current_app.logger.exception(f"OpenSearch error: {e}") + return jsonify(message='Search service temporarily unavailable', type=ApiErrors.SearchError.value), 503 except Exception as e: return jsonify(message=str(e), type=ApiErrors.SearchError.value), 400