Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ new_backend/arxivsearch/templates/
.coverage*
coverage.*
htmlcov/
legacy-data/
legacy-data/
.python-version
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ RUN mkdir -p /app/backend
COPY ./backend/poetry.lock ./backend/pyproject.toml ./backend/

WORKDIR /app/backend
RUN poetry install --all-extras --no-interaction --no-root
RUN poetry install --all-extras --no-interaction

COPY ./backend/ .

Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: build
.PHONY: deploy

build:
docker compose -f docker-local-redis.yml up
deploy:
docker compose -f docker-local-redis.yml up
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<a href="https://github.com/redis-developer/redis-arxiv-search"><img src="https://redis.io/wp-content/uploads/2024/04/Logotype.svg?raw=true" width="30%"><img></a>
<br />
<br />
<h1>🔎 arXiv Search API</h1>
<div display="inline-block">
<a href="https://docsearch.redisvl.com"><b>Hosted Demo</b></a>&nbsp;&nbsp;&nbsp;
<a href="https://github.com/redis-developer/redis-arxiv-search"><b>Code</b></a>&nbsp;&nbsp;&nbsp;
Expand All @@ -14,7 +15,7 @@
<br />
</div>

# 🔎 Redis arXiv Search

*This repository is the official codebase for the arxiv paper search app hosted at: **https://docsearch.redisvl.com***


Expand Down Expand Up @@ -111,17 +112,17 @@ Embeddings represent the semantic properies of the raw text and enable vector si
- Add your `OPENAI_API_KEY` to the `.env` file. **Need one?** [Get an API key](https://platform.openai.com)
- Add you `COHERE_API_KEY` to the `.env` file. **Need one?** [Get an API key](https://cohere.ai)

### Redis Stack Docker (Local) with make
### Run locally with Redis 8 CE
```bash
make build
make deploy
```


## Customizing (optional)

### Run local redis with Docker
```bash
docker run -d --name redis -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
docker run -d --name redis -p 6379:6379 -p 8001:8001 redis:8.0-M03
```

### FastApi with poetry
Expand Down
22 changes: 11 additions & 11 deletions backend/arxivsearch/api/routes/papers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
import logging

import numpy as np
from fastapi import APIRouter, Depends, Query
from redisvl.index import AsyncSearchIndex
from redisvl.query import CountQuery, FilterQuery, VectorQuery

from arxivsearch import config
from arxivsearch.db import redis_helpers
from arxivsearch.db import utils
from arxivsearch.schema.models import (
PaperSimilarityRequest,
SearchResponse,
UserTextSimilarityRequest,
VectorSearchResponse,
)
from arxivsearch.utils.embeddings import Embeddings
from fastapi import APIRouter, Depends, Query
from redisvl.index import AsyncSearchIndex
from redisvl.query import CountQuery, FilterQuery, VectorQuery

logger = logging.getLogger(__name__)

Expand All @@ -26,7 +27,7 @@

@router.get("/", response_model=SearchResponse)
async def get_papers(
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
index: AsyncSearchIndex = Depends(utils.get_async_index),
limit: int = Query(default=20, description="Maximum number of papers to return."),
skip: int = Query(
default=0, description="Number of papers to skip for pagination."
Expand All @@ -53,9 +54,8 @@ async def get_papers(
Returns:
SearchResponse: Pydantic model containing papers and total count.
"""

# Build queries
filter_expression = redis_helpers.build_filter_expression(
filter_expression = utils.build_filter_expression(
years.split(","), categories.split(",")
)
filter_query = FilterQuery(return_fields=[], filter_expression=filter_expression)
Expand All @@ -72,7 +72,7 @@ async def get_papers(
@router.post("/vector_search/by_paper", response_model=VectorSearchResponse)
async def find_papers_by_paper(
similarity_request: PaperSimilarityRequest,
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
index: AsyncSearchIndex = Depends(utils.get_async_index),
):
"""
Find and return papers similar to a given paper based on vector
Expand All @@ -93,7 +93,7 @@ async def find_papers_by_paper(
paper[similarity_request.provider.value], dtype=np.float32
)
# Build filter expression
filter_expression = redis_helpers.build_filter_expression(
filter_expression = utils.build_filter_expression(
similarity_request.years, similarity_request.categories
)
# Create queries
Expand All @@ -115,7 +115,7 @@ async def find_papers_by_paper(
@router.post("/vector_search/by_text", response_model=VectorSearchResponse)
async def find_papers_by_text(
similarity_request: UserTextSimilarityRequest,
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
index: AsyncSearchIndex = Depends(utils.get_async_index),
):
"""
Find and return papers similar to user-provided text based on
Expand All @@ -131,7 +131,7 @@ async def find_papers_by_text(
"""

# Build filter expression
filter_expression = redis_helpers.build_filter_expression(
filter_expression = utils.build_filter_expression(
similarity_request.years, similarity_request.categories
)
# Check available paper count and create vector from user text
Expand Down
5 changes: 2 additions & 3 deletions backend/arxivsearch/db/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from redisvl.index import AsyncSearchIndex

from arxivsearch import config
from arxivsearch.db import redis_helpers
from arxivsearch.db.utils import get_async_index, get_schema
from arxivsearch.schema.models import Provider

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,8 +74,7 @@ async def preprocess_paper(paper: dict) -> dict:

async def load_data():
# Load schema specs and create index in Redis
index = AsyncSearchIndex(redis_helpers.schema)
await index.set_client(redis_helpers.client)
index = await get_async_index()

# Load dataset and create index
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,30 @@
import os
from typing import List

from arxivsearch import config
from redis.asyncio import Redis
from redisvl.index import AsyncSearchIndex, SearchIndex
from redisvl.index import AsyncSearchIndex
from redisvl.query.filter import FilterExpression, Tag
from redisvl.schema import IndexSchema

logger = logging.getLogger(__name__)


dir_path = os.path.dirname(os.path.realpath(__file__))
schema_path = os.path.join(dir_path, "index.yaml")
schema = IndexSchema.from_yaml(schema_path)
client = Redis.from_url(config.REDIS_URL)
global_index = None

from arxivsearch import config

def get_schema():
return IndexSchema.from_yaml(schema_path)
logger = logging.getLogger(__name__)


def get_test_index():
index = SearchIndex.from_yaml(schema_path)
index.connect(redis_url=config.REDIS_URL)
# global search index
_global_index = None

if not index.exists():
index.create(overwrite=True)

return index
def get_schema() -> IndexSchema:
dir_path = os.path.dirname(os.path.realpath(__file__)) + "/schema"
file_path = os.path.join(dir_path, "index.yaml")
return IndexSchema.from_yaml(file_path)


async def get_async_index():
global global_index
if not global_index:
global_index = AsyncSearchIndex.from_yaml(schema_path)
await global_index.set_client(client)
yield global_index
global _global_index
if not _global_index:
_global_index = AsyncSearchIndex(get_schema(), redis_url=config.REDIS_URL)
return _global_index


def build_filter_expression(
Expand Down
22 changes: 18 additions & 4 deletions backend/arxivsearch/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
import logging
from contextlib import asynccontextmanager
from pathlib import Path

import uvicorn
from arxivsearch import config
from arxivsearch.api.main import api_router
from arxivsearch.spa import SinglePageApplication
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from starlette.middleware.cors import CORSMiddleware

from arxivsearch import config
from arxivsearch.api.main import api_router
from arxivsearch.db.utils import get_async_index
from arxivsearch.spa import SinglePageApplication

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)


@asynccontextmanager
async def lifespan(app: FastAPI):
index = await get_async_index()
async with index:
yield


app = FastAPI(
title=config.PROJECT_NAME, docs_url=config.API_DOCS, openapi_url=config.OPENAPI_DOCS
title=config.PROJECT_NAME,
docs_url=config.API_DOCS,
openapi_url=config.OPENAPI_DOCS,
lifespan=lifespan,
)

app.add_middleware(
Expand Down
1 change: 1 addition & 0 deletions backend/arxivsearch/schema/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum

from pydantic import BaseModel


Expand Down
31 changes: 15 additions & 16 deletions backend/arxivsearch/tests/api/routes/test_papers.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
import pytest
from httpx import AsyncClient

from arxivsearch.main import app
from arxivsearch.schema.models import (
PaperSimilarityRequest,
UserTextSimilarityRequest,
)
from arxivsearch.schema.models import PaperSimilarityRequest, UserTextSimilarityRequest


@pytest.fixture
def years(papers):
return papers[0]["year"]
@pytest.fixture(scope="module")
def years(test_data):
return test_data[0]["year"]


@pytest.fixture
def categories(papers):
return papers[0]["categories"]
@pytest.fixture(scope="module")
def categories(test_data):
return test_data[0]["categories"]


@pytest.fixture
@pytest.fixture(scope="module")
def bad_req_json():
return {"not": "valid"}


@pytest.fixture
@pytest.fixture(scope="module")
def text_req(years, categories):
return UserTextSimilarityRequest(
categories=[categories],
Expand All @@ -33,10 +29,13 @@ def text_req(years, categories):
)


@pytest.fixture
def paper_req(papers):
@pytest.fixture(scope="module")
def paper_req(test_data):
return PaperSimilarityRequest(
categories=[], years=[], provider="huggingface", paper_id=papers[0]["paper_id"]
categories=[],
years=[],
provider="huggingface",
paper_id=test_data[0]["paper_id"],
)


Expand Down
57 changes: 37 additions & 20 deletions backend/arxivsearch/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,49 @@
import json
import os

import httpx
import numpy as np
import pytest
import pytest_asyncio
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
from redisvl.index import SearchIndex

from arxivsearch import config
from arxivsearch.db.utils import get_async_index, get_schema
from arxivsearch.main import app
from arxivsearch.tests.utils.seed import seed_test_db
from httpx import AsyncClient
from redis.asyncio import Redis


@pytest.fixture(scope="module")
def papers():
papers = seed_test_db()
return papers
@pytest.fixture(scope="session")
def index():
index = SearchIndex(schema=get_schema(), redis_url=config.REDIS_URL)
index.create()
yield index
index.disconnect()


@pytest.fixture
async def client():
client = await Redis.from_url(config.REDIS_URL)
yield client
try:
await client.aclose()
except RuntimeError as e:
if "Event loop is closed" not in str(e):
raise
@pytest.fixture(scope="session", autouse=True)
def test_data(index):
cwd = os.getcwd()
with open(f"{cwd}/arxivsearch/tests/test_vectors.json", "r") as f:
papers = json.load(f)

# convert to bytes
for paper in papers:
paper["huggingface"] = np.array(
paper["huggingface"], dtype=np.float32
).tobytes()
paper["openai"] = np.array(paper["openai"], dtype=np.float32).tobytes()
paper["cohere"] = np.array(paper["cohere"], dtype=np.float32).tobytes()

@pytest_asyncio.fixture(scope="session")
async def async_client():
_ = index.load(data=papers, id_field="paper_id")
return papers

async with AsyncClient(app=app, base_url="http://test/api/v1/") as client:

yield client
@pytest_asyncio.fixture(scope="session")
async def async_client():
async with LifespanManager(app=app) as lifespan:
async with AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test/api/v1/" # type: ignore
) as client:
yield client
Empty file.
Loading