From b1276c76d9ffbc91aedaf752e15f02d812d06109 Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Thu, 2 Oct 2025 14:26:06 -0700 Subject: [PATCH 01/11] Simplified openai batch llm worker implementation without component integration Signed-off-by: Jingyuan Zhang --- build/container/Dockerfile.runtime | 3 +- development/app/app.py | 9 +- python/aibrix/aibrix/batch/README.md | 348 +++++ python/aibrix/aibrix/batch/constant.py | 5 +- python/aibrix/aibrix/batch/driver.py | 152 ++- python/aibrix/aibrix/batch/job_driver.py | 432 ++++++ .../aibrix/batch/job_entity/__init__.py | 4 + .../aibrix/batch/job_entity/batch_job.py | 256 +++- .../batch/job_entity/job_entity_manager.py | 119 +- .../batch/job_entity/k8s_transformer.py | 604 +++++++++ python/aibrix/aibrix/batch/job_manager.py | 835 ++++++++---- .../aibrix/batch/job_progress_manager.py | 137 ++ python/aibrix/aibrix/batch/request_proxy.py | 126 -- python/aibrix/aibrix/batch/scheduler.py | 126 +- .../aibrix/aibrix/batch/storage/__init__.py | 4 + python/aibrix/aibrix/batch/storage/adapter.py | 249 +++- .../aibrix/batch/storage/batch_metastore.py | 129 +- .../aibrix/batch/storage/batch_storage.py | 52 +- python/aibrix/aibrix/batch/worker.py | 610 +++++++++ python/aibrix/aibrix/config.py | 23 + python/aibrix/aibrix/logger.py | 53 +- python/aibrix/aibrix/metadata/api/v1/batch.py | 134 +- python/aibrix/aibrix/metadata/api/v1/files.py | 13 +- python/aibrix/aibrix/metadata/app.py | 113 +- .../aibrix/aibrix/metadata/cache/__init__.py | 19 + python/aibrix/aibrix/metadata/cache/job.py | 1055 +++++++++++++++ python/aibrix/aibrix/metadata/cache/utils.py | 78 ++ .../aibrix/aibrix/metadata/core/__init__.py | 18 + .../aibrix/metadata/core/asyncio_thread.py | 113 ++ .../aibrix/metadata/core/kopf_operator.py | 264 ++++ python/aibrix/aibrix/metadata/logger.py | 58 - python/aibrix/aibrix/metadata/secret_gen.py | 408 ++++++ .../aibrix/aibrix/metadata/setting/config.py | 27 +- .../aibrix/metadata/setting/k8s_job_rbac.yaml | 30 + .../metadata/setting/k8s_job_redis_patch.yaml | 15 + .../metadata/setting/k8s_job_s3_patch.yaml | 29 + .../metadata/setting/k8s_job_template.yaml | 89 ++ .../metadata/setting/k8s_job_tos_patch.yaml | 34 + .../metadata/setting/s3_secret_template.yaml | 14 + .../metadata/setting/tos_secret_template.yaml | 15 + python/aibrix/aibrix/storage/base.py | 129 +- python/aibrix/aibrix/storage/factory.py | 4 +- python/aibrix/aibrix/storage/local.py | 23 +- python/aibrix/aibrix/storage/redis.py | 82 +- python/aibrix/aibrix/storage/s3.py | 25 +- python/aibrix/aibrix/storage/tos.py | 25 +- python/aibrix/aibrix/storage/utils.py | 46 +- python/aibrix/poetry.lock | 368 ++++- python/aibrix/pyproject.toml | 7 + python/aibrix/scripts/generate_secrets.py | 222 ++++ python/aibrix/tests/batch/conftest.py | 377 ++++++ .../aibrix/tests/batch/sample_job_input.jsonl | 3 - .../tests/batch/test_batch_storage_adapter.py | 491 +++++++ python/aibrix/tests/batch/test_driver.py | 260 +++- .../batch/test_e2e_abnormal_job_behavior.py | 932 +++++++++++++ .../tests/batch/test_e2e_openai_batch_api.py | 242 +++- .../test_inference_client_integration.py | 90 ++ python/aibrix/tests/batch/test_job_cache.py | 84 ++ python/aibrix/tests/batch/test_job_entity.py | 344 +++++ python/aibrix/tests/batch/test_job_manager.py | 146 +- .../tests/batch/test_k8s_job_persistence.py | 291 ++++ .../tests/batch/test_k8s_job_transformer.py | 1178 +++++++++++++++++ python/aibrix/tests/batch/test_rbac_setup.py | 98 ++ .../tests/batch/test_worker_s3_integration.py | 536 ++++++++ .../aibrix/tests/batch/testdata/job_rbac.yaml | 37 + .../testdata/k8s_job_patch_unittest.yaml | 14 + .../tests/batch/testdata/s3_secret.yaml | 14 + .../batch/testdata/sample_job_input.jsonl | 10 + .../tests/metadata/test_app_integration.py | 176 +++ .../tests/metadata/test_kopf_integration.py | 236 ++++ .../aibrix/tests/metadata/test_secret_gen.py | 212 +++ python/aibrix/tests/storage/test_reader.py | 4 +- .../tests/storage/test_redis_storage.py | 207 +++ python/aibrix/tests/storage/test_storage.py | 90 +- ...test_metadata_logger.py => test_logger.py} | 2 +- 75 files changed, 12869 insertions(+), 938 deletions(-) create mode 100644 python/aibrix/aibrix/batch/README.md create mode 100644 python/aibrix/aibrix/batch/job_driver.py create mode 100644 python/aibrix/aibrix/batch/job_entity/k8s_transformer.py create mode 100644 python/aibrix/aibrix/batch/job_progress_manager.py delete mode 100644 python/aibrix/aibrix/batch/request_proxy.py create mode 100644 python/aibrix/aibrix/batch/worker.py create mode 100644 python/aibrix/aibrix/metadata/cache/__init__.py create mode 100644 python/aibrix/aibrix/metadata/cache/job.py create mode 100644 python/aibrix/aibrix/metadata/cache/utils.py create mode 100644 python/aibrix/aibrix/metadata/core/__init__.py create mode 100644 python/aibrix/aibrix/metadata/core/asyncio_thread.py create mode 100644 python/aibrix/aibrix/metadata/core/kopf_operator.py delete mode 100644 python/aibrix/aibrix/metadata/logger.py create mode 100644 python/aibrix/aibrix/metadata/secret_gen.py create mode 100644 python/aibrix/aibrix/metadata/setting/k8s_job_rbac.yaml create mode 100644 python/aibrix/aibrix/metadata/setting/k8s_job_redis_patch.yaml create mode 100644 python/aibrix/aibrix/metadata/setting/k8s_job_s3_patch.yaml create mode 100644 python/aibrix/aibrix/metadata/setting/k8s_job_template.yaml create mode 100644 python/aibrix/aibrix/metadata/setting/k8s_job_tos_patch.yaml create mode 100644 python/aibrix/aibrix/metadata/setting/s3_secret_template.yaml create mode 100644 python/aibrix/aibrix/metadata/setting/tos_secret_template.yaml create mode 100644 python/aibrix/scripts/generate_secrets.py create mode 100644 python/aibrix/tests/batch/conftest.py delete mode 100644 python/aibrix/tests/batch/sample_job_input.jsonl create mode 100644 python/aibrix/tests/batch/test_batch_storage_adapter.py create mode 100644 python/aibrix/tests/batch/test_e2e_abnormal_job_behavior.py create mode 100644 python/aibrix/tests/batch/test_inference_client_integration.py create mode 100644 python/aibrix/tests/batch/test_job_cache.py create mode 100644 python/aibrix/tests/batch/test_job_entity.py create mode 100644 python/aibrix/tests/batch/test_k8s_job_persistence.py create mode 100644 python/aibrix/tests/batch/test_k8s_job_transformer.py create mode 100644 python/aibrix/tests/batch/test_rbac_setup.py create mode 100644 python/aibrix/tests/batch/test_worker_s3_integration.py create mode 100644 python/aibrix/tests/batch/testdata/job_rbac.yaml create mode 100644 python/aibrix/tests/batch/testdata/k8s_job_patch_unittest.yaml create mode 100644 python/aibrix/tests/batch/testdata/s3_secret.yaml create mode 100644 python/aibrix/tests/batch/testdata/sample_job_input.jsonl create mode 100644 python/aibrix/tests/metadata/test_app_integration.py create mode 100644 python/aibrix/tests/metadata/test_kopf_integration.py create mode 100644 python/aibrix/tests/metadata/test_secret_gen.py rename python/aibrix/tests/{test_metadata_logger.py => test_logger.py} (99%) diff --git a/build/container/Dockerfile.runtime b/build/container/Dockerfile.runtime index dbbdb3b00..747b3ffcb 100644 --- a/build/container/Dockerfile.runtime +++ b/build/container/Dockerfile.runtime @@ -34,8 +34,9 @@ WORKDIR /app COPY --from=builder /app/dist/*.whl ./ # Install build dependencies and clean up in one step (avoiding creating another new layer) +# procps grep mawk are for batch worker to kill llm engine RUN apt-get update \ - && apt-get install -y --no-install-recommends gcc python3-dev mawk \ + && apt-get install -y --no-install-recommends gcc python3-dev mawk procps grep \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* \ && pip install --no-cache-dir ./*.whl \ diff --git a/development/app/app.py b/development/app/app.py index 81c1eb966..0f74916b3 100644 --- a/development/app/app.py +++ b/development/app/app.py @@ -162,7 +162,7 @@ def get_token_count(text): # the metrics and results in lots of meaningless requests that we do not want to log. def disable_endpoint_logs(): """Disable logs for requests to specific endpoints.""" - disabled_endpoints = ('/', '/healthz', '/metrics') + disabled_endpoints = ('/', '/health', '/ready', '/metrics') parent_log_request = serving.WSGIRequestHandler.log_request def log_request(self, *args, **kwargs): @@ -175,6 +175,13 @@ def log_request(self, *args, **kwargs): app = Flask(__name__) disable_endpoint_logs() +@app.route('/health', methods=['GET']) +def health(): + return {"status": "ok"}, 200 + +@app.route('/ready', methods=['GET']) +def ready(): + return {"status": "ready"}, 200 @app.route('/v1/models', methods=['GET']) @auth.login_required diff --git a/python/aibrix/aibrix/batch/README.md b/python/aibrix/aibrix/batch/README.md new file mode 100644 index 000000000..d8c33d02f --- /dev/null +++ b/python/aibrix/aibrix/batch/README.md @@ -0,0 +1,348 @@ +# AIBrix Batch Processing + +This directory contains the batch processing system for AIBrix, designed to handle large-scale batch inference jobs using a sidecar pattern with vLLM. + +## Architecture Overview + +The batch processing system consists of several key components: + +- **BatchDriver**: Main orchestrator that manages job lifecycle +- **JobManager**: Handles job state management and tracking +- **RequestProxy**: Bridge between job management and inference execution +- **BatchWorker**: Script-based worker that runs in Kubernetes jobs + +## Worker Architecture + +The `worker.py` script implements a **sidecar pattern** where: + +1. **batch-worker container**: Runs the worker script (exits when complete) +2. **llm-engine container**: Runs vLLM as a persistent service + +### Workflow + +1. Both containers start simultaneously in the same pod +2. Worker waits for vLLM health check on `localhost:8000` +3. Worker loads its parent Kubernetes Job spec using the Kubernetes API +4. Worker transforms the K8s Job to BatchJob using BatchJobTransformer +5. Worker creates BatchDriver and executes the job +6. Worker tracks job until FINALIZING state +7. Worker exits with status code 0 +8. Kubernetes sends SIGTERM to vLLM container +9. Pod completes successfully + +## Testing in Kubernetes + +### Prerequisites + +1. **Docker Images**: Ensure you have the required images built: + ```bash + # Build the AIBrix runtime image with batch worker + make docker-build-runtime + + # Or use development images + docker build -t aibrix/runtime:nightly . + docker build -t aibrix/vllm-mock:nightly ./vllm-mock + ``` + +2. **Kubernetes Cluster**: Have a running cluster (minikube, kind, or full cluster) + +3. **Storage Setup**: Configure storage backend (S3, TOS, or local storage) + +### Basic Testing + +#### 1. Prepare Test Input File and Job Annotations + +Create a JSONL file with batch requests and prepare job annotations: + +```bash +# Create test input file +cat > /tmp/batch_input.jsonl << 'EOF' +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello world"}]}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is AI?"}]}} +EOF + +# Upload to your storage backend and get the file ID +INPUT_FILE_ID="your-uploaded-file-id" +``` + +#### 2. Create and Run Batch Job + +```bash +# Copy the job template +cp aibrix/metadata/setting/k8s_job_template.yaml /tmp/test-batch-job.yaml + +# Edit the template with your specific values +JOB_NAME="test-batch-job-$(date +%s)" +sed -i "s/name: batch-job-template/name: ${JOB_NAME}/" /tmp/test-batch-job.yaml + +# Add batch job annotations with your job specification +cat >> /tmp/test-batch-job.yaml << EOF + annotations: + batch.job.aibrix.ai/input-file-id: "${INPUT_FILE_ID}" + batch.job.aibrix.ai/endpoint: "/v1/chat/completions" + batch.job.aibrix.ai/completion-window: "24h" + batch.job.aibrix.ai/session-id: "test-session-$(date +%s)" +EOF + +# Apply the job +kubectl apply -f /tmp/test-batch-job.yaml +``` + +#### 3. Monitor Job Progress + +```bash +# Watch job status +kubectl get jobs -w + +# Check pod logs +POD_NAME=$(kubectl get pods -l app=aibrix-batch --no-headers -o custom-columns=":metadata.name" | head -1) + +# Watch worker logs +kubectl logs -f $POD_NAME -c batch-worker + +# Watch vLLM logs +kubectl logs -f $POD_NAME -c llm-engine +``` + +#### 4. Verify Job Completion + +```bash +# Check job completion +kubectl get job test-batch-job-* -o wide + +# Expected output should show: +# - COMPLETIONS: 1/1 +# - STATUS: Complete + +# Check pod status +kubectl get pods -l app=aibrix-batch + +# Expected output should show: +# - STATUS: Completed +``` + +### Advanced Testing + +#### Job Annotation Configuration + +The worker loads job specifications from Kubernetes Job annotations using the `batch.job.aibrix.ai/` prefix: + +```yaml +metadata: + annotations: + # Required annotations + batch.job.aibrix.ai/input-file-id: "file-abc123" # Input file with batch requests + batch.job.aibrix.ai/endpoint: "/v1/chat/completions" # API endpoint + + # Optional annotations + batch.job.aibrix.ai/completion-window: "24h" # Job completion window (default: 24h) + batch.job.aibrix.ai/session-id: "session-123" # Session identifier + batch.job.aibrix.ai/metadata.key1: "value1" # Custom metadata (up to 16 pairs) + batch.job.aibrix.ai/metadata.key2: "value2" + +# Environment variables for worker configuration +spec: + template: + spec: + containers: + - name: batch-worker + env: + - name: STORAGE_TYPE + value: "S3" # Optional: Storage backend (AUTO, LOCAL, S3, TOS) + - name: METASTORE_TYPE + value: "S3" # Optional: Metastore backend (AUTO, LOCAL, S3, TOS) +``` + +#### Storage Backend Testing + +For **S3 Backend**: +```yaml +env: +- name: STORAGE_TYPE + value: "S3" +- name: AWS_ACCESS_KEY_ID + value: "your-access-key" +- name: AWS_SECRET_ACCESS_KEY + value: "your-secret-key" +- name: AWS_REGION + value: "us-west-2" +- name: S3_BUCKET + value: "your-batch-bucket" +``` + +For **TOS Backend**: +```yaml +env: +- name: STORAGE_TYPE + value: "TOS" +- name: TOS_ACCESS_KEY + value: "your-tos-key" +- name: TOS_SECRET_KEY + value: "your-tos-secret" +- name: TOS_ENDPOINT + value: "https://tos-s3-cn-beijing.volces.com" +- name: TOS_REGION + value: "cn-beijing" +``` + +#### Custom vLLM Configuration + +Modify the vLLM container configuration: + +```yaml +- name: llm-engine + image: aibrix/vllm:nightly + args: + - --model + - microsoft/DialoGPT-medium + - --port + - "8000" + - --served-model-name + - gpt-3.5-turbo + resources: + requests: + nvidia.com/gpu: 1 + limits: + nvidia.com/gpu: 1 +``` + +### Debugging + +#### Common Issues + +1. **Worker fails to start**: Check job annotations and environment variables + ```bash + kubectl describe pod $POD_NAME + kubectl logs $POD_NAME -c batch-worker + + # Check job annotations + kubectl get job $JOB_NAME -o yaml | grep -A 10 annotations + ``` + +2. **vLLM health check fails**: Check vLLM container logs + ```bash + kubectl logs $POD_NAME -c llm-engine + ``` + +3. **Storage access issues**: Verify credentials and permissions + ```bash + kubectl logs $POD_NAME -c batch-worker | grep -i storage + ``` + +4. **Job hangs**: Check both container logs and job events + ```bash + kubectl describe job test-batch-job-* + kubectl get events --sort-by=.metadata.creationTimestamp + ``` + +#### Debug Mode + +Enable debug logging by adding: +```yaml +env: +- name: LOG_LEVEL + value: "DEBUG" +``` + +### Cleanup + +```bash +# Delete test job and pods +kubectl delete job test-batch-job-* +kubectl delete pod -l app=aibrix-batch + +# Clean up test files +rm -f /tmp/test-batch-job.yaml /tmp/batch_input.jsonl +``` + +## Development + +### Running Tests + +```bash +# Install dependencies +poetry install --no-root --with dev + +# Run batch-specific tests +pytest tests/ -k batch + +# Run formatting and linting +bash ./scripts/format.sh +``` + +### Code Structure + +``` +aibrix/batch/ +├── README.md # This file +├── worker.py # Main worker script +├── driver.py # BatchDriver orchestrator +├── job_manager.py # Job state management +├── job_driver.py # Execution bridge +├── scheduler.py # Job scheduling +├── job_entity/ # Job models and states +│ ├── batch_job.py # BatchJob schema +│ └── k8s_transformer.py # Kubernetes integration +└── storage/ # Storage abstractions + ├── adapter.py # Storage bridge + └── batch_storage.py # Batch storage interface +``` + +The worker script follows the sidecar pattern and integrates with the existing batch processing architecture to provide scalable, Kubernetes-native batch inference capabilities. + +## Architecture Changes + +### New Kubernetes Integration + +The worker now uses **Kubernetes-native job specification** instead of environment variables: + +1. **Job Annotations**: Job specifications are stored as Kubernetes annotations using the `batch.job.aibrix.ai/` prefix +2. **Kubernetes API**: Worker connects to the K8s API server using in-cluster configuration +3. **BatchJobTransformer**: Automatically transforms K8s Job objects to internal BatchJob models +4. **Service Account**: Requires appropriate RBAC permissions to read Job objects + +### Required RBAC + +The batch worker requires a service account with permissions to read Job objects, as defined in the aibrix/metadata/setting/k8s_job_rbac.yaml + +```yaml +apiVersion: v1 +kind: ServiceAccount +metadata: + name: job-reader-sa + namespace: default +--- +# Service Account for job pod +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: job-reader-role + namespace: default +rules: +- apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get"] # Get permissions only +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: job-reader-binding + namespace: default +subjects: +- kind: ServiceAccount + name: job-reader-sa + namespace: default +roleRef: + kind: Role + name: job-reader-role + apiGroup: rbac.authorization.k8s.io +``` + +Add the service account to your job template: + +```yaml +spec: + template: + spec: + serviceAccountName: job-reader-sa +``` \ No newline at end of file diff --git a/python/aibrix/aibrix/batch/constant.py b/python/aibrix/aibrix/batch/constant.py index 98cbce337..e3c411933 100644 --- a/python/aibrix/aibrix/batch/constant.py +++ b/python/aibrix/aibrix/batch/constant.py @@ -15,7 +15,10 @@ # The following are all constants. # This is the time interval for the sliding window to check. -EXPIRE_INTERVAL = 1 +EXPIRE_INTERVAL: float = 1 # This is the job pool size in job scheduler. # It should be proportional to resource size in the backend. DEFAULT_JOB_POOL_SIZE = 1 + +# Job opts are for testing purpose. +BATCH_OPTS_FAIL_AFTER_N_REQUESTS = "fail_after_n_requests" diff --git a/python/aibrix/aibrix/batch/driver.py b/python/aibrix/aibrix/batch/driver.py index 9dfdf2b05..5d67bf0a9 100644 --- a/python/aibrix/aibrix/batch/driver.py +++ b/python/aibrix/aibrix/batch/driver.py @@ -13,16 +13,20 @@ # limitations under the License. import asyncio -from typing import Any, Dict, List, Optional +from typing import Any, Coroutine, Dict, List, Optional import aibrix.batch.storage as _storage from aibrix.batch.constant import DEFAULT_JOB_POOL_SIZE +from aibrix.batch.job_driver import InferenceEngineClient, ProxyInferenceEngineClient from aibrix.batch.job_entity import JobEntityManager from aibrix.batch.job_manager import JobManager -from aibrix.batch.request_proxy import RequestProxy from aibrix.batch.scheduler import JobScheduler -from aibrix.batch.storage.batch_metastore import initialize_batch_metastore -from aibrix.metadata.logger import init_logger +from aibrix.batch.storage.batch_metastore import ( + get_metastore_type, + initialize_batch_metastore, +) +from aibrix.logger import init_logger +from aibrix.metadata.core import AsyncLoopThread, T from aibrix.storage import StorageType logger = init_logger(__name__) @@ -34,80 +38,116 @@ def __init__( job_entity_manager: Optional[JobEntityManager] = None, storage_type: StorageType = StorageType.AUTO, metastore_type: StorageType = StorageType.AUTO, + llm_engine_endpoint: Optional[str] = None, + stand_alone: bool = False, + params={}, ): """ This is main entrance to bind all components to serve job requests. + + Args: + stand_alone: Set to true to start a new thread for job management. """ - _storage.initialize_storage(storage_type) - initialize_batch_metastore(metastore_type) + _storage.initialize_storage(storage_type, params) + initialize_batch_metastore(metastore_type, params) + self._async_thread_loop: Optional[AsyncLoopThread] = None + if stand_alone: + self._async_thread_loop = AsyncLoopThread("BatchDriver") self._storage = _storage - self._job_manager: JobManager = JobManager(job_entity_manager) + self._job_entity_manager: Optional[JobEntityManager] = job_entity_manager + self._job_manager: JobManager = JobManager() self._scheduler: Optional[JobScheduler] = None - self._scheduling_task: Optional[asyncio.Task] = None - self._proxy: RequestProxy = RequestProxy(self._job_manager) - # Only create jobs_running_loop if JobEntityManager does not have its own sched - if not job_entity_manager or not job_entity_manager.is_scheduler_enabled(): + # Only initiate scheduler if JobEntityManager does not have its own sched + if ( + not self._job_entity_manager + or not self._job_entity_manager.is_scheduler_enabled() + ): self._scheduler = JobScheduler(self._job_manager, DEFAULT_JOB_POOL_SIZE) self._job_manager.set_scheduler(self._scheduler) - self._scheduling_task = asyncio.create_task(self.jobs_running_loop()) + + # Initialize inference client with optional LLM engine endpoint + self._inference_client: Optional[InferenceEngineClient] = None + if llm_engine_endpoint is not None: + self._inference_client = ProxyInferenceEngineClient(llm_engine_endpoint) + + # Track jobs with fail_after_n_requests for stop() validation + self._jobs_with_fail_after: set[str] = set() + + logger.info( + "Batch driver initialized", + job_entity_manager=True if job_entity_manager else False, + job_scheduler=True if self._scheduler else False, + storage=_storage.get_storage_type().value, + metastore=get_metastore_type().value, + ) # type: ignore[call-arg] @property def job_manager(self) -> JobManager: return self._job_manager + async def start(self): + # Start thread + if self._async_thread_loop is not None: + self._async_thread_loop.start() + logger.info("Batch driver stand alone thread started") # type: ignore[call-arg] + else: + # name the loop + asyncio.get_running_loop().name = "default" + + if self._job_entity_manager is not None: + logger.info("Registering job entity manager handlers") + await self.run_coroutine( + self.job_manager.set_job_entity_manager(self._job_entity_manager) + ) + + if self._scheduler is not None: + logger.info("starting scheduler") + await self.run_coroutine(self._scheduler.start(self._inference_client)) + async def upload_job_data(self, input_file_name) -> str: - return await self._storage.upload_input_data(input_file_name) + return await self.run_coroutine( + self._storage.upload_input_data(input_file_name) + ) async def retrieve_job_result(self, file_id) -> List[Dict[str, Any]]: - return await self._storage.download_output_data(file_id) + return await self.run_coroutine(self._storage.download_output_data(file_id)) - async def jobs_running_loop(self): - """ - This loop is going through all active jobs in scheduler. - For now, the executing unit is one request. Later if necessary, - we can support a batch size of request per execution. - """ - logger.info("Starting scheduling...") - while True: - one_job = await self._scheduler.round_robin_get_job() - if one_job: - try: - await self._proxy.execute_queries(one_job) - except Exception as e: - job = self._job_manager.mark_job_failed(one_job) - logger.error( - "Failed to execute job", - job_id=one_job, - status=job.status.state.value, - error=e, - ) - raise - await asyncio.sleep(0) - - async def close(self): + async def stop(self): """Properly shutdown the driver and cancel running tasks""" - if self._scheduling_task and not self._scheduling_task.done(): - self._scheduling_task.cancel() - try: - await self._scheduling_task - except (asyncio.CancelledError, RuntimeError) as e: - if isinstance(e, RuntimeError) and "different loop" in str(e): - logger.warning( - "Task cancellation from different event loop, forcing cancellation" - ) - pass - if self._scheduler: - await self._scheduler.close() + if self._scheduler is not None: + await self.run_coroutine(self._scheduler.stop()) + + if self._async_thread_loop is not None: + self._async_thread_loop.stop() + logger.info("Batch driver stand alone thread stopped") # type: ignore[call-arg] async def clear_job(self, job_id): - job = self._job_manager.get_job(job_id) + """Clear job related data for testing""" + if ( + self._async_thread_loop is not None + and self._async_thread_loop.loop != asyncio.get_running_loop() + ): + return await self._async_thread_loop.run_coroutine(self.clear_job(job_id)) + + job = await self._job_manager.get_job(job_id) if job is None: return - self._job_manager.job_deleted_handler(job) - if self._job_manager.get_job(job_id) is None: - await self._storage.remove_job_data(job.spec.input_file_id) + if await self._job_manager.delete_job(job_id): + tasks = [self._storage.remove_job_data(job.spec.input_file_id)] if job.status.output_file_id is not None: - await self._storage.remove_job_data(job.status.output_file_id) + tasks.append(self._storage.remove_job_data(job.status.output_file_id)) if job.status.error_file_id is not None: - await self._storage.remove_job_data(job.status.error_file_id) + tasks.append(self._storage.remove_job_data(job.status.error_file_id)) + + await asyncio.gather(*tasks) + + async def run_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: + """ + Submits a coroutine to the event loop and returns an awaitable Future. + This method itself MUST be awaited. (For use from async code) + """ + if self._async_thread_loop is not None: + return await self._async_thread_loop.run_coroutine(coro) + + return await coro diff --git a/python/aibrix/aibrix/batch/job_driver.py b/python/aibrix/aibrix/batch/job_driver.py new file mode 100644 index 000000000..36e5c294a --- /dev/null +++ b/python/aibrix/aibrix/batch/job_driver.py @@ -0,0 +1,432 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import uuid +from typing import Any, Optional +from urllib.parse import urljoin + +import httpx + +import aibrix.batch.constant as constant +import aibrix.batch.storage as storage +from aibrix.batch.job_entity import ( + BatchJob, + BatchJobError, + BatchJobErrorCode, + BatchJobState, + ConditionType, +) +from aibrix.batch.job_progress_manager import JobProgressManager +from aibrix.logger import init_logger + +logger = init_logger(__name__) + + +class InferenceEngineClient: + async def inference_request(self, endpoint: str, request_data): + """Send inference request to the LLM engine.""" + await asyncio.sleep(constant.EXPIRE_INTERVAL) # Simulate processing time + return request_data + + +class ProxyInferenceEngineClient(InferenceEngineClient): + def __init__(self, base_url: str): + """ + Initiate client to inference engine. + """ + self.base_url = base_url + + async def inference_request(self, endpoint: str, request_data): + """Real inference request to LLM engine.""" + url = urljoin(self.base_url, endpoint) + + logger.debug("requesting inference", url=url, body=request_data) # type: ignore[call-arg] + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=request_data, timeout=30.0) + response.raise_for_status() + return response.json() + + +class JobDriver: + def __init__( + self, + progress_manager: JobProgressManager, + inference_client: Optional[InferenceEngineClient] = None, + ) -> None: + """ + JobDriver drives job progress after a job being started. The progress expreiences three phases: + 1. Job preparing: job output file and error file are prepared. + 2. Job executing: tasks in the job are read and executed, possibly in parallel, without order reservation. + 3. Job finalizing: aggregate job outputs and errors. + + Usage: + * Call execute_job() to execute all phases. This is usually be the case if running in API server with scheduler enabled. + * Call prepare_job() for Job preparing. API server runs without scheduler will call this to prepara job. + * Call execute_worker() for Job executing, supporting parallel exeuction. LLM colocated worker will call this. + * Call finalize_job() for Job finalizing. API server runs without scheduler will call this to aggregate outputs. + """ + self._progress_manager = progress_manager + if inference_client is None: + self._inference_client = InferenceEngineClient() + else: + self._inference_client = inference_client + + async def execute_job(self, job_id): + """ + Execute complete job workflow: prepare -> execute -> finalize. + This function executes all three steps. + """ + job = await self._progress_manager.get_job(job_id) + if job is None: + logger.warning("Job not found", job_id=job_id) + return + + # Check if temp file IDs exist to determine if we should skip steps 1 and 3 + has_temp_files = ( + job.status.temp_output_file_id and job.status.temp_error_file_id + ) + + if not has_temp_files: + # Step 1: Prepare job output files + logger.debug("Temp files not created, creating...", job_id=job_id) + job = await storage.prepare_job_ouput_files(job) + + logger.debug( + "Confirmed temp files", + job_id=job_id, + temp_output_file_id=job.status.temp_output_file_id, + temp_error_file_id=job.status.temp_error_file_id, + ) + + # Step 2: Execute worker (core execution) + try: + job = await self.execute_worker(job_id) + except Exception as ex: + # Handle exception here, so we can execute finalizing if necessary. + job = await self._progress_manager.mark_job_failed( + job_id, + BatchJobError(code=BatchJobErrorCode.INFERENCE_FAILED, message=str(ex)), + ) + + # Step 3: Aggregate outputs + if job.status.state == BatchJobState.FINALIZING: + if not has_temp_files: + await storage.finalize_job_output_data(job) + + logger.debug("Completed job", job_id=job_id) + job = await self._sync_job_status(job_id) + + if job.status.failed: + raise RuntimeError(job.status.get_condition(ConditionType.FAILED).message) + + async def prepare_job(self, job: BatchJob) -> BatchJob: + """ + Prepare job output files by creating multipart uploads. + This is called by metadata server when a new job is committed. + """ + logger.debug("Preparing job output files") # type: ignore[call-arg] + job = await storage.prepare_job_ouput_files(job) + logger.debug("Job output files prepared") # type: ignore[call-arg] + return job + + async def execute_worker(self, job_id) -> BatchJob: + """ + Execute worker logic: process requests without file preparation or finalization. + This function only executes step 2 (the core execution loop). + """ + # Verify job status and get minimum unfinished request id + job, line_no = await self._get_next_request(job_id) + if line_no < 0: + logger.warning( + "Job has something wrong with metadata in job manager, nothing left to execute", + job_id=job_id, + ) # type: ignore[call-arg] + return job + + # [TODO][NOW] find a quick way to decide where to start testing using metastore + if line_no == 0: + logger.debug("Start processing job", job_id=job_id, opts=job.spec.opts) # type: ignore[call-arg] + else: + logger.debug( + "Resuming job", job_id=job_id, request_id=line_no, opts=job.spec.opts + ) # type: ignore[call-arg] + + # Check for fail_after_n_requests option + fail_after_n_requests = None + if job.spec.opts and constant.BATCH_OPTS_FAIL_AFTER_N_REQUESTS in job.spec.opts: + try: + fail_after_n_requests = int( + job.spec.opts[constant.BATCH_OPTS_FAIL_AFTER_N_REQUESTS] + ) + logger.debug( + "Detected fail_after_n_requests option", + job_id=job_id, + fail_after_n_requests=fail_after_n_requests, + ) # type: ignore[call-arg] + except (ValueError, TypeError): + logger.warning( + "Invalid fail_after_n_requests value, ignoring", + job_id=job_id, + value=job.spec.opts["fail_after_n_requests"], + ) # type: ignore[call-arg] + + # Step 2: Execute requests, resumable. + processed_requests = 0 + last_line_no = line_no + while line_no >= 0: + async for request_input in storage.read_job_next_request(job, line_no): + # Extract the request index from the locked request + next_line_no = request_input.pop("_request_index", last_line_no) + # Valid status of skipped requests. + while last_line_no < next_line_no: + if await storage.is_request_done(job, last_line_no): + # Mark the skipped request done + logger.debug( + "Mark skipped request as done locally", + job_id=job_id, + request_id=last_line_no, + ) # type: ignore[call-arg] + job, line_no = await self._sync_job_status_and_get_next_request( + job_id, last_line_no + ) + else: + # Simply skipped the request and get next request id + job, line_no = await self._get_next_request(job_id) + logger.debug( + "Will test next request", + job_id=job_id, + next_unexecuted=line_no, + next_executable=next_line_no, + last_line_no=last_line_no, + ) # type: ignore[call-arg] + if line_no < last_line_no: + # Start next round or stop if no more requests + break + last_line_no = line_no + + # Start next round or stop if no more requests + if line_no < last_line_no: + break + + if line_no != next_line_no: + raise RuntimeError( + f"Metastore inconsistency: expected request index {line_no} but got {next_line_no}" + ) + # Or global status maintained by metastore is not consistent with local status + + custom_id = request_input.get("custom_id", "") + logger.debug( + "Executing job request", + job_id=job_id, + line=line_no, + request_id=line_no, + custom_id=custom_id, + ) # type: ignore[call-arg] + + # Retry inference request up to 3 times with exponential backoff + request_output, last_error = await self._retry_inference_request( + job.spec.endpoint, request_input["body"], job_id, line_no + ) + + # Build standardized response + response = self._build_response( + custom_id, job_id, line_no, request_output, last_error + ) + + logger.debug( + "Got request response", + job_id=job_id, + request_id=line_no, + response=response, + ) # type: ignore[call-arg] + # Write single output and unlock the request + await storage.write_job_output_data(job, line_no, response) + + assert last_line_no == line_no + logger.debug("Job request executed", job_id=job_id, request_id=line_no) # type: ignore[call-arg] + + # Check for fail_after_n_requests condition + if fail_after_n_requests is not None: + processed_requests += 1 + if processed_requests >= fail_after_n_requests: + logger.info( + "Triggering artificial failure due to fail_after_n_requests", + job_id=job_id, + processed_requests=processed_requests, + fail_after_n_requests=fail_after_n_requests, + ) # type: ignore[call-arg] + raise RuntimeError( + f"Artificial failure triggered after processing {processed_requests} requests " + f"(fail_after_n_requests={fail_after_n_requests})" + ) + job, line_no = await self._sync_job_status_and_get_next_request( + job_id, last_line_no + ) + logger.debug( + "Confirmed next request", + job_id=job_id, + next_unexecuted=line_no, + last_line_no=last_line_no, + ) # type: ignore[call-arg] + if line_no < last_line_no: + break + last_line_no = line_no + + # For the first round, this shows we read end of input and we now know the total + if last_line_no == line_no: + job = await self._sync_job_status( + job_id, total=line_no + ) # Now that total == request_id + # We need to confirm that all is execute by try starting next round + job, line_no = await self._get_next_request(job_id) + logger.debug( + "Confirmed total requests", + job_id=job_id, + total=job.status.request_counts.total, + next_unexecuted=line_no, + ) # type: ignore[call-arg] + + # Now we'll testing if we really finished, or start another round. + + # Now that all finished. + logger.debug( + "Worker completed, job state:", + job_id=job_id, + total=job.status.request_counts.total if job else None, + state=job.status.state.value if job else None, + ) # type: ignore[call-arg] + return job + + async def finalize_job(self, job: BatchJob) -> BatchJob: + """ + Finalize the job by removing all data. + """ + assert job.status.state == BatchJobState.FINALIZING + + await storage.finalize_job_output_data(job) + + logger.debug("Finalized job", job_id=job.job_id) # type: ignore[call-arg] + return await self._sync_job_status(job.job_id) + + async def _retry_inference_request( + self, + endpoint: str, + request_data: dict, + job_id: str, + request_id: int, + max_retries: int = 3, + ) -> tuple[Any, Optional[Exception]]: + """ + Retry inference request with exponential backoff. + + Returns: + tuple: (request_output, last_error) - output on success, error on failure + """ + request_output = None + last_error = None + + for attempt in range(max_retries): + try: + request_output = await self._inference_client.inference_request( + endpoint, request_data + ) + break # Success, exit retry loop + except Exception as e: + last_error = e + logger.warning( + f"Inference request failed (attempt {attempt + 1}/{max_retries}): {e}", + job_id=job_id, + request_id=request_id, + ) # type: ignore[call-arg] + if attempt < max_retries - 1: # Don't sleep on last attempt + await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff + + return request_output, last_error + + def _build_response( + self, + custom_id: str, + job_id: str, + request_id: int, + request_output: Any = None, + error: Optional[Exception] = None, + ) -> dict[str, Any]: + """ + Build a standardized response object for job requests. + + Args: + custom_id: Custom identifier for the request + job_id: Job identifier + request_id: Request identifier + request_output: Successful response data (if any) + error: Error that occurred (if any) + + Returns: + dict: Standardized response object + """ + response: dict[str, Any] = { + "id": uuid.uuid4().hex[:5], + "error": None, + "response": None, + "custom_id": custom_id, + } + + if error is not None: + logger.error( + f"All inference attempts failed after retries: {error}", + job_id=job_id, + request_id=request_id, + ) # type: ignore[call-arg] + response["error"] = BatchJobError( + code=BatchJobErrorCode.INFERENCE_FAILED, message=str(error) + ) + else: + response["response"] = { + "status_code": 200, + "request_id": f"{job_id}-{request_id}", + "body": request_output, + } + + return response + + async def _sync_job_status(self, job_id, reqeust_id=-1, total=0) -> BatchJob: + """ + Update job's status back to job manager. + """ + if total > 0: + return await self._progress_manager.mark_job_total(job_id, total) + elif reqeust_id < 0: + return await self._progress_manager.mark_job_done(job_id) + else: + return await self._progress_manager.mark_jobs_progresses( + job_id, [reqeust_id] + ) + + async def _get_next_request(self, job_id: str) -> tuple[BatchJob, int]: + """ + Get next request id from job manager. + """ + return await self._progress_manager.get_job_next_request(job_id) + + async def _sync_job_status_and_get_next_request( + self, job_id: str, request_id: int + ) -> tuple[BatchJob, int]: + """ + Sync job status and get next request, with None checking. + """ + return await self._progress_manager.mark_job_progress_and_get_next_request( + job_id, request_id + ) diff --git a/python/aibrix/aibrix/batch/job_entity/__init__.py b/python/aibrix/aibrix/batch/job_entity/__init__.py index 3eaf490e7..bcaa7d73a 100644 --- a/python/aibrix/aibrix/batch/job_entity/__init__.py +++ b/python/aibrix/aibrix/batch/job_entity/__init__.py @@ -29,6 +29,7 @@ TypeMeta, ) from .job_entity_manager import JobEntityManager +from .k8s_transformer import BatchJobTransformer, JobAnnotationKey, k8s_job_to_batch_job __all__ = [ "BatchJob", @@ -46,4 +47,7 @@ "ObjectMeta", "RequestCountStats", "TypeMeta", + "BatchJobTransformer", + "JobAnnotationKey", + "k8s_job_to_batch_job", ] diff --git a/python/aibrix/aibrix/batch/job_entity/batch_job.py b/python/aibrix/aibrix/batch/job_entity/batch_job.py index cda9b9978..9a7c240bc 100644 --- a/python/aibrix/aibrix/batch/job_entity/batch_job.py +++ b/python/aibrix/aibrix/batch/job_entity/batch_job.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import uuid from datetime import datetime, timezone from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, ConfigDict, Field from pydantic_core import core_schema @@ -40,9 +41,9 @@ class CompletionWindow(str, Enum): TWENTY_FOUR_HOURS = "24h" - def expires_at(self) -> float: + def expires_at(self) -> int: """Returns the expiration time of the completion window.""" - return 86400.0 # Return default value + return 86400 # Return default value class BatchJobState(str, Enum): @@ -51,20 +52,9 @@ class BatchJobState(str, Enum): CREATED = "created" VALIDATING = "validating" IN_PROGRESS = "in_progress" - FINALIZING = "finalizing" - COMPLETED = "completed" - EXPIRED = "expired" - FAILED = "failed" CANCELLING = "cancelling" - CANCELED = "canceled" - - def is_finished(self): - return self in [ - BatchJobState.COMPLETED, - BatchJobState.FAILED, - BatchJobState.CANCELED, - BatchJobState.EXPIRED, - ] + FINALIZING = "finalizing" + FINALIZED = "finalized" class BatchJobErrorCode(str, Enum): @@ -75,15 +65,19 @@ class BatchJobErrorCode(str, Enum): INVALID_COMPLETION_WINDOW = "invalid_completion_window" INVALID_METADATA = "invalid_metadata" AUTHENTICATION_ERROR = "authentication_error" + INFERENCE_FAILED = "inference_failed" + PREPARE_OUTPUT_ERROR = "prepare_output_failed" + FINALIZING_ERROR = "finalizing_failed" UNKNOWN_ERROR = "unknown_error" class ConditionType(str, Enum): """Types of conditions for batch job status.""" - READY = "Ready" - PROCESSING = "Processing" - FAILED = "Failed" + COMPLETED = "completed" + EXPIRED = "expired" + FAILED = "failed" + CANCELLED = "cancelled" class ConditionStatus(str, Enum): @@ -126,22 +120,25 @@ class Condition(NoExtraBaseModel): class BatchJobSpec(NoExtraBaseModel): - """Defines the desired state of a Batch job, which is OpenAI batch compatible.""" + """Defines the specification of a Batch job input.""" input_file_id: str = Field( description="The ID of an uploaded file that contains the requests for the batch", ) - endpoint: BatchJobEndpoint = Field( + endpoint: str = Field( description="The API endpoint to be used for all requests in the batch" ) - completion_window: CompletionWindow = Field( - default=CompletionWindow.TWENTY_FOUR_HOURS, + completion_window: int = Field( + default=CompletionWindow.TWENTY_FOUR_HOURS.expires_at(), description="The time window for completion", ) metadata: Optional[Dict[str, str]] = Field( default=None, description="Set of up to 16 key-value pairs to attach to the batch object", - max_length=16, + ) + opts: Optional[Dict[str, str]] = Field( + default=None, + description="System-only options for internal use (e.g., fail_after_n_requests)", ) @classmethod @@ -151,6 +148,7 @@ def from_strings( endpoint: str, completion_window: str = CompletionWindow.TWENTY_FOUR_HOURS.value, metadata: Optional[Dict[str, str]] = None, + opts: Optional[Dict[str, str]] = None, ) -> "BatchJobSpec": """Create BatchJobSpec from string parameters with validation. @@ -159,6 +157,7 @@ def from_strings( endpoint: The API endpoint as string completion_window: The completion window as string metadata: Optional metadata dictionary + opts: Optional system options dictionary Returns: BatchJobSpec instance @@ -178,9 +177,10 @@ def from_strings( return cls( input_file_id=input_file_id, - endpoint=validated_endpoint, - completion_window=validated_completion_window, + endpoint=validated_endpoint.value, + completion_window=validated_completion_window.expires_at(), metadata=metadata, + opts=opts, ) @staticmethod @@ -276,26 +276,72 @@ def __init__( def __get_pydantic_core_schema__(cls, source, handler) -> core_schema.CoreSchema: """ Returns the pydantic-core schema for this class, allowing it to be - used directly within Pydantic models. + used directly within Pydantic models for both validation and serialization. """ - # This defines the schema for the arguments to __init__ - arguments_schema = core_schema.model_fields_schema( - { - "code": core_schema.model_field(core_schema.str_schema()), - "message": core_schema.model_field(core_schema.str_schema()), - "param": core_schema.model_field( - core_schema.nullable_schema(core_schema.str_schema()) - ), - "line": core_schema.model_field( - core_schema.nullable_schema(core_schema.str_schema()) - ), - } + + # def serialize_batch_job_error(instance: "BatchJobError") -> Dict[str, Any]: + # """Custom serializer for BatchJobError.""" + # return { + # "code": instance.code, + # "message": instance.message, + # "param": instance.param, + # "line": instance.line, + # } + + def validate_batch_job_error(value) -> "BatchJobError": + """Custom validator for BatchJobError.""" + if isinstance(value, cls): + return value + elif isinstance(value, dict): + return cls( + code=BatchJobErrorCode(value["code"]), + message=value["message"], + param=value.get("param"), + line=value.get("line"), + ) + else: + raise ValueError(f"Cannot convert {type(value)} to BatchJobError") + + return core_schema.no_info_plain_validator_function( + function=validate_batch_job_error, + serialization=core_schema.plain_serializer_function_ser_schema( + function=cls.json_serializer, + return_schema=core_schema.dict_schema(), + ), ) - return core_schema.call_schema( - arguments_schema, - function=cls, + + @classmethod + def json_serializer(cls, obj: Any): + """Handles types that the default JSON serializer doesn't know.""" + if isinstance(obj, cls): + return { + "code": obj.code, + "message": obj.message, + "param": obj.param, + "line": obj.line, + } + + return obj + + def __deepcopy__(self, memo): + """ + Provides a custom implementation for deep copying this object. + """ + # Create a new instance by calling __init__ with the current object's data. + # This correctly provides all the required arguments. + new_copy = self.__class__( + code=BatchJobErrorCode(self.code), + message=self.message, + param=self.param, + line=self.line, ) + # Standard practice: store the new object in the memo dictionary + # to handle potential circular references during the copy. + memo[id(self)] = new_copy + + return new_copy + class BatchJobStatus(NoExtraBaseModel): """Defines the observed state of BatchJobSpec.""" @@ -352,6 +398,11 @@ class BatchJobStatus(NoExtraBaseModel): alias="finalizingAt", description="Timestamp of when the batch job started finalizing", ) + finalized_at: Optional[datetime] = Field( + default=None, + alias="finalizedAt", + description="Timestamp of when the batch job was finalized, will be copied to completed_at, failed_at, expired_at, and cancelled_at based on condition", + ) completed_at: Optional[datetime] = Field( default=None, alias="completedAt", @@ -378,11 +429,76 @@ class BatchJobStatus(NoExtraBaseModel): description="Timestamp of when the batch job get cancelled", ) - conditions: Optional[Condition] = Field( + conditions: Optional[List[Condition]] = Field( default=None, description="Conditions represent the latest available observations of the batch job's state", ) + @property + def finished(self) -> bool: + return self.state == BatchJobState.FINALIZED + + @property + def completed(self) -> bool: + return self.finished and self.check_condition(ConditionType.COMPLETED) + + @property + def failed(self) -> bool: + return ( + self.finished + and self.check_condition(ConditionType.FAILED) + and not self.check_condition(ConditionType.EXPIRED) + ) + + @property + def expired(self) -> bool: + return self.finished and self.check_condition(ConditionType.EXPIRED) + + @property + def cancelled(self) -> bool: + return self.finished and self.check_condition(ConditionType.CANCELLED) + + @property + def condition(self) -> Optional[ConditionType]: + """If mutiple conditions exists, expired > failed > cancelled > completed""" + if self.conditions is None: + return None + elif self.check_condition(ConditionType.EXPIRED): + return ConditionType.EXPIRED + elif self.check_condition(ConditionType.FAILED): + return ConditionType.FAILED + elif self.check_condition(ConditionType.CANCELLED): + return ConditionType.CANCELLED + elif self.check_condition(ConditionType.COMPLETED): + return ConditionType.COMPLETED + else: + return None + + def check_condition(self, type: ConditionType) -> bool: + if self.conditions is None: + return False + + for condition in self.conditions: + if condition.type == type: + return True + + return False + + def get_condition(self, type: ConditionType) -> Optional[Condition]: + if self.conditions is None: + return None + + for condition in self.conditions: + if condition.type == type: + return condition + + return None + + def add_condition(self, condition: Condition): + if self.conditions is None: + self.conditions = [] + self.conditions.append(condition) + class BatchJob(NoExtraBaseModel): """Schema for the BatchJob API - Kubernetes Custom Resource equivalent.""" @@ -397,8 +513,17 @@ class BatchJob(NoExtraBaseModel): spec: BatchJobSpec = Field(description="Desired state of the batch job") status: BatchJobStatus = Field(description="Observed state of the batch job") + def copy(self): + return BatchJob( + sessionID=self.session_id, + typeMeta=self.type_meta, + metadata=self.metadata, + spec=self.spec, + status=copy.deepcopy(self.status), + ) + @classmethod - def create_new( + def new( cls, name: str, namespace: str, @@ -408,6 +533,24 @@ def create_new( metadata: Optional[Dict[str, str]] = None, ) -> "BatchJob": """Create a new BatchJob with default values.""" + return cls.new_from_spec( + name, + namespace, + spec=BatchJobSpec( + input_file_id=input_file_id, + endpoint=endpoint.value, + completion_window=completion_window.expires_at(), + metadata=metadata, + ), + ) + + @classmethod + def new_from_spec( + cls, + name: str, + namespace: str, + spec: BatchJobSpec, + ) -> "BatchJob": return cls( typeMeta=TypeMeta(apiVersion="batch.aibrix.ai/v1alpha1", kind="BatchJob"), metadata=ObjectMeta( @@ -417,12 +560,27 @@ def create_new( resourceVersion=None, deletionTimestamp=None, ), - spec=BatchJobSpec( - input_file_id=input_file_id, - endpoint=endpoint, - completion_window=completion_window, - metadata=metadata, + spec=spec, + status=BatchJobStatus( + jobID=str(uuid.uuid4()), + state=BatchJobState.CREATED, + createdAt=datetime.now(timezone.utc), + ), + ) + + @classmethod + def new_local( + cls, + spec: BatchJobSpec, + ) -> "BatchJob": + return cls( + typeMeta=TypeMeta(apiVersion="", kind="LocalBatchJob"), + metadata=ObjectMeta( + creationTimestamp=datetime.now(timezone.utc), + resourceVersion=None, + deletionTimestamp=None, ), + spec=spec, status=BatchJobStatus( jobID=str(uuid.uuid4()), state=BatchJobState.CREATED, diff --git a/python/aibrix/aibrix/batch/job_entity/job_entity_manager.py b/python/aibrix/aibrix/batch/job_entity/job_entity_manager.py index 98f98b630..0624cd5b9 100644 --- a/python/aibrix/aibrix/batch/job_entity/job_entity_manager.py +++ b/python/aibrix/aibrix/batch/job_entity/job_entity_manager.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Any, Callable, Coroutine, Optional from aibrix.batch.job_entity.batch_job import BatchJob, BatchJobSpec +from aibrix.logger import init_logger + +logger = init_logger(__name__) class JobEntityManager(ABC): @@ -25,59 +29,144 @@ class JobEntityManager(ABC): Any storage implementation are transparent to external components. """ - def on_job_committed(self, handler: Callable[[BatchJob], None]): + def on_job_committed( + self, handler: Callable[[BatchJob], Coroutine[Any, Any, bool]] + ): """Register a job committed callback. Args: - handler: (Callable[[BatchJob], None]) + handler: (async Callable[[BatchJob], bool]) The callback function. It should accept a single `BatchJob` object representing the committed job and return `None`. """ - self._job_committed_handler: Optional[Callable[[BatchJob], None]] = handler + self._job_committed_loop = asyncio.get_running_loop() + logger.debug( + "job committed handler registered", + loop=getattr(self._job_committed_loop, "name", "unknown"), + ) # type: ignore[call-arg] + self._job_committed_handler: Optional[ + Callable[[BatchJob], Coroutine[Any, Any, bool]] + ] = handler + + async def job_committed(self, committed: BatchJob) -> bool: + if self._job_committed_handler is None: + return True + + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe( + self._job_committed_handler(committed), self._job_committed_loop + ) + ) - def on_job_updated(self, handler: Callable[[BatchJob, BatchJob], None]): + def on_job_updated( + self, handler: Callable[[BatchJob, BatchJob], Coroutine[Any, Any, bool]] + ): """Register a job updated callback. Args: - handler: (Callable[[BatchJob, BatchJob], None]) + handler: (async Callable[[BatchJob, BatchJob], bool]) The callback function. It should accept two `BatchJob` objects representing the old job and new job and return `None`. Example: `lambda old_job, new_job: logger.info("Job updated", old_id=old_job.id, new_id=new_job.id)` """ - self._job_updated_handler: Optional[Callable[[BatchJob, BatchJob], None]] = ( - handler + self._job_updated_loop = asyncio.get_running_loop() + logger.debug( + "job updated handler registered", + loop=getattr(self._job_updated_loop, "name", "unknown"), + ) # type: ignore[call-arg] + self._job_updated_handler: Optional[ + Callable[[BatchJob, BatchJob], Coroutine[Any, Any, bool]] + ] = handler + + async def job_updated(self, old: BatchJob, new: BatchJob) -> bool: + if self._job_updated_handler is None: + return True + + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe( + self._job_updated_handler(old, new), self._job_updated_loop + ) ) - def on_job_deleted(self, handler: Callable[[BatchJob], None]): + def on_job_deleted(self, handler: Callable[[BatchJob], Coroutine[Any, Any, bool]]): """Register a job deleted callback. Args: - handler: (Callable[[BatchJob], None]) + handler: (async Callable[[BatchJob], bool]) The callback function. It should accept a single `BatchJob` object representing the deleted job and return `None`. Example: `lambda deleted_job: logger.info("Job deleted", job_id=deleted_job.id)` """ - self._job_deleted_handler: Optional[Callable[[BatchJob], None]] = handler + self._job_deleted_loop = asyncio.get_running_loop() + logger.debug( + "job deleted handler registered", + loop=getattr(self._job_deleted_loop, "name", "unknown"), + ) # type: ignore[call-arg] + self._job_deleted_handler: Optional[ + Callable[[BatchJob], Coroutine[Any, Any, bool]] + ] = handler + + async def job_deleted(self, deleted: BatchJob) -> bool: + if self._job_deleted_handler is None: + return True + + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe( + self._job_deleted_handler(deleted), self._job_deleted_loop + ) + ) def is_scheduler_enabled(self) -> bool: """Check if JobEntityManager has own scheduler enabled.""" return False @abstractmethod - def submit_job(self, session_id: str, job: BatchJobSpec): + async def submit_job(self, session_id: str, job: BatchJobSpec): """Submit job by submiting job to the persist store. Args: + session_id (str): id identifiy the job submission sesstion job (BatchJob): Job to add. """ pass @abstractmethod - def cancel_job(self, job_id: str): - """Cancel job by notifing the persist store. + async def update_job_ready(self, job: BatchJob): + """Update job by marking job ready with required information. + + Args: + job (BatchJob): Job to update. + """ + + @abstractmethod + async def update_job_status(self, job: BatchJob): + """Update job status by persisting status information as annotations. + + Args: + job (BatchJob): Job with updated status to persist. + + This method persists critical job status information including: + - Finalized state + - Conditions (completed, failed, cancelled) + - Request counts + - Timestamps (completed_at, cancelling_at, etc.) + """ + + @abstractmethod + async def cancel_job(self, job: BatchJob): + """Cancel job by notifing the persist store on job cancelling or failure. + + Args: + job (BatchJob): Job to cancel or failed + """ + pass + + @abstractmethod + async def delete_job(self, job: BatchJob): + """Delete job from the persist store. Args: - job_id (str): Job id. + job (BatchJob): Job to delete. """ pass diff --git a/python/aibrix/aibrix/batch/job_entity/k8s_transformer.py b/python/aibrix/aibrix/batch/job_entity/k8s_transformer.py new file mode 100644 index 000000000..d87bd13bd --- /dev/null +++ b/python/aibrix/aibrix/batch/job_entity/k8s_transformer.py @@ -0,0 +1,604 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +import json +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from aibrix.logger import init_logger + +from .batch_job import ( + BatchJob, + BatchJobError, + BatchJobErrorCode, + BatchJobSpec, + BatchJobState, + BatchJobStatus, + CompletionWindow, + Condition, + ConditionStatus, + ConditionType, + ObjectMeta, + RequestCountStats, + TypeMeta, +) + +# Annotation prefix for batch job specifications +JOB_ANNOTATION_PREFIX = "batch.job.aibrix.ai/" + +logger = init_logger(__name__) + + +class JobAnnotationKey(str, Enum): + """Valid annotation keys for job specifications.""" + + SESSION_ID = f"{JOB_ANNOTATION_PREFIX}session-id" + INPUT_FILE_ID = f"{JOB_ANNOTATION_PREFIX}input-file-id" + ENDPOINT = f"{JOB_ANNOTATION_PREFIX}endpoint" + METADATA_PREFIX = f"{JOB_ANNOTATION_PREFIX}metadata." + OPTS_PREFIX = f"{JOB_ANNOTATION_PREFIX}opts." + OUTPUT_FILE_ID = f"{JOB_ANNOTATION_PREFIX}output-file-id" + TEMP_OUTPUT_FILE_ID = f"{JOB_ANNOTATION_PREFIX}temp-output-file-id" + ERROR_FILE_ID = f"{JOB_ANNOTATION_PREFIX}error-file-id" + TEMP_ERROR_FILE_ID = f"{JOB_ANNOTATION_PREFIX}temp-error-file-id" + + # Status persistence annotations + JOB_STATE = f"{JOB_ANNOTATION_PREFIX}state" + CONDITION = f"{JOB_ANNOTATION_PREFIX}condition" + REQUEST_COUNTS = f"{JOB_ANNOTATION_PREFIX}request-counts" + IN_PROGRESS_AT = f"{JOB_ANNOTATION_PREFIX}in-progress-at" + CANCELLING_AT = f"{JOB_ANNOTATION_PREFIX}cancelling-at" + FINALIZING_AT = f"{JOB_ANNOTATION_PREFIX}finalizing-at" + FINALIZED_AT = f"{JOB_ANNOTATION_PREFIX}finalized-at" + ERRORS = f"{JOB_ANNOTATION_PREFIX}errors" + + +class BatchJobTransformer: + """Helper class to transform Kubernetes Job objects to BatchJob instances.""" + + @classmethod + def from_k8s_job(cls, k8s_job: Any) -> BatchJob: + """ + Transform a Kubernetes Job object to a BatchJob instance. + + Args: + k8s_job: Kubernetes Job object (from kubernetes.client.V1Job or kopf body) + + Returns: + BatchJob: Internal BatchJob model instance + + Raises: + ValueError: If required annotations are missing or invalid + """ + # Extract metadata with null safety + metadata = cls._safe_get_attr(k8s_job, "metadata", {}) + annotations: Dict[str, str] = cls._safe_get_attr(metadata, "annotations", {}) + + # Extract pod annotations from pod template (where we now store the batch job metadata) + pod_spec = cls._safe_get_attr(k8s_job, "spec", {}) + pod_template = cls._safe_get_attr(pod_spec, "template", {}) + pod_metadata = cls._safe_get_attr(pod_template, "metadata", {}) + pod_annotations: Dict[str, str] = cls._safe_get_attr( + pod_metadata, "annotations", {} + ) + + # Extract SessionID from pod annotations + session_id = pod_annotations.get(JobAnnotationKey.SESSION_ID.value) + + # Extract BatchJobSpec from pod annotations + spec = cls._extract_batch_job_spec(pod_annotations, pod_spec) + + # Extract ObjectMeta from job metadata (not pod metadata) + object_meta: ObjectMeta = cls._extract_object_meta(metadata) + + # Extract TypeMeta from Kubernetes Job + type_meta = cls._extract_type_meta(k8s_job) + + # Extract or create BatchJobStatus + status = cls._extract_batch_job_status( + k8s_job, object_meta.resource_version, pod_annotations, annotations + ) + + return BatchJob( + sessionID=session_id, + typeMeta=type_meta, + metadata=object_meta, + spec=spec, + status=status, + ) + + @classmethod + def _extract_batch_job_spec( + cls, annotations: Dict[str, str], pod_spec: Any + ) -> BatchJobSpec: + """Extract BatchJobSpec from Kubernetes job annotations.""" + # Extract required fields + input_file_id = annotations.get(JobAnnotationKey.INPUT_FILE_ID.value) + if not input_file_id: + raise ValueError( + f"Required annotation '{JobAnnotationKey.INPUT_FILE_ID.value}' not found" + ) + + endpoint = annotations.get(JobAnnotationKey.ENDPOINT.value) + if not endpoint: + raise ValueError( + f"Required annotation '{JobAnnotationKey.ENDPOINT.value}' not found" + ) + + # Extract batch metadata (key-value pairs with prefix) + batch_metadata = {} + batch_opts = {} + for key, value in annotations.items(): + if key.startswith(JobAnnotationKey.METADATA_PREFIX.value): + # Remove prefix to get the actual metadata key + metadata_key = key[len(JobAnnotationKey.METADATA_PREFIX.value) :] + batch_metadata[metadata_key] = value + elif key.startswith(JobAnnotationKey.OPTS_PREFIX.value): + # Remove prefix to get the actual opts key + opts_key = key[len(JobAnnotationKey.OPTS_PREFIX.value) :] + batch_opts[opts_key] = value + + # Use BatchJobSpec.from_strings for validation and creation + return BatchJobSpec( + input_file_id=input_file_id, + endpoint=endpoint, + completion_window=cls._safe_get_attr( + pod_spec, + "activeDeadlineSeconds", + CompletionWindow.TWENTY_FOUR_HOURS.expires_at(), + ), + metadata=batch_metadata if batch_metadata else None, + opts=batch_opts if batch_opts else None, + ) + + @classmethod + def _extract_object_meta(cls, k8s_metadata: Any) -> ObjectMeta: + """Extract ObjectMeta from Kubernetes metadata.""" + # Handle both attribute access and dict-like access + name = cls._safe_get_attr(k8s_metadata, "name") + namespace = cls._safe_get_attr(k8s_metadata, "namespace") + uid = cls._safe_get_attr(k8s_metadata, "uid") + resource_version = cls._safe_get_attr( + k8s_metadata, "resource_version" + ) or cls._safe_get_attr(k8s_metadata, "resourceVersion") + generation = cls._safe_get_attr(k8s_metadata, "generation") + + # Handle timestamp conversion + creation_timestamp = cls._convert_timestamp( + cls._safe_get_attr(k8s_metadata, "creation_timestamp") + or cls._safe_get_attr(k8s_metadata, "creationTimestamp") + ) + deletion_timestamp = cls._convert_timestamp( + cls._safe_get_attr(k8s_metadata, "deletion_timestamp") + or cls._safe_get_attr(k8s_metadata, "deletionTimestamp") + ) + + labels = cls._safe_get_attr(k8s_metadata, "labels") + annotations = cls._safe_get_attr(k8s_metadata, "annotations") + + return ObjectMeta( + name=name, + namespace=namespace, + uid=uid, + resourceVersion=resource_version, + generation=generation, + creationTimestamp=creation_timestamp, + deletionTimestamp=deletion_timestamp, + labels=labels, + annotations=annotations, + ) + + @classmethod + def _extract_type_meta(cls, k8s_job: Any) -> TypeMeta: + """Extract TypeMeta from Kubernetes Job.""" + # Extract apiVersion and kind from the Kubernetes job + api_version = cls._safe_get_attr(k8s_job, "api_version") or cls._safe_get_attr( + k8s_job, "apiVersion", "batch/v1" + ) + kind = cls._safe_get_attr(k8s_job, "kind", "Job") + + return TypeMeta(apiVersion=api_version, kind=kind) + + @classmethod + def _extract_batch_job_status( + cls, + k8s_job: Any, + resource_version: Optional[str], + podAnnotations: Dict[str, str], + annotations: Dict[str, str], + ) -> BatchJobStatus: + """Extract or create BatchJobStatus from Kubernetes job.""" + # Extract job status information + k8s_status = cls._safe_get_attr(k8s_job, "status", {}) + metadata = cls._safe_get_attr(k8s_job, "metadata", {}) + + # Generate or extract batch ID + job_id = cls._safe_get_attr(metadata, "uid") or str(uuid.uuid4()) + + # Map file ids + output_file_id = podAnnotations.get(JobAnnotationKey.OUTPUT_FILE_ID.value) + temp_output_file_id = podAnnotations.get( + JobAnnotationKey.TEMP_OUTPUT_FILE_ID.value + ) + error_file_id = podAnnotations.get(JobAnnotationKey.ERROR_FILE_ID.value) + temp_error_file_id = podAnnotations.get( + JobAnnotationKey.TEMP_ERROR_FILE_ID.value + ) + + # Extract conditions from Kubernetes job + conditions = cls._extract_conditions(k8s_status, annotations) + + # Map Kubernetes job phase to BatchJobState + state, finalizing_time = cls._map_k8s_phase_to_batch_state( + annotations, conditions + ) + + # Extract creation timestamp + creation_timestamp = cls._convert_timestamp( + cls._safe_get_attr(metadata, "creation_timestamp") + or cls._safe_get_attr(metadata, "creationTimestamp") + ) + if not creation_timestamp: + creation_timestamp = datetime.now(timezone.utc) + + status = BatchJobStatus( + jobID=job_id, + state=state, + outputFileID=output_file_id, + tempOutputFileID=temp_output_file_id, + errorFileID=error_file_id, + tempErrorFileID=temp_error_file_id, + createdAt=creation_timestamp, + finalizingAt=finalizing_time, + conditions=conditions, + ) + + # Update with persisted annotations if available + status = cls.update_status_from_annotations(status, annotations) + + logger.debug( + "Extracted batch job status", + jobID=job_id, + resource_version=resource_version, + state=status.state, + errors=status.errors, + k8s_status=k8s_status, + annotations=annotations, + status=status, + ) # type:ignore[call-arg] + + return status + + @classmethod + def _extract_conditions( + cls, k8s_status: Any, annotations: Dict[str, str] + ) -> Optional[List[Condition]]: + """Extract and convert Kubernetes conditions to AIBrix Condition objects.""" + k8s_conditions = cls._safe_get_attr(k8s_status, "conditions") + if k8s_conditions is None: + return None + + conditions = [] + has_failure = False + suspend_condition = annotations.get(JobAnnotationKey.CONDITION.value) + for k8s_condition in k8s_conditions: + condition_type = cls._safe_get_attr(k8s_condition, "type") + condition_status = cls._safe_get_attr(k8s_condition, "status") + condition_reason = cls._safe_get_attr(k8s_condition, "reason") + condition_message = cls._safe_get_attr(k8s_condition, "message") + + # Extract and convert timestamp + last_transition_time = cls._convert_timestamp( + cls._safe_get_attr(k8s_condition, "lastTransitionTime") + ) + if not last_transition_time: + last_transition_time = datetime.now(timezone.utc) + + # Map Kubernetes condition types to AIBrix ConditionType + aibrix_condition_type = None + if condition_type == "Complete" and condition_status == "True": + aibrix_condition_type = ConditionType.COMPLETED + elif condition_type == "Failed" and condition_status == "True": + if condition_reason == "DeadlineExceeded": + aibrix_condition_type = ConditionType.EXPIRED + else: + aibrix_condition_type = ConditionType.FAILED + has_failure = True + elif ( + condition_type == "Suspended" + and condition_status == "True" + and suspend_condition is not None + ): + aibrix_condition_type = ConditionType(suspend_condition) + if aibrix_condition_type == ConditionType.FAILED: + has_failure = True + + # Only add conditions that map to our types + if aibrix_condition_type: + conditions.append( + Condition( + type=aibrix_condition_type, + status=ConditionStatus.TRUE, # We only add True conditions + lastTransitionTime=last_transition_time, + reason=condition_reason, + message=condition_message, + ) + ) + + # Handle failure during finalizing. + if ( + not has_failure + and annotations.get(JobAnnotationKey.JOB_STATE.value) + == BatchJobState.FINALIZED.value + and suspend_condition == ConditionType.FAILED.value + ): + last_transition_time = cls._convert_timestamp( + annotations.get(JobAnnotationKey.FINALIZED_AT.value) + ) + if last_transition_time is None: + last_transition_time = datetime.now(timezone.utc) + + conditions.append( + Condition( + type=ConditionType.FAILED, + status=ConditionStatus.TRUE, # We only add True conditions + lastTransitionTime=last_transition_time, + ) + ) + + logger.debug( + "conditions check", conditions=len(conditions) if conditions else 0 + ) # type: ignore[call-arg] + + return conditions if len(conditions) > 0 else None + + @classmethod + def _map_k8s_phase_to_batch_state( + cls, annotations: Dict[str, str], conditions: Optional[List[Condition]] + ) -> Tuple[BatchJobState, Optional[datetime]]: + """ + Map Kubernetes job phase to BatchJobState. Most states can be identified using annotation except: + 1. Job first time created, which could created by the 3rd party. + 2. Job previously in progress and finished that need finalizing, which controlled by the 3rd party. + A special case is cancelling in progress, where state is finalizing, but we need to confirm the + finalizing time by check the time the job is suspended. + + Returns: + state: BatchJobState + finalizing_time: datetime, optional + """ + # If state available, respect it. + state_value = annotations.get(JobAnnotationKey.JOB_STATE.value) + if state_value: + state = BatchJobState(state_value) + if state not in [BatchJobState.IN_PROGRESS, BatchJobState.FINALIZING]: + return state, None + else: + state = BatchJobState.CREATED + return state, None + + # 1. If ConditionTypes are available, the state should always be FINALIZING + if conditions and len(conditions) > 0: + return BatchJobState.FINALIZING, conditions[0].last_transition_time + + return BatchJobState.IN_PROGRESS, None + + @classmethod + def _safe_get_attr(cls, obj: Any, attr: str, default: Any = None) -> Any: + """Safely get attribute from object, supporting both attr access and dict access.""" + if obj is None: + return default + + # Try dict-like access, use collections.abc.Mapping to support kopf.body + if isinstance(obj, collections.abc.Mapping): + val = obj.get(attr, None) + else: + val = getattr(obj, attr, None) + + return default if val is None else val + + @classmethod + def _convert_timestamp(cls, timestamp: Any) -> Optional[datetime]: + """Convert various timestamp formats to datetime.""" + if timestamp is None: + return None + + # If already a datetime object + if isinstance(timestamp, datetime): + return timestamp + + # If it's a string, try to parse it + if isinstance(timestamp, str): + try: + # Handle ISO format timestamps + return datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + except ValueError: + # Try other common formats + try: + return datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%SZ") + except ValueError: + return None + + # If it has a timestamp attribute (Kubernetes Time object) + if hasattr(timestamp, "timestamp"): + return timestamp.timestamp() + + return None + + @classmethod + def create_status_annotations(cls, job_status: BatchJobStatus) -> Dict[str, str]: + """Create pod template annotations from BatchJobStatus for persistence. + + Args: + job_status: BatchJobStatus to persist + + Returns: + Dict of annotations to add to pod template + """ + annotations = {} + + # Persist batch job state + annotations[JobAnnotationKey.JOB_STATE.value] = job_status.state.value + + # Persist conditions (failed, cancelled) + if job_status.check_condition(ConditionType.CANCELLED): + annotations[JobAnnotationKey.CONDITION.value] = ( + ConditionType.CANCELLED.value + ) + elif job_status.check_condition(ConditionType.FAILED): + annotations[JobAnnotationKey.CONDITION.value] = ConditionType.FAILED.value + + # Persist errors + if job_status.errors is not None and len(job_status.errors) > 0: + annotations[JobAnnotationKey.ERRORS.value] = json.dumps( + job_status.errors, default=BatchJobError.json_serializer + ) + + # Persist request counts (only if they contain meaningful data) + if job_status.request_counts.total > 0: + request_counts_data = { + "total": job_status.request_counts.total, + "launched": job_status.request_counts.launched, + "completed": job_status.request_counts.completed, + "failed": job_status.request_counts.failed, + } + annotations[JobAnnotationKey.REQUEST_COUNTS.value] = json.dumps( + request_counts_data + ) + + # Persist timestamps (only if they exist) + timestamp_mappings = [ + (job_status.in_progress_at, JobAnnotationKey.IN_PROGRESS_AT), + (job_status.finalizing_at, JobAnnotationKey.FINALIZING_AT), + (job_status.finalized_at, JobAnnotationKey.FINALIZED_AT), + (job_status.cancelling_at, JobAnnotationKey.CANCELLING_AT), + ] + + for timestamp, annotation_key in timestamp_mappings: + if timestamp is not None: + annotations[annotation_key.value] = timestamp.isoformat() + + return annotations + + @classmethod + def update_status_from_annotations( + cls, job_status: BatchJobStatus, annotations: Dict[str, str] + ) -> BatchJobStatus: + """Update BatchJobStatus with data from persisted annotations. + + Args: + job_status: Existing BatchJobStatus to update + annotations: Pod template annotations containing persisted status + + Returns: + Updated BatchJobStatus + """ + # Update errors if persisted + if ( + persisted_errors := annotations.get(JobAnnotationKey.ERRORS.value) + ) is not None: + try: + errors: list[dict] = json.loads(persisted_errors) + job_status.errors = [] + for error in errors: + job_status.errors.append( + BatchJobError( + code=BatchJobErrorCode( + error.get("code", BatchJobErrorCode.UNKNOWN_ERROR.value) + ), + message=str(error.get("message")), + param=str(error.get("message")), + line=error.get("line"), # type: ignore[arg-type] + ) + ) + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Failed to parse persisted errors", error=str(e)) # type: ignore[call-arg] + + # Update request counts if persisted + if ( + persisted_counts := annotations.get(JobAnnotationKey.REQUEST_COUNTS.value) + ) is not None: + try: + counts_data = json.loads(persisted_counts) + job_status.request_counts = RequestCountStats( + total=counts_data.get("total", 0), + launched=counts_data.get("launched", 0), + completed=counts_data.get("completed", 0), + failed=counts_data.get("failed", 0), + ) + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Failed to parse persisted request counts", error=str(e)) # type: ignore[call-arg] + + # Update timestamps if persisted + timestamp_mappings = [ + (JobAnnotationKey.IN_PROGRESS_AT, "in_progress_at"), + (JobAnnotationKey.FINALIZING_AT, "finalizing_at"), + (JobAnnotationKey.FINALIZED_AT, "finalized_at"), + (JobAnnotationKey.CANCELLING_AT, "cancelling_at"), + ] + + for annotation_key, attr_name in timestamp_mappings: + if ( + persisted_timestamp := annotations.get(annotation_key.value) + ) is not None and ( + converted_timestamp := cls._convert_timestamp(persisted_timestamp) + ) is not None: + setattr(job_status, attr_name, converted_timestamp) + + if job_status.state == BatchJobState.FINALIZED: + if ( + condition := job_status.get_condition(ConditionType.FAILED) + ) is not None: + job_status.failed_at = ( + job_status.finalized_at or condition.last_transition_time + ) + elif ( + condition := job_status.get_condition(ConditionType.CANCELLED) + ) is not None: + job_status.cancelled_at = ( + job_status.finalized_at or condition.last_transition_time + ) + elif ( + condition := job_status.get_condition(ConditionType.EXPIRED) + ) is not None: + job_status.expired_at = ( + job_status.finalized_at or condition.last_transition_time + ) + elif ( + condition := job_status.get_condition(ConditionType.COMPLETED) + ) is not None: + job_status.completed_at = ( + job_status.finalized_at or condition.last_transition_time + ) + + return job_status + + +def k8s_job_to_batch_job(k8s_job: Any) -> BatchJob: + """ + Convenience function to transform a Kubernetes Job object to a BatchJob. + + Args: + k8s_job: Kubernetes Job object (from kubernetes.client.V1Job or kopf body) + + Returns: + BatchJob: Internal BatchJob model instance + + Raises: + ValueError: If required annotations are missing or invalid + """ + return BatchJobTransformer.from_k8s_job(k8s_job) diff --git a/python/aibrix/aibrix/batch/job_manager.py b/python/aibrix/aibrix/batch/job_manager.py index dc05b5821..5a94685de 100644 --- a/python/aibrix/aibrix/batch/job_manager.py +++ b/python/aibrix/aibrix/batch/job_manager.py @@ -13,12 +13,11 @@ # limitations under the License. import asyncio -import copy -import uuid from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple +from aibrix.batch.job_driver import JobDriver from aibrix.batch.job_entity import ( BatchJob, BatchJobError, @@ -26,26 +25,30 @@ BatchJobSpec, BatchJobState, BatchJobStatus, + Condition, + ConditionStatus, + ConditionType, JobEntityManager, - ObjectMeta, - TypeMeta, ) +from aibrix.batch.job_progress_manager import JobProgressManager from aibrix.batch.scheduler import JobScheduler from aibrix.batch.storage import read_job_input_info -from aibrix.metadata.logger import init_logger +from aibrix.logger import init_logger -# Custom exceptions for job creation -class JobCreationError(Exception): - """Base exception for job creation errors.""" +# Custom exceptions for job manager +class JobManagerError(Exception): + """Base exception for job manager errors.""" pass -class JobCreationTimeoutError(JobCreationError): - """Exception raised when job creation times out.""" +class JobUnexpectedStateError(JobManagerError): + """Job in unexpcted status""" - pass + def __init__(self, message: str, state: Optional[BatchJobState]): + super().__init__(message) + self.state = state @dataclass @@ -81,8 +84,9 @@ def __init__(self, job: BatchJob): status=job.status, ) self._async_lock = asyncio.Lock() - self._current_request_id: int = ( - 0 # request_id < _current_request_id are all completed. + self._next_request_id: int = 0 + self._min_unexecuted_id: int = ( + 0 # request_id < _min_unexecuted_id are all completed. ) self._no_total: bool = job.status.request_counts.total == 0 # Initialize progress bits based on total request count @@ -90,47 +94,67 @@ def __init__(self, job: BatchJob): False ] * job.status.request_counts.total + @property + def batch_job(self) -> BatchJob: + return BatchJob( + typeMeta=self.type_meta, + metadata=self.metadata, + spec=self.spec, + status=self.status, + ) + def set_request_executed(self, req_id): # This marks the request successfully executed. self._request_progress_bits[req_id] = True + # Check if self._min_unexecuted_id need to be updated + if req_id != self._min_unexecuted_id: + return + # Update self._min_unexecuted_id + for i in range(self._min_unexecuted_id, self.status.request_counts.total): + if self._request_progress_bits[i]: + self._min_unexecuted_id = i + 1 + else: + break def get_request_bit(self, req_id): return self._request_progress_bits[req_id] - def get_job_status(self) -> BatchJobState: - return self.status.state - def complete_one_request(self, req_id, failed: bool = False): """ This is called after an inference call. If all requests are done, we need to update its status to be completed. """ if req_id == self.status.request_counts.total: - self.status.request_counts.total -= 1 # Fix total count - self.status.finalizing_at = datetime.now(timezone.utc) - self.status.state = BatchJobState.FINALIZING - return - - if not self._request_progress_bits[req_id]: + # Fix total count and launched count on total decided. + self.status.request_counts.total -= 1 + if self.status.request_counts.launched > self.status.request_counts.total: + self.status.request_counts.launched = self.status.request_counts.total + self._no_total = False + elif not self._request_progress_bits[req_id]: self.set_request_executed(req_id) if failed: self.status.request_counts.failed += 1 else: self.status.request_counts.completed += 1 - if ( - not self._no_total - and self.status.request_counts.completed - + self.status.request_counts.failed - == self.status.request_counts.total - ): - self.status.finalizing_at = datetime.now(timezone.utc) - self.status.state = BatchJobState.FINALIZING - def next_request_id(self): + # Test all done + if ( + not self._no_total + and self.status.request_counts.completed + self.status.request_counts.failed + == self.status.request_counts.total + ): + self.status.finalizing_at = datetime.now(timezone.utc) + self.status.state = BatchJobState.FINALIZING + + def next_request_id(self) -> int: """ - Returns the next request for inference. Due to the propobility + Returns the next request_id for inference. Due to the propobility that some requests are failed, this returns a request that - are not marked as executed. + are not marked as executed. We used round robin touch all requests + first and then start another round. + + Returns: + int: next_request_id or -1 if job is done """ if ( not self._no_total @@ -139,21 +163,25 @@ def next_request_id(self): ): return -1 - req_id = self._current_request_id - assert self._no_total or req_id != self.status.request_counts.total + req_id = self._next_request_id + # If total has confirmed and not all request executed, start next round. + if not self._no_total and req_id == self.status.request_counts.total: + req_id = self._min_unexecuted_id + # In case total has not confirmed, expland _request_progress_bits if necessary. if req_id >= len(self._request_progress_bits): self._request_progress_bits.append(False) + + # Skip executed requests. while self._request_progress_bits[req_id]: req_id += 1 if not self._no_total and req_id == self.status.request_counts.total: - return -1 + req_id = self._min_unexecuted_id if req_id >= len(self._request_progress_bits): self._request_progress_bits.append(False) - # Mark self._current_request_id, requests before self._current_request_id are all completed - # and don't need to retry. - self._current_request_id = req_id + # Update _next_request_id + self._next_request_id = req_id # Update launched request count if req_id >= self.status.request_counts.launched: self.status.request_counts.launched = req_id + 1 @@ -188,8 +216,35 @@ def job_authentication(self): return True -class JobManager: - def __init__(self, job_entity_manager: Optional[JobEntityManager] = None): +class JobManager(JobProgressManager): + # Valid state transitions are defined as: + # 1. Started -> Validating -> In_progress -> Finalizing -> Finalzed(condition: completed) + # 2. Started/Validating -> Finalzed (condition: failed) + # 3. In_progress -> Finalizing -> Finalized (condition: failed) + # 4. Started/Validating -> Cancelling -> Finalized (condition: cancelled) + # 5. In_progress -> Cancelling -> Finalizing -> Finalized (condition: cancelled) + # 6. Started/Validating -> Finalized (condition: expired) + # 7. In_progress -> Finalizing -> Finalized (condition: expired) + VALID_STATE_TRANSITIONS = { + BatchJobState.CREATED: [BatchJobState.VALIDATING], + BatchJobState.VALIDATING: [ + BatchJobState.IN_PROGRESS, + BatchJobState.FINALIZED, # For failed/expired conditions + BatchJobState.CANCELLING, # For cancellation + ], + BatchJobState.IN_PROGRESS: [ + BatchJobState.FINALIZING, + BatchJobState.CANCELLING, # For cancellation + ], + BatchJobState.FINALIZING: [BatchJobState.FINALIZED], + BatchJobState.CANCELLING: [ + BatchJobState.FINALIZED, + BatchJobState.FINALIZING, # For in_progress -> cancelling -> finalizing + ], + BatchJobState.FINALIZED: [], # Terminal state + } + + def __init__(self) -> None: """ This manages jobs in three categorical job pools. 1. _pending_jobs are jobs that are not scheduled yet @@ -197,26 +252,31 @@ def __init__(self, job_entity_manager: Optional[JobEntityManager] = None): Theses are the input to the job scheduler. 3. _done_jobs are inactive jobs. This needs to be updated periodically. """ + super().__init__() + self._pending_jobs: dict[str, BatchJob] = {} self._in_progress_jobs: dict[str, BatchJob] = {} self._done_jobs: dict[str, BatchJob] = {} self._job_scheduler: Optional[JobScheduler] = None - self._job_entity_manager: Optional[JobEntityManager] = job_entity_manager + self._job_entity_manager: Optional[JobEntityManager] = None # Track jobs being created with JobEntityManager self._creating_jobs: Dict[str, asyncio.Future[str]] = {} self._creation_timeouts: Dict[str, asyncio.Task] = {} self._session_metadata: Dict[str, Dict[str, Any]] = {} - # Register job lifecycle handlers if entity manager is available - if self._job_entity_manager: - self._job_entity_manager.on_job_committed(self.job_committed_handler) - self._job_entity_manager.on_job_updated(self.job_updated_handler) - self._job_entity_manager.on_job_deleted(self.job_deleted_handler) - - def set_scheduler(self, scheduler: JobScheduler): + def set_scheduler(self, scheduler: JobScheduler) -> None: self._job_scheduler = scheduler + async def set_job_entity_manager( + self, job_entity_manager: JobEntityManager + ) -> None: + self._job_entity_manager = job_entity_manager + # Register job lifecycle handlers within loop context + self._job_entity_manager.on_job_committed(self.job_committed_handler) + self._job_entity_manager.on_job_updated(self.job_updated_handler) + self._job_entity_manager.on_job_deleted(self.job_deleted_handler) + async def create_job( self, session_id: str, @@ -225,17 +285,21 @@ async def create_job( completion_window: str, meta_data: dict, timeout: float = 30.0, + initial_state: BatchJobState = BatchJobState.CREATED, ) -> str: job_spec = BatchJobSpec.from_strings( input_file_id, api_endpoint, completion_window, meta_data ) - return await self.create_job_with_spec(session_id, job_spec, timeout) + return await self.create_job_with_spec( + session_id, job_spec, timeout, initial_state + ) async def create_job_with_spec( self, session_id: str, job_spec: BatchJobSpec, timeout: float = 30.0, + initial_state: BatchJobState = BatchJobState.CREATED, ) -> str: """ Async job creation that waits for job ID to be available. @@ -270,40 +334,37 @@ async def create_job_with_spec( # Will trigger job committed handler # Note: When using job_entity_manager, the job_id will be available after the committed handler # For now, we return None since we don't have immediate access to the generated job_id - self._job_entity_manager.submit_job(session_id, job_spec) - job_id = await asyncio.wait_for(job_future, timeout=timeout) + submitted = asyncio.create_task( + self._job_entity_manager.submit_job(session_id, job_spec) + ) + timeouted = asyncio.create_task( + asyncio.wait_for(job_future, timeout=timeout) + ) + + _, job_id = await asyncio.gather(submitted, timeouted) logger.info( "Job created successfully", session_id=session_id, job_id=job_id ) # type: ignore[call-arg] except Exception: + print(f"timeout {datetime.now()}") raise finally: # Clean up tracking del self._creating_jobs[session_id] + if job_id is None: + raise RuntimeError("Job ID was not set during creation") return job_id # Local job handling. - job = BatchJob( - typeMeta=TypeMeta(apiVersion="", kind="LocalBatchJob"), - metadata=ObjectMeta( - resourceVersion=None, - creationTimestamp=datetime.now(timezone.utc), - deletionTimestamp=None, - ), - spec=job_spec, - status=BatchJobStatus( - jobID=str(uuid.uuid4()), - state=BatchJobState.CREATED, - createdAt=datetime.now(timezone.utc), - ), - ) - self.job_committed_handler(job) + job = BatchJob.new_local(job_spec) + job.status.state = initial_state + await self.job_committed_handler(job) assert job.job_id is not None return job.job_id - def cancel_job(self, job_id: str) -> bool: + async def cancel_job(self, job_id: str) -> bool: """ Cancel a job by job_id. @@ -314,6 +375,8 @@ def cancel_job(self, job_id: str) -> bool: The method considers the situation that while before signaling, the job is in pending or processing, but before job_deleted_handler is called, the job may have completed. + Noted: job not will be deleted from job_manager + Args: job_id: The ID of the job to cancel @@ -325,55 +388,133 @@ def cancel_job(self, job_id: str) -> bool: job_in_progress = False if job_id in self._pending_jobs: job = self._pending_jobs[job_id] - # Delete from pending jobs to avoid the job being scheduled. Status will be updated later + # remove from _pending_jobs to prevent scheduling anyway. del self._pending_jobs[job_id] + logger.debug("Job removed from a category", category="_pending_jobs") # type: ignore[call-arg] elif job_id in self._in_progress_jobs: job = self._in_progress_jobs[job_id] - job_in_progress = True + job_in_progress = job.status.state == BatchJobState.IN_PROGRESS elif job_id in self._done_jobs: # Job is already done (completed, failed, expired, or cancelled) logger.debug("Job is already in final state", job_id=job_id) # type: ignore[call-arg] - return True + return False else: logger.warning("Job not found", job_id=job_id) # type: ignore[call-arg] return False - # Check if job is already in a final state (race condition protection) + # Check if job is finalizing # We allow CANCELLING job be signalled again. - if job.status and job.status.state in [ - BatchJobState.COMPLETED, - BatchJobState.FAILED, - BatchJobState.EXPIRED, - BatchJobState.CANCELED, - ]: + if job.status.state == BatchJobState.FINALIZING: logger.info( # type: ignore[call-arg] - "Job is already in final state", job_id=job_id, state=job.status.state + "Job is finalizing", job_id=job_id, state=job.status.state ) return False + # Start cancel + + job.status.state = ( + BatchJobState.CANCELLING + ) # update local state until being cancelled + job.status.cancelling_at = datetime.now(timezone.utc) + if not job_in_progress: + self._in_progress_jobs[job_id] = job + logger.debug( + "Job added to a category during cancelling", category="_pending_jobs" + ) # type: ignore[call-arg] + + job_cancelled = job.copy() + job_cancelled.status.add_condition( + Condition( + type=ConditionType.CANCELLED, + status=ConditionStatus.TRUE, + lastTransitionTime=datetime.now(timezone.utc), + ) + ) + if job_in_progress: + job_cancelled.status.state = BatchJobState.FINALIZING + else: + job_cancelled.status.state = BatchJobState.FINALIZED + job_cancelled.status.finalized_at = job.status.cancelling_at + if self._job_entity_manager: # Signal the entity manager to cancel the job - # The actual state update will be handled by job_deleted_handler or job_updated_handler when called back - self._job_entity_manager.cancel_job(job_id) + # The actual state update will be handled by job_updated_handler when called back + await self._job_entity_manager.cancel_job(job_cancelled) return True - # For local jobs, directly call job_deleted_handler - job_done = copy.deepcopy(job) - if job_done.status: - job_done.status.state = BatchJobState.CANCELLING - job_done.status.cancelling_at = datetime.now() + # For local jobs, transit directly if job_in_progress: - # [TODO][NEXT] zhangjyr - # Remove all related requests from scheduler and proxy - del self._in_progress_jobs[job_id] + # [TODO][NEXT] Review decision of disabling cancellation of local in progress job. + # Local in progress job can not or need not be cancelled. + return False - if job_done.status: - job_done.status.state = BatchJobState.CANCELED - job_done.status.cancelled_at = datetime.now() - self.job_updated_handler(job, job_done) + await self.job_updated_handler(job, job_cancelled) return True - def job_committed_handler(self, job: BatchJob): + async def delete_job(self, job_id: str) -> bool: + """ + Delete a job by job_id. Only finished job can be deleted. + + Args: + job_id: The ID of the job to cancel + + Returns: + bool: True if deletion was initiated successfully, False otherwise + """ + # Check if job exists in any state + if (job := self._done_jobs.get(job_id)) is None: + # Job is not already done (completed, failed, expired, or cancelled) + logger.error("Job is not in final state on deleting", job_id=job_id) # type: ignore[call-arg] + return False + + if self._job_entity_manager: + # Signal the entity manager to delete the job + # The actual state update will be handled by job_deleted_handler when called back + await self._job_entity_manager.delete_job(job) + return True + + # For local jobs, transit directly + return await self.job_deleted_handler(job) + + def _validate_state_transition( + self, old_job: Optional[BatchJob], new_job: BatchJob + ) -> bool: + """Validate if the state transition is allowed based on the defined rules. + + Args: + old_job: The previous job state (None for new jobs) + new_job: The new job state + + Returns: + True if transition is valid, False otherwise + """ + if old_job is None: + # New job, allow any initial state + return True + + old_state = old_job.status.state + new_state = new_job.status.state + + # Same state is always valid + if old_state == new_state: + return True + + # Check if transition is in valid transitions + valid_next_states = self.VALID_STATE_TRANSITIONS.get(old_state, []) + is_valid = new_state in valid_next_states + + if not is_valid: + logger.warning( + "Invalid state transition for job", + job_id=new_job.status.job_id, + old_state=old_state, + new_state=new_state, + valid_transitions=valid_next_states, + ) # type: ignore[call-arg] + + return is_valid + + async def job_committed_handler(self, job: BatchJob) -> bool: """ This is called by job entity manager when a job is committed. Enhanced to resolve pending job creation futures. @@ -381,54 +522,169 @@ def job_committed_handler(self, job: BatchJob): job_id = job.job_id if not job_id: logger.error("Job ID not found in comitted job") - return - - category = self._categorize_jobs(job) - category[job_id] = job + return False # Check if this job resolves a pending creation - if job.session_id and job.session_id in self._creating_jobs: - future = self._creating_jobs[job.session_id] - if not future.done(): + if job.session_id: + if ( + future := self._creating_jobs.get(job.session_id) + ) is not None and not future.done(): future.set_result(job_id) logger.debug( "Job creation future resolved", session_id=job.session_id, job_id=job_id, ) # type: ignore[call-arg] + else: + # Ignore + logger.warning( + "Job creation timeout or already created", + session_id=job.session_id, + job_id=job_id, + ) # type: ignore[call-arg] + return False + + category, name = self._categorize_jobs(job, first_seen=True) + category[job_id] = job + logger.debug("Job added to a category", category=name) # type: ignore[call-arg] - # Add to job schduler if available - if category is self._pending_jobs and self._job_scheduler: + if category is not self._pending_jobs: + return True + + # Add to job scheduler if available (traditional workflow) + if self._job_scheduler: created_at: datetime = job.status.created_at + logger.info("Add job to scheduler", job_id=job_id) # type: ignore[call-arg] self._job_scheduler.append_job( - job_id, created_at.timestamp() + job.spec.completion_window.expires_at() + job_id, created_at.timestamp() + job.spec.completion_window ) + # For metadata server (no scheduler): prepare job output files when job is committed + elif ( + job.status.output_file_id is None + or job.status.temp_output_file_id is None + or job.status.error_file_id is None + or job.status.temp_error_file_id is None + ) and self._job_entity_manager is not None: + # Try starting job immiediately with job validation. + if not await self.start_execute_job(job_id): + return True + + # Initiate job preparing, see JobDriver for details + logger.info("Starting job preparation for new job", job_id=job_id) # type: ignore[call-arg] + try: + job_driver = JobDriver(self) + prepared_job = await job_driver.prepare_job(job) + + logger.info( + "Job preparation completed, files ready", + job_id=job_id, + output_file_id=prepared_job.status.output_file_id, + temp_output_file_id=prepared_job.status.temp_output_file_id, + error_file_id=prepared_job.status.error_file_id, + temp_error_file_id=prepared_job.status.temp_error_file_id, + ) # type: ignore[call-arg] - def job_updated_handler(self, old_job: BatchJob, new_job: BatchJob): + await self._job_entity_manager.update_job_ready(prepared_job) + + # Leave job_updated_handler to update job location in queues + except Exception as e: + logger.error("Job preparation failed", job_id=job_id, exc_info=True) # type: ignore[call-arg] + await self.mark_job_failed( + job_id, + BatchJobError( + code=BatchJobErrorCode.PREPARE_OUTPUT_ERROR, message=str(e) + ), + ) + # No need to stop job because only update_job_ready will start job. + + return True + + async def job_updated_handler(self, old_job: BatchJob, new_job: BatchJob) -> bool: """ This is called by job entity manager when a job status is updated. Handles state transitions when a job is cancelled or completed. + Validates state transitions according to defined rules. """ - job_id = old_job.job_id - if not job_id: - logger.error("Job ID not found in updated job") - return - if not old_job.status or not new_job.status: - logger.error("Job status not found in updated job", job_id=job_id) # type: ignore[call-arg] - return + try: + job_id = old_job.job_id + if not job_id: + logger.error("Job ID not found in updated job") + return False + + # Categorize jobs + old_category, old_name = self._categorize_jobs(old_job) + new_category, new_name = self._categorize_jobs(new_job) + # Load cache job, possibily with local metainfo. + old_job_in_category = old_category.get(job_id) + if old_job_in_category is None: + logger.warning( + "Job is not in old category, ignore updating", + old_category=old_name, + new_category=new_name, + ) # type: ignore[call-arg] + return False + old_job = old_job_in_category - # Categorize jobs - old_category = self._categorize_jobs(old_job) - new_category = self._categorize_jobs(new_job) - if old_category is new_category: - return + # Validate state transition + if not self._validate_state_transition(old_job, new_job): + logger.warning( + "Invalid state transition for job - rejecting update", + job_id=job_id, + ) # type: ignore[call-arg] + return False + + # Mark post-state-transition flags + finalizing_needed = ( + new_job.status.state == BatchJobState.FINALIZING + and old_job.status.state != BatchJobState.FINALIZING + and self._job_scheduler is None + ) + logger.debug( + "job_updated_handler passed state transition", + old_state=old_job.status.state.value, + new_state=new_job.status.state.value, + finalizing_needed=finalizing_needed, + ) # type: ignore[call-arg] + + # No category change, try update status + if old_category == new_category: + # avoid override local metainfo by update status only + old_job.metadata = new_job.metadata # Update resource version + old_job.status = new_job.status # Update status + new_job = old_job + else: + # Move job from old category to new category + del old_category[job_id] + new_category[job_id] = new_job + logger.debug( + "Job moved to a new category", + old_category=old_name, + new_category=new_name, + ) # type: ignore[call-arg] - # Move job from old category to new category - if job_id in old_category: - del old_category[job_id] - new_category[job_id] = new_job + # For metadata server (no scheduler): finalize job when transitioning to FINALIZING + if finalizing_needed: + try: + logger.info("Starting job finalization", job_id=job_id) # type: ignore[call-arg] + job_driver = JobDriver(self) + await job_driver.finalize_job(new_job) + except Exception as fe: + logger.error( + "Job finalization failed", job_id=job_id, exc_info=True + ) # type: ignore[call-arg] + await self.mark_job_failed( + job_id, + BatchJobError( + code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe) + ), + ) - def job_deleted_handler(self, job: BatchJob): + return True + except Exception: + logger.error("exception in job_updated_handler", exc_info=True) # type: ignore[call-arg] + raise + + async def job_deleted_handler(self, job: BatchJob) -> bool: """ This is called by job entity manager when a job is deleted. """ @@ -437,22 +693,25 @@ def job_deleted_handler(self, job: BatchJob): # [TODO][NEXT] zhangjyr # Remove all related requests from scheduler and proxy, and call job_updated_handler, followed by job_deleted_handler() again. logger.warning("Job is in progress, cannot be deleted", job_id=job_id) # type: ignore[call-arg] - return + return True if job_id in self._pending_jobs: del self._pending_jobs[job_id] + logger.debug("Job removed from a category", category="_pending_jobs") # type: ignore[call-arg] + return True + if job_id in self._done_jobs: del self._done_jobs[job_id] + logger.debug("Job removed from a category", category="_done_jobs") # type: ignore[call-arg] + + return True - def get_job(self, job_id) -> Optional[BatchJob]: + async def get_job(self, job_id) -> Optional[BatchJob]: """ This retrieves a job's status to users. Job scheduler does not need to check job status. It can directly check the job pool for scheduling, such as pending_jobs. """ - if self._job_entity_manager: - return self._job_entity_manager.get_job(job_id) - if job_id in self._pending_jobs: return self._pending_jobs[job_id] elif job_id in self._in_progress_jobs: @@ -462,10 +721,10 @@ def get_job(self, job_id) -> Optional[BatchJob]: return None - def get_job_status(self, job_id: str) -> Optional[BatchJobState]: + async def get_job_status(self, job_id: str) -> Optional[BatchJobStatus]: """Get the current status of a job.""" - job = self.get_job(job_id) - return job.status.state if job else None + job = await self.get_job(job_id) + return job.status if job else None async def list_jobs(self) -> List[BatchJob]: """List all jobs.""" @@ -486,10 +745,18 @@ async def list_jobs(self) -> List[BatchJob]: return all_jobs - async def start_execute_job(self, job_id) -> bool: + async def validate_job(self, meta_data: JobMetaInfo): + """The interface is reserved for tests to hijack job validation""" + await meta_data.validate_job() + + async def start_execute_job(self, job_id: str) -> bool: """ This interface should be called by scheduler. User is not allowed to choose a job to be scheduled. + + DO NOT OVERRIDE THIS IN THE TEST, A JOB SHOULD EITHER: + * in state CREATED and in _pending_job, OR + * not in state CREATED and in _in_progress_jobs. """ if job_id not in self._pending_jobs: logger.warning("Job does not exist - maybe create it first", job_id=job_id) # type: ignore[call-arg] @@ -500,39 +767,35 @@ async def start_execute_job(self, job_id) -> bool: job = self._pending_jobs[job_id] del self._pending_jobs[job_id] + meta_data = JobMetaInfo(job) - meta_data.status.state = BatchJobState.VALIDATING + # In-place status update, will be reflected in the entity_manager if available. + if job.status.state == BatchJobState.CREATED: + # Only update state for first validation. + meta_data.status.state = BatchJobState.VALIDATING self._in_progress_jobs[job_id] = meta_data + logger.debug( + "Job moved to a new category", + old_category="_pending_jobs", + new_category="_in_progress_jobs", + ) # type: ignore[call-arg] try: - await meta_data.validate_job() - meta_data.status.in_progress_at = datetime.now(timezone.utc) - # [TODO][NEXT] Use separate file id - meta_data.status.output_file_id = meta_data.job_id - meta_data.status.state = BatchJobState.IN_PROGRESS + # [TODO][NOW] This should be moved to job_driver. + # We still need to validate job even if it is in progress. + await self.validate_job(meta_data) + # But we do not update state for in-progress job. + if meta_data.status.state == BatchJobState.VALIDATING: + meta_data.status.in_progress_at = datetime.now(timezone.utc) + meta_data.status.state = BatchJobState.IN_PROGRESS except BatchJobError as e: logger.error("Job validation failed", job_id=job_id, error=str(e)) # type: ignore[call-arg] - meta_data.status.state = BatchJobState.FAILED - meta_data.status.failed_at = datetime.now(timezone.utc) - meta_data.status.errors = [e] - del self._in_progress_jobs[job_id] - self._done_jobs[job_id] = meta_data + await self.mark_job_failed(job_id, e) return False return True - def get_job_next_request(self, job_id) -> int: - request_id = -1 - if job_id not in self._in_progress_jobs: - logger.info("Job has not been scheduled yet", job_id=job_id) # type: ignore[call-arg] - return request_id - - job = self._in_progress_jobs[job_id] - assert isinstance(job, JobMetaInfo) - meta_data: JobMetaInfo = job - return meta_data.next_request_id() - - def get_job_endpoint(self, job_id) -> str: + async def get_job_endpoint(self, job_id: str) -> str: if job_id in self._pending_jobs: job = self._pending_jobs[job_id] elif job_id in self._in_progress_jobs: @@ -542,21 +805,35 @@ def get_job_endpoint(self, job_id) -> str: return "" return str(job.spec.endpoint) - def mark_job_progress(self, job_id, executed_requests) -> Optional[BatchJob]: + async def mark_job_progress(self, job_id: str, req_id: int) -> Tuple[BatchJob, int]: """ - This is used to sync job's progress, called by execution proxy. + This is used to sync job's progress, called by job driver. It is guaranteed that each request is executed at least once. + + Raises: + JobUnexpectedStateError: If job is not in progress. """ - if job_id not in self._in_progress_jobs: - logger.info("Job has not started yet", job_id=job_id) # type: ignore[call-arg] - return None + meta_data = await self._meta_from_in_progress_job(job_id) - job = self._in_progress_jobs[job_id] - assert isinstance(job, JobMetaInfo) - meta_data: JobMetaInfo = job + if req_id < 0 or req_id > meta_data.status.request_counts.total: + raise ValueError(f"invalide request_id: {req_id}") - request_len = meta_data.status.request_counts.total + meta_data.complete_one_request(req_id) + return meta_data, meta_data.next_request_id() + + async def mark_jobs_progresses( + self, job_id: str, executed_requests: List[int] + ) -> BatchJob: + """ + This is the batch operation to sync jobs' progresses, called by job driver. + It is guaranteed that each request is executed at least once. + + Raises: + JobUnexpectedStateError: If job is not in progress. + """ + meta_data = await self._meta_from_in_progress_job(job_id) + request_len = meta_data.status.request_counts.total for req_id in executed_requests: if req_id < 0 or req_id > request_len: logger.error( # type: ignore[call-arg] @@ -568,103 +845,199 @@ def mark_job_progress(self, job_id, executed_requests) -> Optional[BatchJob]: continue meta_data.complete_one_request(req_id) - self._in_progress_jobs[job_id] = meta_data return meta_data - def mark_job_done(self, job_id: str) -> Optional[BatchJob]: + async def get_job_next_request(self, job_id: str) -> Tuple[BatchJob, int]: """ - Mark job done. + Get next request id to execute, see JobMetaInfo::next_request_id for details + + Returns: + tuple: (job, next_request_id) or (job, -1) if job is done + + Raises: + JobUnexpectedStateError: If job is not in progress. """ - if job_id not in self._in_progress_jobs: - logger.error( - "Unexpected job queue", job_id=job_id, queue="_in_progress_jobs" - ) # type: ignore[call-arg] - return None + meta_data = await self._meta_from_in_progress_job(job_id) + return meta_data, meta_data.next_request_id() - job = self._in_progress_jobs[job_id] - if job.status.state != BatchJobState.FINALIZING: - logger.error( - "Unexpected job status", job_id=job_id, status=job.status.state.value - ) # type: ignore[call-arg] - return job + async def mark_job_progress_and_get_next_request( + self, job_id: str, req_id: int + ) -> Tuple[BatchJob, int]: + """ + This is used to sync job's progress, called by execution proxy. + It is guaranteed that each request is executed at least once. + + Returns: + tuple: (job, next_request_id) or (job, -1) if job is done + + Raises: + JobUnexpectedStateError: If job is not in progress. + """ + meta_data = await self._meta_from_in_progress_job(job_id) + + meta_data.complete_one_request(req_id) + return meta_data, meta_data.next_request_id() + + async def mark_job_total(self, job_id: str, total_requests: int) -> BatchJob: + """ + This is used to set job's total requests when stream reader sees the end of the request. + + Raises: + JobUnexpectedStateError: If job is not in progress. + """ + job, _ = await self.mark_job_progress(job_id, total_requests + 1) + return job + + async def mark_job_done(self, job_id: str) -> BatchJob: + """ + Mark job done. - del self._in_progress_jobs[job_id] - self._done_jobs[job_id] = job + Raises: + JobUnexpectedStateError: If job is not in progress and not finalizing. + """ + try: + meta_data = await self._meta_from_in_progress_job(job_id) + except JobUnexpectedStateError as juse: + logger.warning(str(juse), state=juse.state) # type: ignore[call-arg] + raise + + if meta_data.status.state != BatchJobState.FINALIZING: + logger.error("Job is not in finalizing state", state=meta_data.status.state) # type: ignore[call-arg] + raise JobUnexpectedStateError( + "Job is not in finalizing state", meta_data.status.state + ) + + job = meta_data.copy() job.status.completed_at = datetime.now(timezone.utc) - job.status.state = BatchJobState.COMPLETED - logger.info("Job is completed", job_id=job_id) # type: ignore[call-arg] + job.status.finalized_at = job.status.completed_at + # Do not override existing condition. Fill up locally for data integrity in case apply_job_changes does nothing + if job.status.condition is None: + job.status.add_condition( + Condition( + type=ConditionType.COMPLETED, + status=ConditionStatus.TRUE, + lastTransitionTime=job.status.completed_at, + ) + ) + job.status.state = BatchJobState.FINALIZED + if not await self.apply_job_changes(job, meta_data): + return meta_data + + logger.info("Job is finalized", job_id=job_id) # type: ignore[call-arg] return job - def mark_job_failed(self, job_id: str) -> Optional[BatchJob]: + async def mark_job_failed(self, job_id: str, ex: BatchJobError) -> BatchJob: """ Mark job failed. - """ - if job_id not in self._in_progress_jobs: - logger.error( - "Unexpected job queue", job_id=job_id, queue="_in_progress_jobs" - ) # type: ignore[call-arg] - return None - job = self._in_progress_jobs[job_id] - del self._in_progress_jobs[job_id] + Raises: + JobUnexpectedStateError: If job is not in progress. + """ + meta_data = await self._meta_from_in_progress_job(job_id) + job = meta_data.copy() job.status.failed_at = datetime.now(timezone.utc) - job.status.state = BatchJobState.FAILED - self._done_jobs[job_id] = job + # Fill up locally for data integrity in case apply_job_changes does nothing + job.status.add_condition( + Condition( + type=ConditionType.FAILED, + status=ConditionStatus.TRUE, + lastTransitionTime=job.status.failed_at, + reason=ex.code, + message=ex.message, + ) + ) + job.status.errors = [ex] + if meta_data.status.state == BatchJobState.IN_PROGRESS: + job.status.finalizing_at = datetime.now(timezone.utc) + job.status.state = BatchJobState.FINALIZING + else: + job.status.finalized_at = job.status.failed_at + job.status.state = BatchJobState.FINALIZED + + if not await self.apply_job_changes(job, meta_data): + return meta_data logger.info("Job failed", job_id=job_id) # type: ignore[call-arg] return job - def expire_job(self, job_id): + async def apply_job_changes( + self, job: BatchJob, old_job: Optional[BatchJob] = None + ) -> bool: """ - This is called by scheduler. When a job arrives at its - specified due time, scheduler will mark this expired. - User can not expire a job, but can cancel a job. + Sync job status to persistent storage by calling update_job_status. + + This persists critical job status information including finalized state, + conditions, request counts, and timestamps to Kubernetes annotations + to ensure job state can be recovered after crashes. + + Args: + job_id: Job ID to sync to storage """ + try: + # Call update directly + if old_job is None: + old_job = await self.get_job(job.job_id) + assert old_job is not None + + # Use the entity manager to persist status + if self._job_entity_manager: + if ( + old_job.status.state == BatchJobState.FINALIZING + or old_job.status.errors is None + ): + await self._job_entity_manager.update_job_status(job) + else: + await self._job_entity_manager.cancel_job(job) - if job_id in self._pending_jobs: - job = self._pending_jobs[job_id] - del self._pending_jobs[job_id] - job.status.state = BatchJobState.EXPIRED - self._done_jobs[job_id] = job - elif job_id in self._in_progress_jobs: - # Now a job can not be expired once it gets scheduled, considering - # that expiring a partial executed job wastes resources. - # Later we may apply another policy to force a job to expire - # regardless of its current progress. - logger.warning("Job was scheduled and cannot expire", job_id=job_id) # type: ignore[call-arg] - return False + logger.debug( + "Job status synced to job entity manager", + job_id=job.job_id, + state=job.status.state, + condition=job.status.condition, + ) # type: ignore[call-arg] + return True - elif job_id in self._done_jobs: - logger.error("Job is done and this should not happen", job_id=job_id) # type: ignore[call-arg] + logger.debug("Job status synced to job entity manager") + await self.job_updated_handler(old_job, job) + return True + except Exception as e: + logger.error( + "Failed to apply job changes", + job_id=job.job_id, + error=str(e), + ) # type: ignore[call-arg] + # Don't re-raise - this is a background sync operation return False - return True + async def _meta_from_in_progress_job(self, job_id: str) -> JobMetaInfo: + if job_id not in self._in_progress_jobs: + job = await self.get_job(job_id) + raise JobUnexpectedStateError( + "Job has not been scheduled yet or has been scheduled", + job.status.state if job else None, + ) - def sync_job_to_storage(self, jobId): - """ - [TODO] Xin - This is used to serialize everything here to storage to make sure - that job manager can restart it over from storage once it crashes - or intentional quit. - """ - pass + job = self._in_progress_jobs[job_id] + assert isinstance(job, JobMetaInfo) + meta_data: JobMetaInfo = job + return meta_data - def _categorize_jobs(self, job: BatchJob) -> dict[str, BatchJob]: + def _categorize_jobs( + self, job: BatchJob, first_seen: bool = False + ) -> Tuple[dict[str, BatchJob], str]: """ This is used to categorize jobs into pending, in progress, and done. """ if not job.status: - return self._pending_jobs + return self._pending_jobs, "_pending_jobs" if job.status.state == BatchJobState.CREATED: - return self._pending_jobs - elif job.status.state in [ - BatchJobState.COMPLETED, - BatchJobState.FAILED, - BatchJobState.EXPIRED, - BatchJobState.CANCELED, - ]: - return self._done_jobs + return self._pending_jobs, "_pending_jobs" + elif job.status.finished: + return self._done_jobs, "_done_jobs" + elif first_seen and self._job_scheduler: + # We need to pending jobs to be scheduled to make progress + return self._pending_jobs, "_pending_jobs" else: - return self._in_progress_jobs + return self._in_progress_jobs, "_in_progress_jobs" diff --git a/python/aibrix/aibrix/batch/job_progress_manager.py b/python/aibrix/aibrix/batch/job_progress_manager.py new file mode 100644 index 000000000..9005108b3 --- /dev/null +++ b/python/aibrix/aibrix/batch/job_progress_manager.py @@ -0,0 +1,137 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Protocol, Tuple + +from .job_entity import BatchJob, BatchJobError, BatchJobStatus + + +class JobProgressManager(Protocol): + """Protocol for managing job progress and status tracking. + + This protocol defines the interface that JobDriver uses to interact with + job management functionality, providing a clean separation of concerns + between job execution logic and job lifecycle management. + """ + + async def get_job(self, job_id: str) -> Optional[BatchJob]: + """Get job by ID. + + Args: + job_id: Job identifier + + Returns: + BatchJob if found, None otherwise + """ + ... + + async def get_job_status(self, job_id: str) -> Optional[BatchJobStatus]: + """Get the current status of a job.""" + ... + + async def start_execute_job(self, job_id) -> bool: + """ + This interface should be called by scheduler. + User is not allowed to choose a job to be scheduled. + """ + ... + + async def mark_job_total(self, job_id: str, total_requests: int) -> BatchJob: + """Mark the total number of requests for a job. + + Args: + job_id: Job identifier + total_requests: Total number of requests in the job + + Returns: + Updated BatchJob + + Raises: + JobUnexpectedStateError: If job is not in progress + """ + ... + + async def mark_job_done(self, job_id: str) -> BatchJob: + """Mark job as completed. + + Args: + job_id: Job identifier + + Returns: + Updated BatchJob + + Raises: + JobUnexpectedStateError: If job is not in finalizing state + """ + ... + + async def mark_job_failed(self, job_id: str, ex: BatchJobError) -> BatchJob: + """Mark job as failed. + + Args: + job_id: Job identifier + ex: BatchJobError that cause the failure + + Raises: + JobUnexpectedStateError: If job is not in progress. + """ + ... + + async def mark_jobs_progresses( + self, job_id: str, executed_requests: List[int] + ) -> BatchJob: + """Mark multiple requests as completed. + + Args: + job_id: Job identifier + executed_requests: List of request IDs that have been completed + + Returns: + Updated BatchJob + + Raises: + JobUnexpectedStateError: If job is not in progress + """ + ... + + async def get_job_next_request(self, job_id: str) -> Tuple[BatchJob, int]: + """Get the next request ID to execute. + + Args: + job_id: Job identifier + + Returns: + Tuple of (BatchJob, next_request_id) or (BatchJob, -1) if job is done + + Raises: + JobUnexpectedStateError: If job is not in progress + """ + ... + + async def mark_job_progress_and_get_next_request( + self, job_id: str, req_id: int + ) -> Tuple[BatchJob, int]: + """Mark a request as completed and get the next request ID. + + Args: + job_id: Job identifier + req_id: Request ID that was completed + + Returns: + Tuple of (BatchJob, next_request_id) or (BatchJob, -1) if job is done + + Raises: + JobUnexpectedStateError: If job is not in progress + """ + ... diff --git a/python/aibrix/aibrix/batch/request_proxy.py b/python/aibrix/aibrix/batch/request_proxy.py deleted file mode 100644 index 6923bf72f..000000000 --- a/python/aibrix/aibrix/batch/request_proxy.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2024 The Aibrix Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from typing import Optional - -import aibrix.batch.storage as storage -from aibrix.batch.job_entity import BatchJob, BatchJobState -from aibrix.batch.job_manager import JobManager -from aibrix.metadata.logger import init_logger - -logger = init_logger(__name__) - - -class RequestProxy: - def __init__(self, manager) -> None: - """ """ - self._job_manager: JobManager = manager - self._inference_client = InferenceEngineClient() - - async def execute_queries(self, job_id): - """ - This is the entrance to inference engine. - This fetches request input from storage and submit request - to inference engine. Lastly the result is stored back to storage. - """ - # Verify job status and get minimum unfinished request id - request_id = self._job_manager.get_job_next_request(job_id) - if request_id == -1: - logger.warning( - "Job has something wrong with metadata in job manager, nothing left to execute", - job_id=job_id, - ) - return - - job = self._job_manager.get_job(job_id) - - if request_id == 0: - logger.debug("Start processing job", job_id=job_id) - else: - logger.debug("Resuming job", job_id=job_id, request_id=request_id) - - # Step 1: Prepare job output files. - await storage.prepare_job_ouput_files(job) - - # Step 2: Execute requests, resumable. - line_no = request_id - async for request in storage.read_job_next_request(job, request_id): - logger.debug( - "Read job request, checking completion status", - job_id=job_id, - line=line_no, - next_unfinished=request_id, - ) - # Skip completed requests - if line_no < request_id: - continue - - logger.debug("Executing job request", job_id=job_id, request_id=request_id) - request_output = self._inference_client.inference_request( - job.spec.endpoint, request - ) - await storage.write_job_output_data(job, request_id, [request_output]) - # Request next id to avoid state becoming FINALIZING by make total > request_id - logger.debug("Job request executed", job_id=job_id, request_id=request_id) - job = self.sync_job_status(job_id, request_id) - - request_id = self._job_manager.get_job_next_request(job_id) - line_no += 1 - - job = self.sync_job_status( - job_id, request_id + 1 - ) # Now that total == request_id - logger.debug( - "Finalizing job", - job_id=job_id, - total=job.status.request_counts.total, - state=job.status.state.value, - ) - assert job is not None - assert job.status.state == BatchJobState.FINALIZING - - # Step 3: Aggregate outputs. - await storage.finalize_job_output_data(job) - - logger.debug("Completed job", job_id=job_id) - self.sync_job_status(job_id) - - def store_output(self, output_id, request_id, result): - """ - Write the request result back to storage. - """ - storage.put_job_results(output_id, request_id, [result]) - - def sync_job_status(self, job_id, reqeust_id=-1) -> Optional[BatchJob]: - """ - Update job's status back to job manager. - """ - if reqeust_id < 0: - return self._job_manager.mark_job_done(job_id) - else: - return self._job_manager.mark_job_progress(job_id, [reqeust_id]) - - -class InferenceEngineClient: - def __init__(self): - """ - Initiate client to inference engine, such as account - and its authentication. - """ - pass - - def inference_request(self, endpoint, prompt_list): - time.sleep(1) - return prompt_list diff --git a/python/aibrix/aibrix/batch/scheduler.py b/python/aibrix/aibrix/batch/scheduler.py index c693276dc..28d418b0d 100644 --- a/python/aibrix/aibrix/batch/scheduler.py +++ b/python/aibrix/aibrix/batch/scheduler.py @@ -20,8 +20,12 @@ from enum import Enum from typing import Optional -from aibrix.batch.constant import DEFAULT_JOB_POOL_SIZE, EXPIRE_INTERVAL -from aibrix.metadata.logger import init_logger +import aibrix.batch.constant as constant +from aibrix.batch.job_progress_manager import JobProgressManager +from aibrix.logger import init_logger + +from .job_driver import InferenceEngineClient, JobDriver +from .job_entity import BatchJobError, BatchJobErrorCode # JobManager will be passed as parameter to avoid circular import @@ -129,28 +133,26 @@ def grow_resource(self): class JobScheduler: def __init__( self, - job_manager, - pool_size, - cc_controller=BasicCongestionControl(DEFAULT_JOB_POOL_SIZE), + job_progress_manager: JobProgressManager, + pool_size: int, + cc_controller=BasicCongestionControl(constant.DEFAULT_JOB_POOL_SIZE), policy=SchedulePolicy.FIFO, - ): + ) -> None: """ self._jobs_queue are all the jobs. self._due_jobs_list stores all potential jobs that can be marked as expired jobs. self._inactive_jobs are jobs that are already invalid. """ - self._job_manager = job_manager - self.interval = EXPIRE_INTERVAL - self._jobs_queue = queue.Queue() - self._inactive_jobs = set() - self._due_jobs_list = [] + self._job_progress_manager = job_progress_manager + self.interval = constant.EXPIRE_INTERVAL + self._jobs_queue: queue.Queue[str] = queue.Queue() + self._inactive_jobs: set[str] = set() + self._due_jobs_list: list[tuple[str, float]] = [] self._CC_controller = cc_controller self._current_pool_size = self._CC_controller._job_pool_size # Start the loop process in an async way - self._job_cleanup_loop = asyncio.get_running_loop() - self._job_cleanup_task = asyncio.create_task(self.job_cleanup_loop()) self._policy = policy def configure_job_pool_size(self, new_pool_size): @@ -180,16 +182,16 @@ async def schedule_next_job(self) -> Optional[str]: if self._policy == SchedulePolicy.FIFO: if self._jobs_queue.empty(): logger.debug("Job scheduler is waiting jobs coming") - await asyncio.sleep(1) + await asyncio.sleep(self.interval) if not self._jobs_queue.empty(): job_id = self._jobs_queue.get() - logger.debug("Job scheduler is scheduling job", job_id=job_id) # type: ignore[call-arg] + logger.info("Job scheduler is scheduling job", job_id=job_id) # type: ignore[call-arg] # Every time when popping a job from queue, # we check if this job is in active state and we try starting the job. while job_id and ( job_id in self._inactive_jobs - or not await self._job_manager.start_execute_job(job_id) + or not await self._job_progress_manager.start_execute_job(job_id) ): if self._jobs_queue.empty(): job_id = None @@ -214,7 +216,8 @@ async def expire_jobs(self): ): idx += 1 - logger.info("Found expired jobs", count=idx) + if idx > 0: + logger.info("Found expired jobs", count=idx) for i in range(idx): # Update job's status to job manager job_id = self._due_jobs_list[i][0] @@ -227,11 +230,66 @@ async def expire_jobs(self): policy=str(self._policy), ) # type: ignore[call-arg] - async def job_cleanup_loop(self): + async def start(self, inference_client: Optional[InferenceEngineClient]): + self._serve_loop = asyncio.get_running_loop() + logger.info("in start") + self._jobs_running_task = self._serve_loop.create_task( + self.jobs_running_loop(inference_client) + ) + logger.info("running loop set up") + self._jobs_cleanup_task = self._serve_loop.create_task(self.jobs_cleanup_loop()) + logger.info("cleanup loop set up") + + async def jobs_running_loop( + self, inference_client: Optional[InferenceEngineClient] + ): + """ + This loop is going through all active jobs in scheduler. + For now, the executing unit is one request. Later if necessary, + we can support a batch size of request per execution. + """ + logger.info("Starting scheduling...") + job_driver = JobDriver(self._job_progress_manager, inference_client) + while True: + try: + one_job = await self.round_robin_get_job() + except Exception as e: + logger.error( + "Failed to schedule job", + error=str(e), + ) # type: ignore[call-arg] + + if one_job: + try: + await job_driver.execute_job(one_job) + except RuntimeError as re: + logger.error( + "Runtime err", + job_id=one_job, + error=str(re), + ) # type: ignore[call-arg] + raise + except Exception as e: + job = await self._job_progress_manager.mark_job_failed( + one_job, + BatchJobError( + code=BatchJobErrorCode.INFERENCE_FAILED, message=str(e) + ), + ) + logger.error( + "Failed to execute job", + job_id=one_job, + status=job.status.state.value, + error=str(e), + ) # type: ignore[call-arg] + raise + # yield loop + await asyncio.sleep(0) + + async def jobs_cleanup_loop(self): """ This is a long-running process to check if jobs have expired or not. """ - round_id = 0 while True: start_time = time.time() # Record start time await self.expire_jobs() # Run the process @@ -239,24 +297,22 @@ async def job_cleanup_loop(self): time_to_next_run = max( 0, self.interval - elapsed_time ) # Calculate remaining time - logger.debug("Job cleanup loop iteration", round_id=round_id) - round_id += 1 await asyncio.sleep(time_to_next_run) # Wait for the remaining time - async def close(self): + async def stop(self): """Properly shutdown the driver and cancel running tasks""" - loop = asyncio.get_running_loop() - if self._job_cleanup_loop and loop is not self._job_cleanup_loop: - try: - asyncio.run_coroutine_threadsafe( - self.close(), self._job_cleanup_loop - ).result(timeout=5) - except Exception: - pass - return - - if self._job_cleanup_task and not self._job_cleanup_task.done(): - self._job_cleanup_task.cancel() + assert getattr(self, "_serve_loop") == asyncio.get_running_loop() + # Cancel running loop + if not self._jobs_running_task.done(): + self._jobs_running_task.cancel() + # wait _jobs_running_task for capturing any exception + try: + await self._jobs_running_task + except asyncio.CancelledError: + pass + # Cancel cleanup loop + if not self._jobs_cleanup_task.done(): + self._jobs_cleanup_task.cancel() async def round_robin_get_job(self): # Step 1 @@ -268,7 +324,7 @@ async def round_robin_get_job(self): job_id = self._CC_controller._running_job_pool[i] # Do not schedule new job in since we need to adjust capacity # based on new pool size representing how much underlying resource. - if self._job_manager.get_job_status(job_id).is_finished(): + if (await self._job_progress_manager.get_job_status(job_id)).finished: self._CC_controller._running_job_pool[i] = None # Step 2, after the jobs' status are updated, diff --git a/python/aibrix/aibrix/batch/storage/__init__.py b/python/aibrix/aibrix/batch/storage/__init__.py index b90470ef0..b9acbe2dd 100644 --- a/python/aibrix/aibrix/batch/storage/__init__.py +++ b/python/aibrix/aibrix/batch/storage/__init__.py @@ -16,7 +16,9 @@ StorageType, download_output_data, finalize_job_output_data, + get_storage_type, initialize_storage, + is_request_done, prepare_job_ouput_files, read_job_input_info, read_job_next_request, @@ -28,9 +30,11 @@ __all__ = [ "StorageType", "initialize_storage", + "get_storage_type", "upload_input_data", "read_job_input_info", "read_job_next_request", + "is_request_done", "prepare_job_ouput_files", "write_job_output_data", "finalize_job_output_data", diff --git a/python/aibrix/aibrix/batch/storage/adapter.py b/python/aibrix/aibrix/batch/storage/adapter.py index 318b97ad8..23112872d 100644 --- a/python/aibrix/aibrix/batch/storage/adapter.py +++ b/python/aibrix/aibrix/batch/storage/adapter.py @@ -15,15 +15,18 @@ import asyncio import json import uuid -from typing import Any, AsyncIterator, Dict, List, Tuple +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple from aibrix.batch.job_entity import BatchJob from aibrix.batch.storage.batch_metastore import ( delete_metadata, get_metadata, - set_metadata, + is_request_done, + list_metastore_keys, + lock_request, + unlock_request, ) -from aibrix.metadata.logger import init_logger +from aibrix.logger import init_logger from aibrix.storage.base import BaseStorage logger = init_logger(__name__) @@ -74,14 +77,16 @@ async def read_job_input_info(self, job: BatchJob) -> Tuple[int, bool]: async def read_job_next_input_data( self, job: BatchJob, start_index: int ) -> AsyncIterator[Dict[str, Any]]: - """Read job input data line by line. + """Read job input data line by line with request locking. Args: job: BatchJob + start_index: Starting line index Returns: - AsyncIterator of lines + AsyncIterator of lines that were successfully locked for processing """ + idx = start_index # Use 'async for' to iterate and 'yield' each item. async for line in self.storage.readline_iter( job.spec.input_file_id, start_index @@ -89,9 +94,57 @@ async def read_job_next_input_data( line = line.strip() if len(line) == 0: continue - yield json.loads(line) - async def prepare_job_ouput_files(self, job: BatchJob) -> None: + # Try to lock this request before processing + lock_key = self._get_request_meta_output_key(job, idx) + try: + # Try to acquire lock with 1 hour expiration + locked = await lock_request(lock_key, expiration_seconds=3600) + except Exception as e: + # Lock operation failed (should not happen with return False requirement) + logger.warning( + "Error on locking request in the job, assuming locking not supported", + job_id=job.job_id, + line_no=idx, + error=e, + ) # type:ignore[call-arg] + locked = True + + if locked: + # Successfully locked, yield the request data + request_data = json.loads(line) + request_data["_request_index"] = idx # Add index for tracking + logger.debug( + "Locked and will processing request in the job", + job_id=job.job_id, + line_no=idx, + requset=request_data, + ) # type:ignore[call-arg] + yield request_data + else: + # Request already locked by another worker, skip it + logger.debug( + "Skipping already locked request in the job", + job_id=job.job_id, + line_no=idx, + ) # type:ignore[call-arg] + + idx += 1 + + async def is_request_done(self, job: BatchJob, request_index: int) -> bool: + """Check if a request is done. + + Args: + job: BatchJob + request_index: Index of the request being processed + + Returns: + True if the request is done, False otherwise + """ + lock_key = self._get_request_meta_output_key(job, request_index) + return await is_request_done(lock_key) + + async def prepare_job_ouput_files(self, job: BatchJob) -> BatchJob: """Get job output file id. Args: @@ -101,7 +154,7 @@ async def prepare_job_ouput_files(self, job: BatchJob) -> None: Job output file id """ if job.status.temp_output_file_id or job.status.temp_error_file_id: - return + return job job_uuid = uuid.UUID(job.job_id) job.status.output_file_id, job.status.error_file_id = ( @@ -120,16 +173,17 @@ async def prepare_job_ouput_files(self, job: BatchJob) -> None: job.status.temp_output_file_id, job.status.temp_error_file_id, ) = await asyncio.gather(*tasks) + return job async def write_job_output_data( - self, job: BatchJob, start_index: int, output_list: List[Dict[str, Any]] + self, job: BatchJob, request_index: int, output_data: Dict[str, Any] ) -> None: - """Write job results to storage. + """Write job result to storage and unlock the request. Args: - job_id: Job identifier - start_index: Starting index for the results - output_list: List of result dictionaries + job: BatchJob object + request_index: Index of the request being processed + output_data: Single result dictionary """ assert ( job.status.output_file_id @@ -137,60 +191,132 @@ async def write_job_output_data( and job.status.temp_output_file_id and job.status.temp_error_file_id ) - for i, result_data in enumerate(output_list): - idx = start_index + i - json_str = json.dumps(result_data) + "\n" - is_error = "error" in result_data - etag = await self.storage.upload_part( - job.status.error_file_id if is_error else job.status.output_file_id, - job.status.temp_error_file_id - if is_error - else job.status.temp_output_file_id, - idx, - json_str, - ) - # Store metadata - await set_metadata( - self._get_request_meta_output_key(job, idx), - self._get_request_meta_output_val(is_error, etag), - ) + + json_str = json.dumps(output_data) + "\n" + is_error = "error" in output_data and output_data["error"] is not None + etag = await self.storage.upload_part( + job.status.error_file_id if is_error else job.status.output_file_id, + job.status.temp_error_file_id + if is_error + else job.status.temp_output_file_id, + request_index, + json_str, + ) + + # Unlock the request by setting completion status + unlock_key = self._get_request_meta_output_key(job, request_index) + completion_status = self._get_request_meta_output_val(is_error, etag) + await unlock_request(unlock_key, completion_status) logger.debug( - f"Stored {len(output_list)} results for job {job.job_id} starting at index {start_index}" + f"Stored result for job {job.job_id} request {request_index}, status: {completion_status}" ) async def finalize_job_output_data(self, job: BatchJob) -> None: - assert ( - job.status.output_file_id - and job.status.error_file_id - and job.status.temp_output_file_id - and job.status.temp_error_file_id - ) + if ( + job.status.output_file_id is None + or job.status.error_file_id is None + or job.status.temp_output_file_id is None + or job.status.temp_error_file_id is None + ): + # Do nothing + return + + # 1. List all keys from metastore with the job prefix + prefix = self._get_request_meta_output_key(job, None) + all_keys = await list_metastore_keys(prefix) + + logger.debug( + "Metastore keys found during job finalizing", + job_id=job.job_id, + prefix=prefix, + keys=all_keys, + ) # type: ignore[call-arg] + + # 2. Extract indices from keys and determine maximum index for total count + indices = [] + for key in all_keys: + # Extract index from key format: batch:{job_id}:done/{idx} + try: + idx_str = key[len(prefix) :] # Get the index part + idx = int(idx_str) + indices.append(idx) + except ValueError: + logger.warning( + "Invalid key format found in metastore", + key=key, + job_id=job.job_id, + ) # type: ignore[call-arg] + continue + + # Sort indices to ensure proper ordering + indices.sort() + + # 3. Calculate actual counts based on metastore keys + launched = len(indices) + total = indices[-1] + 1 if launched > 0 else 0 + + logger.info( + "Finalizing job output data using metastore keys", + job_id=job.job_id, + launched=launched, + total=total, + ) # type: ignore[call-arg] + # Fetch metadata for all found keys + keys = [self._get_request_meta_output_key(job, idx) for idx in indices] + etag_results = await asyncio.gather(*[get_metadata(key) for key in keys]) + + # Process results and categorize into outputs and errors output: List[Dict[str, str | int]] = [] error: List[Dict[str, str | int]] = [] - keys = [] - for i in range(job.status.request_counts.launched): - keys.append(self._get_request_meta_output_key(job, i)) + completed = 0 + failed = 0 + valid_keys = [] - etag_results = await asyncio.gather(*[get_metadata(key) for key in keys]) - exists = 0 - for i, etag_result in enumerate(etag_results): + for idx, key, etag_result in zip(indices, keys, etag_results): meta_val, exist = etag_result if not exist: continue - # Compact keys - keys[exists] = keys[i] - exists += 1 - etag, is_error = self._parse_request_meta_output_val(meta_val) - val: Dict[str, str | int] = {"etag": etag, "part_number": i} + if etag == "": + continue + + valid_keys.append(key) + val: Dict[str, str | int] = {"etag": etag, "part_number": idx} + if is_error: error.append(val) + failed += 1 else: output.append(val) - keys = keys[:exists] + completed += 1 + + # 4. Update job object with calculated request counts if they differ + if ( + job.status.request_counts.total != total + or job.status.request_counts.launched != launched + or job.status.request_counts.completed != completed + or job.status.request_counts.failed != failed + ): + logger.info( + "Updating job request counts based on metastore data", + job_id=job.job_id, + old_total=job.status.request_counts.total, + new_total=total, + old_launched=job.status.request_counts.launched, + new_launched=launched, + old_completed=job.status.request_counts.completed, + new_completed=completed, + old_failed=job.status.request_counts.failed, + new_failed=failed, + ) # type: ignore[call-arg] + + job.status.request_counts.total = total + job.status.request_counts.launched = launched + job.status.request_counts.completed = completed + job.status.request_counts.failed = failed # Aggregate results await asyncio.gather( @@ -206,8 +332,9 @@ async def finalize_job_output_data(self, job: BatchJob) -> None: ), ) - # Delete metadata - await asyncio.gather(*[delete_metadata(key) for key in keys]) + # Delete metadata for valid keys only + if valid_keys: + await asyncio.gather(*[delete_metadata(key) for key in valid_keys]) async def read_job_output_data(self, file_id: str) -> List[Dict[str, Any]]: """Read job results output from storage. @@ -238,12 +365,24 @@ async def delete_job_data(self, file_id: str) -> None: except Exception as e: logger.error(f"Failed to delete data for job {file_id}: {e}") - def _get_request_meta_output_key(self, job: BatchJob, idx: int) -> str: - return f"batch:{job.job_id}:output:{idx}" + def _get_request_meta_output_key(self, job: BatchJob, idx: Optional[int]) -> str: + prefix = f"batch:{job.job_id}:done/" + if idx is None: + return prefix + return f"{prefix}{idx}" def _get_request_meta_output_val(self, is_error: bool, etag: str) -> str: return f"{'error' if is_error else 'output'}:{etag}" def _parse_request_meta_output_val(self, meta_val: str) -> Tuple[str, bool]: - is_error, etag = meta_val.split(":", 1) - return etag, is_error == "error" + """valid output can be: + 1. output:[etag] + 2. error:[etag] + 3. processing + """ + status = meta_val.split(":", 1) + if len(status) == 2: + is_error, etag = status + return etag, is_error == "error" + else: + return "", False diff --git a/python/aibrix/aibrix/batch/storage/batch_metastore.py b/python/aibrix/aibrix/batch/storage/batch_metastore.py index 25872e24d..df59e8715 100644 --- a/python/aibrix/aibrix/batch/storage/batch_metastore.py +++ b/python/aibrix/aibrix/batch/storage/batch_metastore.py @@ -14,13 +14,16 @@ from typing import Optional, Tuple -from aibrix.metadata.logger import init_logger +from aibrix import envs +from aibrix.logger import init_logger from aibrix.storage import BaseStorage, StorageType, create_storage +from aibrix.storage.base import PutObjectOptionsBuilder logger = init_logger(__name__) p_metastore: Optional[BaseStorage] = None NUM_REQUESTS_PER_READ = 1024 +STATUS_RUQUEST_LOCKING = "processing" def initialize_batch_metastore(storage_type=StorageType.AUTO, params={}): @@ -34,24 +37,63 @@ def initialize_batch_metastore(storage_type=StorageType.AUTO, params={}): """ global p_metastore + if storage_type == StorageType.AUTO and envs.STORAGE_REDIS_HOST: + storage_type = StorageType.REDIS + # Create new storage instance and wrap with adapter try: + logger.info( + "Initializing batch metastore", storage_type=storage_type, params=params + ) # type: ignore[call-arg] p_metastore = create_storage(storage_type, base_path=".metastore", **params) - logger.info(f"Initialized batch metastore with type: {storage_type}") except Exception as e: - logger.error(f"Failed to initialize storage: {e}") + logger.error("Failed to initialize metastore", error=str(e)) # type: ignore[call-arg] raise -async def set_metadata(key: str, value: str) -> None: - """Set metadata to metastore. +def get_metastore_type() -> StorageType: + """Get the type of metastore. + + Returns: + Type of type of storage that backs metastore + """ + assert p_metastore is not None + return p_metastore.get_type() + + +async def set_metadata( + key: str, + value: str, + expiration_seconds: Optional[int] = None, + if_not_exists: bool = False, +) -> bool: + """Set metadata to metastore with advanced options. Args: key: Metadata key value: Metadata value + expiration_seconds: TTL in seconds for the key (Redis only) + if_not_exists: Only set if key doesn't exist (Redis NX operation) + + Returns: + True if metadata was set, False if conditional operation failed + + Raises: + ValueError: If unsupported options are used with non-Redis storage """ assert p_metastore is not None - await p_metastore.put_object(key, value) + + # Build options if needed + options = None + if expiration_seconds is not None or if_not_exists: + builder = PutObjectOptionsBuilder() + if expiration_seconds is not None: + builder.ttl_seconds(expiration_seconds) + if if_not_exists: + builder.if_not_exists() + options = builder.build() + + return await p_metastore.put_object(key, value, options=options) async def get_metadata(key: str) -> Tuple[str, bool]: @@ -80,3 +122,78 @@ async def delete_metadata(key: str) -> None: """ assert p_metastore is not None await p_metastore.delete_object(key) + + +async def lock_request(key: str, expiration_seconds: int = 3600) -> bool: + """Lock a request for processing. + + Args: + key: Request key to lock + expiration_seconds: TTL for the lock (default 1 hour) + + Returns: + True if lock was acquired, False if already locked + + Raises: + ValueError: If storage doesn't support locking operations + """ + return await set_metadata( + key, + STATUS_RUQUEST_LOCKING, + expiration_seconds=expiration_seconds, + if_not_exists=True, + ) + + +async def is_request_done(key: str) -> bool: + """Check if a request is done. + + Args: + key: Request key to check + + Returns: + True if the request is done, False otherwise + """ + status, got = await get_metadata(key) + return got and status != STATUS_RUQUEST_LOCKING + + +async def unlock_request(key: str, status: str) -> bool: + """Unlock a request by setting completion status. + + Args: + key: Request key to unlock + status: Completion status (e.g., "output:etag" or "error:etag") + + Returns: + True if status was set successfully + """ + return await set_metadata(key, status) + + +async def list_metastore_keys(prefix: str) -> list[str]: + """List all keys from metastore matching the given prefix. + + Args: + prefix: Key prefix to filter + + Returns: + List of keys matching the prefix + """ + assert p_metastore is not None + + keys = [] + continuation_token = None + + while True: + batch_keys, continuation_token = await p_metastore.list_objects( + prefix=prefix, + continuation_token=continuation_token, + limit=1000, # Process in batches of 1000 + ) + keys.extend(batch_keys) + + if continuation_token is None: + break + + return keys diff --git a/python/aibrix/aibrix/batch/storage/batch_storage.py b/python/aibrix/aibrix/batch/storage/batch_storage.py index 36714a5a9..d38d8e3cd 100644 --- a/python/aibrix/aibrix/batch/storage/batch_storage.py +++ b/python/aibrix/aibrix/batch/storage/batch_storage.py @@ -13,16 +13,16 @@ # limitations under the License. import uuid -from typing import Any, AsyncIterator, Dict, List, Tuple +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple from aibrix.batch.job_entity import BatchJob from aibrix.batch.storage.adapter import BatchStorageAdapter -from aibrix.metadata.logger import init_logger +from aibrix.logger import init_logger from aibrix.storage import StorageType, create_storage logger = init_logger(__name__) -p_storage = None +p_storage: Optional[BatchStorageAdapter] = None def initialize_storage(storage_type=StorageType.AUTO, params={}): @@ -38,14 +38,26 @@ def initialize_storage(storage_type=StorageType.AUTO, params={}): # Create new storage instance and wrap with adapter try: + logger.info( + "Initializing batch storage", storage_type=storage_type, params=params + ) # type: ignore[call-arg] storage = create_storage(storage_type, **params) p_storage = BatchStorageAdapter(storage) - logger.info(f"Initialized batch storage with type: {storage_type}") except Exception as e: - logger.error(f"Failed to initialize storage: {e}") + logger.error("Failed to initialize storage", error=str(e)) # type: ignore[call-arg] raise +def get_storage_type() -> StorageType: + """Get the type of storage. + + Returns: + Type of storage. + """ + assert p_storage is not None + return p_storage.storage.get_type() + + async def upload_input_data(inputDataFileName: str) -> str: """Upload job input data file to storage. @@ -88,24 +100,38 @@ async def read_job_next_request( yield data -async def prepare_job_ouput_files(job: BatchJob) -> None: +async def is_request_done(job: BatchJob, request_index: int) -> bool: + """Check if a request is done. + + Args: + job: BatchJob + request_index: Index of the request being processed + + Returns: + True if the request is done, False otherwise + """ + assert p_storage is not None + return await p_storage.is_request_done(job, request_index) + + +async def prepare_job_ouput_files(job: BatchJob) -> BatchJob: """Prepare job output files, including output and error file ids""" assert p_storage is not None - await p_storage.prepare_job_ouput_files(job) + return await p_storage.prepare_job_ouput_files(job) async def write_job_output_data( - job: BatchJob, start_index: int, output_list: List[Dict[str, Any]] + job: BatchJob, request_index: int, output_data: Dict[str, Any] ) -> None: - """Write job results to storage. + """Write job result to storage and unlock the request. Args: - job_id: Job identifier - start_index: Starting index for the results - output_list: List of result dictionaries + job: BatchJob object + request_index: Index of the request being processed + output_data: Single result dictionary """ assert p_storage is not None - await p_storage.write_job_output_data(job, start_index, output_list) + await p_storage.write_job_output_data(job, request_index, output_data) async def finalize_job_output_data(job: BatchJob) -> None: diff --git a/python/aibrix/aibrix/batch/worker.py b/python/aibrix/aibrix/batch/worker.py new file mode 100644 index 000000000..7cbf5ea37 --- /dev/null +++ b/python/aibrix/aibrix/batch/worker.py @@ -0,0 +1,610 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import asyncio +import os +import signal +import subprocess +import sys +import time +from datetime import datetime, timezone +from typing import Optional +from urllib.parse import urlparse + +import httpx + +import aibrix.batch.constant as constant +from aibrix.batch.driver import BatchDriver +from aibrix.batch.job_entity import ( + BatchJob, + BatchJobSpec, + BatchJobState, + BatchJobStatus, + BatchJobTransformer, +) +from aibrix.logger import init_logger + +logger = init_logger(__name__) + + +class LLMHealthChecker: + """Health checker for vLLM service readiness.""" + + def __init__( + self, + health_url: str, + check_interval: int = 1, + timeout: int = 300, + ): + self.health_url = health_url + self.check_interval = check_interval + self.timeout = timeout + + async def wait_for_ready(self) -> bool: + """Wait for vLLM service to become ready.""" + logger.info( + "Waiting for vLLM service to become ready", health_url=self.health_url + ) # type: ignore[call-arg] + + start_time = time.time() + async with httpx.AsyncClient() as client: + while time.time() - start_time < self.timeout: + try: + response = await client.get(self.health_url, timeout=5.0) + if response.status_code == 200: + logger.info("vLLM service is ready") + return True + except (httpx.RequestError, httpx.TimeoutException) as e: + logger.debug("vLLM not ready yet", error=str(e)) # type: ignore[call-arg] + + await asyncio.sleep(self.check_interval) + + logger.error( + "LLM engine did not become ready within timeout", + timeout_seconds=self.timeout, + ) # type: ignore[call-arg] + return False + + +class BatchWorker: + """Batch worker that processes jobs using the sidecar pattern.""" + + def __init__(self) -> None: + self.health_checker: Optional[LLMHealthChecker] = None + self.driver: Optional[BatchDriver] = None + self.llm_engine_base_url: Optional[str] = None + + def load_job_from_env(self) -> BatchJob: + """Generate BatchJob from environment variables set by pod annotations.""" + logger.info("Loading job specification from environment variables...") + + # Get basic job information + job_name = os.getenv("JOB_NAME") + job_ns = os.getenv("JOB_NAMESPACE") + job_id = os.getenv("JOB_UID") + + if not job_id: + raise ValueError("JOB_UID environment variable is required") + if not job_name: + raise ValueError("JOB_NAME environment variable is required") + if not job_ns: + raise ValueError("JOB_NAMESPACE environment variable is required") + + # Get batch job metadata from environment variables + input_file_id = os.getenv("BATCH_INPUT_FILE_ID") + endpoint = os.getenv("BATCH_ENDPOINT") + opts: dict[str, str] = {} + if ( + failed_after_after_n_requests := os.getenv( + "BATCH_OPTS_FAIL_AFTER_N_REQUESTS" + ) + ) is not None: + opts[constant.BATCH_OPTS_FAIL_AFTER_N_REQUESTS] = ( + failed_after_after_n_requests + ) + # Expiration window is set on Job spec: activeDeadlineSeconds + + # Get file IDs + output_file_id = os.getenv("BATCH_OUTPUT_FILE_ID") + temp_output_file_id = os.getenv("BATCH_TEMP_OUTPUT_FILE_ID") + error_file_id = os.getenv("BATCH_ERROR_FILE_ID") + temp_error_file_id = os.getenv("BATCH_TEMP_ERROR_FILE_ID") + + if not input_file_id: + raise ValueError("BATCH_INPUT_FILE_ID environment variable is required") + if not endpoint: + raise ValueError("BATCH_ENDPOINT environment variable is required") + + logger.info( + "Confirmed Job", + name=job_name, + namespace=job_ns, + job_id=job_id, + input_file_id=input_file_id, + endpoint=endpoint, + opts=opts, + ) # type: ignore[call-arg] + + try: + # Initialize health checker + health_url = os.getenv("LLM_READY_ENDPOINT", "http://localhost:8000/health") + self.health_checker = LLMHealthChecker(health_url) + logger.info("Set health checker", health_url=health_url) # type: ignore[call-arg] + + # Try to construct base URL from health URL + parsed = urlparse(health_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + self.llm_engine_base_url = base_url + logger.info("Set LLM engine base URL", base_url=base_url) # type: ignore[call-arg] + + # Create BatchJobSpec + spec = BatchJobSpec.from_strings( + input_file_id=input_file_id, + endpoint=endpoint, + metadata=None, + opts=opts, + ) + + # Determine state based on file IDs (as in current transformer logic) + validated = ( + output_file_id is not None + and temp_output_file_id is not None + and error_file_id is not None + and temp_error_file_id is not None + ) + state = BatchJobState.IN_PROGRESS if validated else BatchJobState.CREATED + + # Create BatchJobStatus + status = BatchJobStatus( + jobID=job_id, + state=state, + outputFileID=output_file_id, + tempOutputFileID=temp_output_file_id, + errorFileID=error_file_id, + tempErrorFileID=temp_error_file_id, + createdAt=datetime.now(timezone.utc), + ) + + # Create BatchJob + batch_job = BatchJob.new_from_spec(job_name, job_ns, spec) + batch_job.status = status + + logger.info( + "Successfully generated BatchJob from environment variables", + job_name=batch_job.metadata.name, + job_namespace=batch_job.metadata.namespace, + job_id=batch_job.job_id, + state=state.value, + validated=validated, + ) # type: ignore[call-arg] + + return batch_job + except Exception as e: + raise RuntimeError( + f"Error creating BatchJob from environment variables: {e}" + ) from e + + def load_job_from_k8s(self, llm_engine_container_name: str) -> BatchJob: + """Load and transform the parent Kubernetes Job to BatchJob.""" + logger.info("Loading job specification from Kubernetes API...") + + # Load k8s batch api client + from kubernetes import client, config + + try: + config.load_incluster_config() + except config.ConfigException: + logger.warning("Failed to load in-cluster config, trying local config...") + config.load_kube_config() + batch_api_client = client.BatchV1Api() + + # Get the Job name and namespace from environment variables + # Get basic job information + job_name = os.getenv("JOB_NAME") + namespace = os.getenv("JOB_NAMESPACE") + if not job_name: + raise ValueError("JOB_NAME environment variable is required") + if not namespace: + raise ValueError("JOB_NAMESPACE environment variable is required") + + logger.info("Confirmed Job", name=job_name, namespace=namespace) # type: ignore[call-arg] + + try: + # Fetch the Job object + logger.info("Fetching Job spec from the Kubernetes API...") + k8s_job = batch_api_client.read_namespaced_job( + name=job_name, namespace=namespace + ) + + # Extract LLM engine information and initialize health checker + health_url = os.getenv("LLM_READY_ENDPOINT", "http://localhost:8000/health") + self.health_checker = LLMHealthChecker(health_url) + + # Transform k8s Job to BatchJob using BatchJobTransformer + batch_job = BatchJobTransformer.from_k8s_job(k8s_job) + logger.info( + "Successfully transformed k8s Job to BatchJob", job_id=batch_job.job_id + ) # type: ignore[call-arg] + + return batch_job + + except client.ApiException as e: + raise RuntimeError(f"Error fetching Job from Kubernetes API: {e}") from e + except Exception as e: + raise RuntimeError(f"Error transforming Job to BatchJob: {e}") from e + + def extract_llm_engine_info(self, k8s_job, llm_engine_container_name: str): + """Extract LLM engine health check information from k8s job object.""" + logger.info( + "Extracting LLM engine info for container", + container_name=llm_engine_container_name, + ) # type: ignore[call-arg] + + # Extract health check endpoint from livenessProbe or readinessProbe + ready_url = "http://localhost:8000/health" # default + check_interval = 5 # default + try: + # Navigate to pod template containers + pod_template = k8s_job.spec.template + containers = pod_template.spec.containers + + # Find the LLM engine container + llm_container = None + for container in containers: + if container.name == llm_engine_container_name: + llm_container = container + break + + if not llm_container: + logger.warning( + "LLM engine container not found, using defaults", + container_name=llm_engine_container_name, + ) # type: ignore[call-arg] + return ready_url, check_interval + + probe = None + probe_type = None + + # Prefer livenessProbe, fallback to readinessProbe + if llm_container.readiness_probe: + probe = llm_container.readiness_probe + probe_type = "readiness" + elif llm_container.readiness_probe: + probe = llm_container.readiness_probe + probe_type = "readiness" + + if probe and probe.http_get: + # Extract interval from periodSeconds + if probe.period_seconds: + check_interval = probe.period_seconds + + # Extract health endpoint from HTTP get action + http_get = probe.http_get + scheme = http_get.scheme or "http" + host = http_get.host or "localhost" + port = http_get.port or 8000 + path = http_get.path or "/health" + + ready_url = f"{scheme.lower()}://{host}:{port}{path}" + logger.info( + "Found health probe, extracting health check info", + probe_type=probe_type, + endpoint=ready_url, + interval=check_interval, + ) # type: ignore[call-arg] + else: + logger.info( + "No liveness or readiness probe found, using default health check" + ) + + # Try to construct base URL from health URL + from urllib.parse import urlparse + + parsed = urlparse(ready_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + self.llm_engine_base_url = base_url + logger.info("Set LLM engine base URL", base_url=base_url) # type: ignore[call-arg] + + return ready_url, check_interval + + except Exception as e: + logger.warning( + "Error extracting LLM engine info, using defaults", + error=str(e), + endpoint=ready_url, + interval=check_interval, + ) # type: ignore[call-arg] + return ready_url, check_interval + + async def execute_batch_job(self, batch_job: BatchJob) -> str: + """Execute the provided batch job.""" + assert ( + self.driver is not None + ), "Driver must be initialized before executing jobs" + + job_id = batch_job.job_id + if job_id is None: + raise RuntimeError("BatchJob job_id is None") + + logger.info( + "Executing batch job", + job_id=job_id, + input_file_id=batch_job.spec.input_file_id, + endpoint=batch_job.spec.endpoint, + ) # type: ignore[call-arg] + + # Commit job to job manager + await self.driver.job_manager.job_committed_handler(batch_job) + logger.info("Job committed to manager", job_id=job_id) # type: ignore[call-arg] + + # Wait until job reaches FINALIZING state + await self.wait_for_finalizing(job_id) + + return job_id + + async def wait_for_in_progress( + self, job: BatchJob, max_wait: int = 300 + ) -> BatchJob: + """Wait for job to reach IN_PROGRESS state by watching Kubernetes job updates. + + Raises: + ReadTimeoutError if job does not reach IN_PROGRESS state within max_wait seconds. + """ + if job.status.state != BatchJobState.CREATED: + logger.info( + "Job already in non-CREATED state", + job_id=job.job_id, + current_state=job.status.state.value, + output_file_id=job.status.output_file_id, + temp_output_file_id=job.status.temp_output_file_id, + error_file_id=job.status.error_file_id, + temp_error_file_id=job.status.temp_error_file_id, + ) # type: ignore[call-arg] + return job + + logger.info( + "Waiting for job to reach IN_PROGRESS state", + job_id=job.job_id, + job_name=job.metadata.name, + namespace=job.metadata.namespace, + max_wait=max_wait, + ) # type: ignore[call-arg] + + # [TODO][NOW] Use metestore to watch in_progress status. + + # Unlikely, should raise ReadTimeoutError + return job + + async def wait_for_finalizing(self, job_id: str, max_wait: int = 600): + """Wait for job to reach FINALIZING state.""" + assert ( + self.driver is not None + ), "Driver must be initialized before waiting for jobs" + + start_time = time.time() + + while time.time() - start_time < max_wait: + job = await self.driver.job_manager.get_job(job_id) + if job and job.status.finished: + logger.info( + "Job reached final state", + job_id=job_id, + state=job.status.state.value, + ) # type: ignore[call-arg] + return + + await asyncio.sleep(1) + + raise TimeoutError( + f"Job {job_id} did not reach final state within {max_wait} seconds" + ) + + async def run(self, args: argparse.Namespace) -> int: + """Main worker execution flow.""" + try: + # Step 1: Load job specification + batchJob: Optional[BatchJob] = None + try: + if args.load_job_from_api: + batch_job = self.load_job_from_k8s(args.llm_engine_container_name) + except Exception: + pass + + # Use env as a fallback + if batchJob is None: + batch_job = self.load_job_from_env() + + logger.info( + "Loaded job specification", + input_file_id=batch_job.spec.input_file_id, + endpoint=batch_job.spec.endpoint, + ) # type: ignore[call-arg] + + # Wait for job to become IN_PROGRESS (metadata server will prepare the job) + batch_job = await self.wait_for_in_progress(batch_job) + if batch_job.status.state != BatchJobState.IN_PROGRESS: + logger.error("Job failed to reach IN_PROGRESS state") + return 1 + + # Step 2: Wait for vLLM service to become ready + assert ( + self.health_checker is not None + ), "Health checker should be initialized" + if not await self.health_checker.wait_for_ready(): + logger.error("vLLM service failed to become ready") + return 1 + + # Step 3: Initialize BatchDriver + self.driver = BatchDriver(llm_engine_endpoint=self.llm_engine_base_url) + await self.driver.start() + logger.info("BatchDriver initialized successfully") + + # Step 4: Execute batch job + job_id = await self.execute_batch_job(batch_job) + logger.info("Batch worker completed successfully", job_id=job_id) # type: ignore[call-arg] + + await self.driver.stop() + return 0 + + except Exception as e: + file, lineno, func_name = get_error_details(e) + logger.error( + "Batch worker failed", + error=str(e), + file=file, + lineno=lineno, + function=func_name, + ) # type: ignore[call-arg] + return 1 + + finally: + # Cleanup driver if initialized + if self.driver: + logger.info("Cleaning up BatchDriver...") + await self.driver.stop() + + +def get_error_details(ex: Exception) -> tuple[str, int | None, str]: + import traceback + + """ + Must be called from within an 'except' block. + + Returns a tuple containing the filename, line number, and function name + where the exception occurred. + """ + # If the exception has a cause, that's the original error. + # Otherwise, use the exception itself. + target_exc = ex.__cause__ if ex.__cause__ else ex + + # Create a structured traceback from the target exception + tb_exc = traceback.TracebackException.from_exception(target_exc) + + # The last frame in the stack is where the error originated + last_frame = tb_exc.stack[-1] + + return (last_frame.filename, last_frame.lineno, last_frame.name) + + +def kill_llm_engine(): + """Kill the llm process identified by WORKER_VICTIM=1 environment variable.""" + logger.info("Looking for llm engine with WORKER_VICTIM=1 environment variable...") + + try: + # Use grep and awk to find PID with WORKER_VICTIM=1 in environment + result = subprocess.run( + [ + "bash", + "-c", + "grep -zla 'WORKER_VICTIM=1' /proc/*/environ 2>/dev/null | awk -F/ '/\/proc\/[0-9]+\/environ/ {print $3}' | sort -n", + ], + capture_output=True, + text=True, + ) + + pids = [pid.strip() for pid in result.stdout.strip().split("\n") if pid.strip()] + + if not pids: + logger.info("No process found with WORKER_VICTIM=1 environment variable") + # # Fallback to pgrep method + # logger.info("Falling back to pgrep method...") + # result = subprocess.run( + # ["pgrep", "-f", "python"], capture_output=True, text=True + # ) + # pids = result.stdout.strip().split() + + # Filter out current process PID + current_pid = os.getpid() + pids = [pid for pid in pids if pid and int(pid) != current_pid] + + if not pids: + logger.info("No server process found to terminate") + return + + # Kill the first server process found + server_pid = int(pids[0]) + logger.info(f"Found server process with PID: {server_pid}. Sending SIGTERM...") + os.kill(server_pid, signal.SIGINT) + logger.info("SIGTERM sent to server process") + except subprocess.CalledProcessError as e: + logger.warning(f"Process discovery command failed: {e}") + except ProcessLookupError: + logger.info("Server process already terminated") + except Exception as e: + logger.error(f"Error while terminating server process: {e}") + + +async def worker_main() -> int: + """Main entry point for the batch worker.""" + loop = asyncio.get_running_loop() + + # --- Add Signal Handlers --- + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + loop.add_signal_handler(signal.SIGINT, stop.set_result, None) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--load-job-from-api", + action="store_true", + help="load job spec from api server", + ) + parser.add_argument( + "--llm-engine-container-name", + type=str, + default="llm-engine", + help="container name of the llm engine", + ) + args = parser.parse_args() + + # --- Run your main logic --- + logger.info("Worker starting...") + worker = BatchWorker() + worker_task = asyncio.create_task(worker.run(args)) + + # --- Wait for either the task to complete or a stop signal --- + done, pending = await asyncio.wait( + [stop, worker_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if stop in done: + logger.info("Shutdown signal received, cancelling tasks...") + # Gracefully cancel pending tasks + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + return 1 # Return a non-zero exit code for signal termination + + # If the worker task finished on its own + logger.info("Worker finished normally.") + + kill_llm_engine() + + return worker_task.result() + + +def main(): + try: + code = asyncio.run(worker_main()) + sys.exit(code) + except asyncio.CancelledError: + logger.info("Main task was cancelled.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/python/aibrix/aibrix/config.py b/python/aibrix/aibrix/config.py index 1e9596774..08371b1b8 100644 --- a/python/aibrix/aibrix/config.py +++ b/python/aibrix/aibrix/config.py @@ -11,6 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class AIBrixSettings(BaseSettings): + # This loads *only* from system environment variables + # Uncomment env_file and env_file_encoding to load from .env file + model_config = SettingsConfigDict( + # env_file=".env", # Load environment variables from a .env file + # env_file_encoding="utf-8", # Encoding for the .env file + extra="ignore", + ) + + # --- Security Settings --- + SECRET_KEY: str = ( + "test-secret-key-for-testing" # No default, reserved for later use. + ) + + # --- Logging Settings --- + LOG_LEVEL: str = "DEBUG" # DEBUG, INFO, WARNING, ERROR, CRITICAL + LOG_PATH: Optional[str] = "/tmp/aibrix/python.log" # If None, logs to stdout only + LOG_FORMAT: str = "%(asctime)s - %(filename)s:%(lineno)d - %(funcName)s - %(levelname)s - %(message)s" DEFAULT_METRIC_COLLECTOR_TIMEOUT = 1 diff --git a/python/aibrix/aibrix/logger.py b/python/aibrix/aibrix/logger.py index 66aba65f3..1c5e2b6b0 100644 --- a/python/aibrix/aibrix/logger.py +++ b/python/aibrix/aibrix/logger.py @@ -13,32 +13,61 @@ # limitations under the License. import logging +import os import sys from logging import Logger from logging.handlers import RotatingFileHandler from pathlib import Path +from typing import Optional +import structlog -def _default_logging_basic_config() -> None: - Path("/tmp/aibrix").mkdir(parents=True, exist_ok=True) - logging.basicConfig( - format="%(asctime)s - %(filename)s:%(lineno)d - %(funcName)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S %Z", - handlers=[ - logging.StreamHandler(stream=sys.stdout), +from .config import AIBrixSettings + +active_settings: AIBrixSettings = AIBrixSettings() + + +def logging_basic_config(settings: Optional[AIBrixSettings] = None) -> None: + global active_settings + if settings is not None: + active_settings = settings + + # 1. Configure the standard library logging + handlers: list[logging.Handler] = [logging.StreamHandler(stream=sys.stdout)] + if active_settings.LOG_PATH is not None: + Path(os.path.dirname(active_settings.LOG_PATH)).mkdir( + parents=True, exist_ok=True + ) + handlers.append( RotatingFileHandler( - "/tmp/aibrix/python.log", + active_settings.LOG_PATH, maxBytes=10 * (2**20), backupCount=10, - ), + ) + ) + logging.basicConfig(format=active_settings.LOG_FORMAT, handlers=handlers) + + # 2. Configure structlog processors and renderer + structlog.configure( + processors=[ + structlog.stdlib.add_logger_name, + structlog.stdlib.add_log_level, + structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S %Z"), + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + structlog.processors.JSONRenderer(), # Renders the log event as JSON ], - level=logging.INFO, + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, ) def init_logger(name: str) -> Logger: - return logging.getLogger(name) + logger = structlog.get_logger(name) + logger.setLevel(active_settings.LOG_LEVEL) + return logger -_default_logging_basic_config() +logging_basic_config() logger = init_logger(__name__) diff --git a/python/aibrix/aibrix/metadata/api/v1/batch.py b/python/aibrix/aibrix/metadata/api/v1/batch.py index 8e47423c5..a3c097add 100644 --- a/python/aibrix/aibrix/metadata/api/v1/batch.py +++ b/python/aibrix/aibrix/metadata/api/v1/batch.py @@ -13,16 +13,24 @@ # limitations under the License. import asyncio +import traceback import uuid -from datetime import datetime +from datetime import datetime, timedelta from typing import Dict, List, Optional -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from pydantic import BaseModel, Field -from aibrix.batch.job_entity import BatchJob, BatchJobError, BatchJobSpec -from aibrix.batch.job_manager import JobManager -from aibrix.metadata.logger import init_logger +from aibrix.batch import BatchDriver +from aibrix.batch.job_entity import ( + BatchJob, + BatchJobEndpoint, + BatchJobError, + BatchJobSpec, + BatchJobStatus, + CompletionWindow, +) +from aibrix.logger import init_logger logger = init_logger(__name__) @@ -32,6 +40,35 @@ # OpenAI Batch API request/response models +class BatchSpec(BaseModel): + """Defines the specification of a Batch job input, which is OpenAI batch compatible.""" + + input_file_id: str = Field( + description="The ID of an uploaded file that contains the requests for the batch", + ) + endpoint: BatchJobEndpoint = Field( + description="The API endpoint to be used for all requests in the batch" + ) + completion_window: CompletionWindow = Field( + default=CompletionWindow.TWENTY_FOUR_HOURS, + description="The time window for completion", + ) + metadata: Optional[Dict[str, str]] = Field( + default=None, + description="Set of up to 16 key-value pairs to attach to the batch object", + max_length=16, + ) + + @classmethod + def newBatchJobSpec(cls, spec: "BatchSpec") -> BatchJobSpec: + return BatchJobSpec( + input_file_id=spec.input_file_id, + endpoint=spec.endpoint.value, + completion_window=spec.completion_window.expires_at(), + metadata=spec.metadata, + ) + + class BatchRequestCounts(BaseModel): """Request counts for OpenAI batch API.""" @@ -75,9 +112,7 @@ class BatchResponse(BaseModel): in_progress_at: Optional[int] = Field( default=None, description="Unix timestamp of when the batch started processing" ) - expires_at: Optional[int] = Field( - default=None, description="Unix timestamp of when the batch expires" - ) + expires_at: int = Field(description="Unix timestamp of when the batch expires") finalizing_at: Optional[int] = Field( default=None, description="Unix timestamp of when the batch started finalizing" ) @@ -118,8 +153,8 @@ class BatchListResponse(BaseModel): def _batch_job_to_openai_response(batch_job: BatchJob) -> BatchResponse: """Convert BatchJob to OpenAI batch response format.""" - status = batch_job.status - spec = batch_job.spec + status: BatchJobStatus = batch_job.status + spec: BatchJobSpec = batch_job.spec def dt_to_unix(dt: Optional[datetime]) -> Optional[int]: """Convert datetime to unix timestamp.""" @@ -127,7 +162,7 @@ def dt_to_unix(dt: Optional[datetime]) -> Optional[int]: # Convert request counts request_counts = None - if status.request_counts: + if status.request_counts and status.request_counts.total > 0: request_counts = BatchRequestCounts( total=status.request_counts.total, completed=status.request_counts.completed, @@ -135,21 +170,37 @@ def dt_to_unix(dt: Optional[datetime]) -> Optional[int]: ) created_at_unix = dt_to_unix(status.created_at) - if created_at_unix is None: - created_at_unix = int(datetime.now().timestamp()) + assert created_at_unix is not None + + delta = timedelta(seconds=spec.completion_window) + total_hours = delta.total_seconds() / 3600 + completion_window = f"{int(total_hours)}h" + + state = status.state.value + if status.finished: + condition = status.condition + if condition is None: + logger.error( + "Unexpected job finalized without condition", + job_id=batch_job.job_id, + state=status.state.value, + conditions=status.conditions, + ) # type:ignore[call-arg] + raise ValueError("job finalized without condition") + state = condition.value return BatchResponse( id=status.job_id, - endpoint=spec.endpoint.value, + endpoint=spec.endpoint, errors=BatchErrors(data=status.errors) if status.errors else None, input_file_id=spec.input_file_id, - completion_window=spec.completion_window.value, - status=status.state.value, + completion_window=completion_window, + status=state, output_file_id=status.output_file_id, error_file_id=status.error_file_id, created_at=created_at_unix, in_progress_at=dt_to_unix(status.in_progress_at), - expires_at=created_at_unix + int(spec.completion_window.expires_at()), + expires_at=created_at_unix + spec.completion_window, finalizing_at=dt_to_unix(status.finalizing_at), completed_at=dt_to_unix(status.completed_at), failed_at=dt_to_unix(status.failed_at), @@ -161,8 +212,11 @@ def dt_to_unix(dt: Optional[datetime]) -> Optional[int]: ) -@router.post("/") -async def create_batch(request: Request, batch_request: BatchJobSpec) -> BatchResponse: +@router.post("/", include_in_schema=False) +@router.post("") +async def create_batch( + request: Request, batch_request: BatchJobSpec = Depends(BatchSpec.newBatchJobSpec) +) -> BatchResponse: """Create a new batch. Creates a new batch for processing multiple requests. The batch will be @@ -170,7 +224,7 @@ async def create_batch(request: Request, batch_request: BatchJobSpec) -> BatchRe """ try: # Get job controller from app state - job_manager: JobManager = request.app.state.job_controller.job_manager + batch_driver: BatchDriver = request.app.state.batch_driver # Generate session ID for tracking session_id = str(uuid.uuid4()) @@ -184,13 +238,15 @@ async def create_batch(request: Request, batch_request: BatchJobSpec) -> BatchRe ) # type: ignore[call-arg] # Create job using JobManager - job_id = await job_manager.create_job_with_spec( - session_id=session_id, - job_spec=batch_request, + job_id = await batch_driver.run_coroutine( + batch_driver.job_manager.create_job_with_spec( + session_id=session_id, + job_spec=batch_request, + ) ) # Retrieve the created job - job = job_manager.get_job(job_id) + job = await batch_driver.run_coroutine(batch_driver.job_manager.get_job(job_id)) if not job: logger.error("Created job not found", job_id=job_id) # type: ignore[call-arg] raise HTTPException(status_code=500, detail="Created batch not found") @@ -221,12 +277,14 @@ async def get_batch(request: Request, batch_id: str) -> BatchResponse: """ try: # Get job controller from app state - job_manager: JobManager = request.app.state.job_controller.job_manager + batch_driver: BatchDriver = request.app.state.batch_driver logger.debug("Retrieving batch", batch_id=batch_id) # type: ignore[call-arg] # Get job from manager - job = job_manager.get_job(batch_id) + job = await batch_driver.run_coroutine( + batch_driver.job_manager.get_job(batch_id) + ) if not job: logger.warning("Batch not found", batch_id=batch_id) # type: ignore[call-arg] raise HTTPException(status_code=404, detail="Batch not found") @@ -238,6 +296,7 @@ async def get_batch(request: Request, batch_id: str) -> BatchResponse: logger.error( "Unexpected error retrieving batch", batch_id=batch_id, error=str(e) ) # type: ignore[call-arg] + logger.error(f"Stack trace: {traceback.format_exc()}") raise HTTPException(status_code=500, detail="Internal server error") @@ -250,24 +309,30 @@ async def cancel_batch(request: Request, batch_id: str) -> BatchResponse: """ try: # Get job controller from app state - job_manager: JobManager = request.app.state.job_controller.job_manager + batch_driver: BatchDriver = request.app.state.batch_driver logger.info("Cancelling batch", batch_id=batch_id) # type: ignore[call-arg] # Check if job exists - job = job_manager.get_job(batch_id) + job = await batch_driver.run_coroutine( + batch_driver.job_manager.get_job(batch_id) + ) if not job: logger.warning("Batch not found for cancellation", batch_id=batch_id) # type: ignore[call-arg] raise HTTPException(status_code=404, detail="Batch not found") # Cancel the job - success = job_manager.cancel_job(batch_id) + success = await batch_driver.run_coroutine( + batch_driver.job_manager.cancel_job(batch_id) + ) if not success: logger.warning("Failed to cancel batch", batch_id=batch_id) # type: ignore[call-arg] raise HTTPException(status_code=400, detail="Batch cannot be cancelled") # Get updated job status - updated_job = job_manager.get_job(batch_id) + updated_job = await batch_driver.run_coroutine( + batch_driver.job_manager.get_job(batch_id) + ) if not updated_job: logger.error("Job not found after cancellation", batch_id=batch_id) # type: ignore[call-arg] raise HTTPException(status_code=500, detail="Internal server error") @@ -285,7 +350,8 @@ async def cancel_batch(request: Request, batch_id: str) -> BatchResponse: raise HTTPException(status_code=500, detail="Internal server error") -@router.get("/") +@router.get("/", include_in_schema=False) +@router.get("") async def list_batches( request: Request, after: Optional[str] = Query(None, description="Cursor for pagination"), @@ -298,14 +364,16 @@ async def list_batches( """ try: # Get job controller from app state - job_manager: JobManager = request.app.state.job_controller.job_manager + batch_driver: BatchDriver = request.app.state.batch_driver logger.debug("Listing batches", after=after, limit=limit) # type: ignore[call-arg] # Get all jobs from the manager # Note: This is a simple implementation. In production, you'd want # proper pagination and filtering in the JobManager - all_jobs: List[BatchJob] = await job_manager.list_jobs() + all_jobs: List[BatchJob] = await batch_driver.run_coroutine( + batch_driver.job_manager.list_jobs() + ) # Apply cursor-based pagination if after: diff --git a/python/aibrix/aibrix/metadata/api/v1/files.py b/python/aibrix/aibrix/metadata/api/v1/files.py index 9bafdc680..decd20834 100644 --- a/python/aibrix/aibrix/metadata/api/v1/files.py +++ b/python/aibrix/aibrix/metadata/api/v1/files.py @@ -20,7 +20,7 @@ from fastapi import APIRouter, File, Form, HTTPException, Request, Response, UploadFile from pydantic import Field -from aibrix.metadata.logger import init_logger +from aibrix.logger import init_logger from aibrix.metadata.setting.config import settings from aibrix.openapi.protocol import NoExtraBaseModel from aibrix.storage import BaseStorage, Reader, SizeExceededError, generate_filename @@ -112,7 +112,8 @@ def _create_error_response( return {"error": error_data} -@router.post("/") +@router.post("/", include_in_schema=False) +@router.post("") async def create_file( request: Request, file: UploadFile = File(..., description="The file to upload"), @@ -143,10 +144,10 @@ async def create_file( # Generate metadata created_at = int(time.time()) - metadata = { - "filename": file.filename, + metadata: dict[str, str] = { + "filename": file.filename or "", "purpose": purpose.value, - "created_at": created_at, + "created_at": str(created_at), # requires all value a string. } try: await storage.put_object( @@ -181,7 +182,7 @@ async def create_file( except Exception as e: logger.error("Unexpected error uploading file", error=str(e)) # type: ignore[call-arg] error_response = _create_error_response("Internal server error") - raise HTTPException(status_code=500, detail=error_response) + raise # HTTPException(status_code=500, detail=error_response) @router.get("/{file_id}/content") diff --git a/python/aibrix/aibrix/metadata/app.py b/python/aibrix/aibrix/metadata/app.py index 5d217fd6b..95ecb0e74 100644 --- a/python/aibrix/aibrix/metadata/app.py +++ b/python/aibrix/aibrix/metadata/app.py @@ -13,17 +13,20 @@ # limitations under the License. import argparse from contextlib import asynccontextmanager -from typing import Optional +from pathlib import Path +from typing import Any, Dict, Optional import uvicorn -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, FastAPI, Request from fastapi.responses import JSONResponse +from kubernetes import config from aibrix.batch import BatchDriver from aibrix.batch.job_entity import JobEntityManager +from aibrix.logger import init_logger, logging_basic_config from aibrix.metadata.api.v1 import batch, files -from aibrix.metadata.core.httpx_client import HTTPXClientWrapper -from aibrix.metadata.logger import init_logger +from aibrix.metadata.cache import JobCache +from aibrix.metadata.core import HTTPXClientWrapper, KopfOperatorWrapper from aibrix.metadata.setting import settings from aibrix.storage import create_storage @@ -43,23 +46,57 @@ async def readiness_check(): return JSONResponse(content={"status": "ready"}, status_code=200) +@router.get("/status") +async def status_check(request: Request): + """Get detailed status of all components.""" + status: Dict[str, Any] = { + "httpx_client": { + "available": hasattr(request.app.state, "httpx_client_wrapper"), + "status": "initialized" + if hasattr(request.app.state, "httpx_client_wrapper") + else "not_initialized", + }, + "kopf_operator": { + "available": hasattr(request.app.state, "kopf_operator_wrapper"), + }, + "batch_driver": { + "available": hasattr(request.app.state, "batch_driver"), + }, + } + + # Get detailed kopf operator status if available + if hasattr(request.app.state, "kopf_operator_wrapper"): + kopf_status = request.app.state.kopf_operator_wrapper.get_status() + status["kopf_operator"].update(kopf_status) + + return JSONResponse(content=status, status_code=200) + + @asynccontextmanager async def lifespan(app: FastAPI): # Code executed on startup + logger.info("Initializing FastAPI app...") if hasattr(app.state, "httpx_client_wrapper"): app.state.httpx_client_wrapper.start() + if hasattr(app.state, "kopf_operator_wrapper"): + app.state.kopf_operator_wrapper.start() + if hasattr(app.state, "batch_driver"): + await app.state.batch_driver.start() yield # Code executed on shutdown - if hasattr(app.state, "job_controller"): - await app.state.job_controller.close() + logger.info("Finalizing FastAPI app...") + if hasattr(app.state, "batch_driver"): + await app.state.batch_driver.stop() + if hasattr(app.state, "kopf_operator_wrapper"): + app.state.kopf_operator_wrapper.stop() if hasattr(app.state, "httpx_client_wrapper"): await app.state.httpx_client_wrapper.stop() -def build_app(args: argparse.Namespace): +def build_app(args: argparse.Namespace, params={}): if args.enable_fastapi_docs: - app = FastAPI(lifespan=lifespan, debug=False) + app = FastAPI(lifespan=lifespan, debug=False, redirect_slashes=False) else: app = FastAPI( lifespan=lifespan, @@ -67,18 +104,32 @@ def build_app(args: argparse.Namespace): openapi_url=None, docs_url=None, redoc_url=None, + redirect_slashes=False, ) app.state.httpx_client_wrapper = HTTPXClientWrapper() + # Initialize kopf operator wrapper if K8s jobs are enabled + if args.enable_k8s_job: + app.state.kopf_operator_wrapper = KopfOperatorWrapper( + namespace=getattr(args, "k8s_namespace", "default"), + startup_timeout=getattr(args, "kopf_startup_timeout", 30.0), + shutdown_timeout=getattr(args, "kopf_shutdown_timeout", 10.0), + ) + app.include_router(router) # Initialize batches API if not args.disable_batch_api: job_entity_manager: Optional[JobEntityManager] = None - app.state.job_controller = BatchDriver( + if args.enable_k8s_job: + # Get template_path from params if provided + job_entity_manager = JobCache(template_patch_path=args.k8s_job_patch) + app.state.batch_driver = BatchDriver( job_entity_manager, storage_type=settings.STORAGE_TYPE, metastore_type=settings.METASTORE_TYPE, + stand_alone=True, + params=params, ) app.include_router( batch.router, prefix=f"{settings.API_V1_STR}/batches", tags=["batches"] @@ -87,7 +138,7 @@ def build_app(args: argparse.Namespace): # Initialize fiels API if not args.disable_file_api: - app.state.storage = create_storage(settings.STORAGE_TYPE) + app.state.storage = create_storage(settings.STORAGE_TYPE, **params) app.include_router( files.router, prefix=f"{settings.API_V1_STR}/files", tags=["files"] ) # mount files api at /v1/files @@ -123,6 +174,36 @@ def main(): default=False, help="Disable file api", ) + parser.add_argument( + "--enable-k8s-job", + action="store_true", + default=False, + help="Enable native kubernetes jobs as the job executor", + ) + parser.add_argument( + "--k8s-namespace", + type=str, + default="default", + help="Kubernetes namespace to monitor for jobs (default: default)", + ) + parser.add_argument( + "--k8s-job-patch", + type=Path, + default=None, + help="Patch to customize k8s job template", + ) + parser.add_argument( + "--kopf-startup-timeout", + type=float, + default=30.0, + help="Timeout in seconds for kopf operator startup (default: 30.0)", + ) + parser.add_argument( + "--kopf-shutdown-timeout", + type=float, + default=10.0, + help="Timeout in seconds for kopf operator shutdown (default: 10.0)", + ) parser.add_argument( "--e2e-test", action="store_true", @@ -131,7 +212,17 @@ def main(): ) args = parser.parse_args() - logger.info(f"Using {args} to startup {settings.PROJECT_NAME}") + global logger + logging_basic_config(settings) + logger = init_logger(__name__) # Reset logger + + try: + config.load_incluster_config() + except Exception: + # Local debug + config.load_kube_config() + + logger.info(f"Using {args} to startup app", project=settings.PROJECT_NAME) # type: ignore[call-arg] app = build_app(args=args) uvicorn.run(app, host=args.host, port=args.port) diff --git a/python/aibrix/aibrix/metadata/cache/__init__.py b/python/aibrix/aibrix/metadata/cache/__init__.py new file mode 100644 index 000000000..000bd7ac3 --- /dev/null +++ b/python/aibrix/aibrix/metadata/cache/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .job import JobCache + +__all__ = [ + "JobCache", +] diff --git a/python/aibrix/aibrix/metadata/cache/job.py b/python/aibrix/aibrix/metadata/cache/job.py new file mode 100644 index 000000000..f7860a34d --- /dev/null +++ b/python/aibrix/aibrix/metadata/cache/job.py @@ -0,0 +1,1055 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import uuid +from pathlib import Path +from typing import Any, Callable, Coroutine, Dict, List, Optional + +import kopf +import yaml +from kubernetes import client +from kubernetes.client.rest import ApiException + +import aibrix.batch.storage.batch_metastore as metastore +import aibrix.batch.storage.batch_storage as storage +from aibrix.batch.job_entity import ( + BatchJob, + BatchJobSpec, + BatchJobState, + BatchJobStatus, + BatchJobTransformer, + ConditionType, + JobAnnotationKey, + JobEntityManager, + k8s_job_to_batch_job, +) +from aibrix.logger import init_logger +from aibrix.storage import StorageType + +from .utils import merge_yaml_object + +# If you installed kopf[uvloop], kopf will likely set this up. +# Otherwise, you can explicitly set it: +# import uvloop +# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +# Global logger for standalone functions +logger = init_logger(__name__) + +# Global JobCache instance for kopf handlers +_global_job_cache: Optional["JobCache"] = None + + +def set_global_job_cache(job_cache: "JobCache") -> None: + """Set the global job cache instance for kopf handlers.""" + global _global_job_cache + _global_job_cache = job_cache + + +def get_global_job_cache() -> Optional["JobCache"]: + """Get the global job cache instance.""" + return _global_job_cache + + +class JobCache(JobEntityManager): + """Kubernetes-based job cache implementing JobEntityManager interface. + + This class uses kopf to watch Kubernetes Job resources and maintains + an in-memory cache of BatchJob objects. It implements the JobEntityManager + interface to provide standardized job management capabilities. + """ + + def __init__(self, template_patch_path: Optional[Path] = None) -> None: + """Initialize the job cache. + + Args: + template_path: Optional path to custom k8s job template YAML file. + If None, uses the default template. + """ + # Cache of BatchJob objects keyed by batch ID (K8s UID) + self.active_jobs: Dict[str, BatchJob] = {} + + # Register this instance as the global job cache for kopf handlers + set_global_job_cache(self) + + # Callback handlers for job lifecycle events + self._job_committed_handler: Optional[ + Callable[[BatchJob], Coroutine[Any, Any, bool]] + ] = None + self._job_updated_handler: Optional[ + Callable[[BatchJob, BatchJob], Coroutine[Any, Any, bool]] + ] = None + self._job_deleted_handler: Optional[ + Callable[[BatchJob], Coroutine[Any, Any, bool]] + ] = None + + # Load Kubernetes Job template once at initialization + template_dir = Path(__file__).parent.parent / "setting" + + try: + path = template_dir / "k8s_job_template.yaml" + with open(path, "r") as f: + self.job_template = yaml.safe_load(f) + logger.info( + "Kubernetes Job template loaded successfully", + template_path=str(path), + ) # type: ignore[call-arg] + + # Apply customize job patch + if template_patch_path: + path = template_patch_path + with open(path, "r") as f: + self.job_template = merge_yaml_object( + self.job_template, yaml.safe_load(f), False + ) + logger.info( + "Kubernetes Job template (customize) loaded successfully", + template_path=str(path), + ) # type: ignore[call-arg] + + # Apply s3 config patch + self.storage_patches = {} + path = template_dir / "k8s_job_s3_patch.yaml" + with open(path, "r") as f: + self.storage_patches[StorageType.S3.value] = yaml.safe_load(f) + logger.info( + "Kubernetes Job storage patch (s3) loaded successfully", + template_path=str(path), + storage=StorageType.S3, + ) # type: ignore[call-arg] + + # Apply tos config patch + path = template_dir / "k8s_job_tos_patch.yaml" + with open(path, "r") as f: + self.storage_patches[StorageType.TOS.value] = yaml.safe_load(f) + logger.info( + "Kubernetes Job storage patch (tos) loaded successfully", + template_path=str(path), + storage=StorageType.TOS, + ) # type: ignore[call-arg] + + # Apply redis config patch + path = template_dir / "k8s_job_redis_patch.yaml" + with open(path, "r") as f: + self.storage_patches[StorageType.REDIS.value] = yaml.safe_load(f) + logger.info( + "Kubernetes Job storage patch (redis) loaded successfully", + template_path=str(path), + storage=StorageType.REDIS, + ) # type: ignore[call-arg] + except FileNotFoundError: + logger.error( + "Kubernetes Job template not found", + template_path=str(path), + operation="__init__", + ) # type: ignore[call-arg] + raise RuntimeError(f"Job template not found at {str(path)}") + except yaml.YAMLError as e: + logger.error( + "Failed to parse Kubernetes Job template", + error=str(e), + template_path=str(path), + operation="__init__", + ) # type: ignore[call-arg] + raise RuntimeError(f"Invalid YAML in job template: {e}") + + self.batch_v1_api = client.BatchV1Api() + self.core_v1_api = client.CoreV1Api() + self.rbac_v1_api = client.RbacAuthorizationV1Api() + + # Apply RBAC resources for job execution + self._apply_job_rbac(template_dir) + + def _apply_job_rbac(self, template_dir: Path) -> None: + """Apply RBAC resources for job execution from k8s_job_rbac.yaml.""" + try: + rbac_path = template_dir / "k8s_job_rbac.yaml" + with open(rbac_path, "r") as f: + rbac_docs = list(yaml.safe_load_all(f)) + + for doc in rbac_docs: + if not doc: # Skip empty documents + continue + + kind = doc.get("kind") + metadata = doc.get("metadata", {}) + name = metadata.get("name") + namespace = metadata.get("namespace", "default") + + try: + if kind == "ServiceAccount": + # Try to create, if exists then update + try: + self.core_v1_api.create_namespaced_service_account( + namespace=namespace, body=doc + ) + logger.info( + f"Created ServiceAccount: {doc['metadata']['name']}" + ) + except ApiException as e: + if e.status == 409: # Already exists + self.core_v1_api.patch_namespaced_service_account( + name=doc["metadata"]["name"], + namespace=namespace, + body=doc, + ) + logger.info( + f"Updated ServiceAccount: {doc['metadata']['name']}" + ) + else: + raise + + elif kind == "Role": + try: + self.rbac_v1_api.create_namespaced_role( + namespace=namespace, body=doc + ) + logger.info(f"Created Role: {doc['metadata']['name']}") + except ApiException as e: + if e.status == 409: # Already exists + self.rbac_v1_api.patch_namespaced_role( + name=doc["metadata"]["name"], + namespace=namespace, + body=doc, + ) + logger.info(f"Updated Role: {doc['metadata']['name']}") + else: + raise + + elif kind == "RoleBinding": + try: + self.rbac_v1_api.create_namespaced_role_binding( + namespace=namespace, body=doc + ) + logger.info( + f"Created RoleBinding: {doc['metadata']['name']}" + ) + except ApiException as e: + if e.status == 409: # Already exists + self.rbac_v1_api.patch_namespaced_role_binding( + name=doc["metadata"]["name"], + namespace=namespace, + body=doc, + ) + logger.info( + f"Updated RoleBinding: {doc['metadata']['name']}" + ) + else: + raise + else: + logger.warning(f"Unsupported RBAC resource kind: {kind}") + + except ApiException as e: + logger.error( + f"Failed to apply {kind} {name}: {e.status} {e.reason}", + error=str(e), + kind=kind, + name=name, + namespace=namespace, + ) # type: ignore[call-arg] + # Don't raise here to allow other resources to be applied + + except FileNotFoundError: + logger.warning( + "RBAC template not found, skipping RBAC setup", + template_path=str(rbac_path), + ) # type: ignore[call-arg] + except yaml.YAMLError as e: + logger.error( + "Failed to parse RBAC template", + error=str(e), + template_path=str(rbac_path), + ) # type: ignore[call-arg] + except Exception as e: + logger.error( + "Unexpected error applying RBAC resources", + error=str(e), + template_path=str(rbac_path), + ) # type: ignore[call-arg] + + def is_scheduler_enabled(self) -> bool: + """Check if JobEntityManager has own scheduler enabled.""" + return True + + # Implementation of JobEntityManager abstract methods + def get_job(self, job_id: str) -> Optional[BatchJob]: + """Get cached job detail by batch id. + + Args: + job_id: Batch id (Kubernetes UID). + + Returns: + BatchJob: Job detail. + + Raises: + KeyError: If job with given job_id is not found. + """ + if job_id not in self.active_jobs: + return None + return self.active_jobs[job_id] + + def list_jobs(self) -> List[BatchJob]: + """List unarchived jobs that cached locally. + + Returns: + List[BatchJob]: List of jobs. + """ + return list(self.active_jobs.values()) + + async def submit_job( + self, + session_id: str, + job_spec: BatchJobSpec, + job_name: Optional[str] = None, + parallelism: int = 1, + prepared_job: Optional[BatchJob] = None, + ) -> None: + """Submit job by creating a Kubernetes Job. + + Args: + job_spec: BatchJobSpec to submit to Kubernetes. + job_name: Optional job name, will generate one if not provided. + prepared_job: Optional BatchJob with file IDs to add to pod annotations. + + Raises: + RuntimeError: If Kubernetes client is not available. + ApiException: If Kubernetes API call fails. + """ + if not self.batch_v1_api: + raise RuntimeError("Kubernetes client not available") + + try: + # Convert BatchJobSpec to Kubernetes Job manifest + k8s_job = self._batch_job_spec_to_k8s_job( + session_id, job_spec, job_name, parallelism, prepared_job + ) + + # Get namespace from k8s_job, use default if not specified + namespace = k8s_job["metadata"].get("namespace") or "default" + + logger.info( # type: ignore[call-arg] + "Submitting job to Kubernetes", + namespace=namespace, + input_file_id=job_spec.input_file_id, + endpoint=job_spec.endpoint, + opts=job_spec.opts, + job_name=k8s_job["metadata"]["name"], + ) # type: ignore[call-arg] + + # Submit job asynchronously + async_result = await asyncio.to_thread( + self.batch_v1_api.create_namespaced_job, + namespace=namespace, + body=k8s_job, + async_req=True, + ) + + # Create a task to check job result asynchronously without blocking + async def check_job_result(): + try: + job_result = await asyncio.to_thread(async_result.get) + logger.info( # type: ignore[call-arg] + "Job successfully submitted to Kubernetes", + namespace=namespace, + job_name=job_result.metadata.name, + job_uid=job_result.metadata.uid, + ) # type: ignore[call-arg] + return job_result + except ApiException as e: + logger.error( # type: ignore[call-arg] + "Kubernetes API error during job submission", + input_file_id=job_spec.input_file_id, + endpoint=job_spec.endpoint, + error=str(e), + status_code=e.status, + reason=e.reason, + namespace=namespace, + operation="submit_job", + ) # type: ignore[call-arg] + except Exception as e: + # This could catch errors from async_result.get() or other unexpected issues + error_type = type(e).__name__ + logger.error( # type: ignore[call-arg] + "Unexpected error during job submission", + input_file_id=job_spec.input_file_id, + endpoint=job_spec.endpoint, + namespace=namespace, + error=str(e), + error_type=error_type, + operation="submit_job", + ) # type: ignore[call-arg] + + # Start the job result checking task but don't wait for it + asyncio.create_task(check_job_result()) + + except Exception as e: + error_type = type(e).__name__ + logger.error( # type: ignore[call-arg] + "Unexpected error during job submission", + input_file_id=job_spec.input_file_id, + endpoint=job_spec.endpoint, + namespace=namespace, + error=str(e), + error_type=error_type, + operation="submit_job", + ) # type: ignore[call-arg] + + async def update_job_ready(self, job: BatchJob): + """Update job by marking it ready info in the persist store. + The job suspend flag will be removed to start the execution. + + Args: + job (BatchJob): Job to update. + """ + if not self.batch_v1_api: + raise RuntimeError("Kubernetes client not available") + + patch_body: Optional[Dict[str, Any]] = None + try: + # Get namespace from k8s_job, use default if not specified + namespace = job.metadata.namespace or "default" + + # Convert BatchJobSpec to Kubernetes Job manifest + patch_body = self._ready_batch_job_to_k8s_job_patch(job) + + logger.info( # type: ignore[call-arg] + "Executing job setting to ready", + job_name=job.metadata.name, + namespace=namespace, + patch=patch_body, + ) # type: ignore[call-arg] + + await asyncio.to_thread( + self.batch_v1_api.patch_namespaced_job, + name=job.metadata.name, + namespace=namespace, + body=patch_body, + async_req=True, + ) + + except ApiException as e: + if e.status == 409: + logger.warning( # type: ignore[call-arg] + "Job status changed", + job=job.metadata.name, + namespace=namespace, + job_id=job.job_id, + ) + raise + else: + logger.error( # type: ignore[call-arg] + "Failed to set job ready", + job_name=job.metadata.name, + namespace=namespace, + patch=patch_body, + error=str(e), + status_code=e.status, + reason=e.reason, + ) # type: ignore[call-arg] + raise + except Exception as e: + logger.error( # type: ignore[call-arg] + "Unexpected error setting job ready", + job_name=job.metadata.name, + namespace=namespace, + patch=patch_body, + error=str(e), + operation="update_job_ready", + ) # type: ignore[call-arg] + raise + + async def update_job_status(self, job: BatchJob): + """Update job status by persisting status information as annotations. + + Args: + job (BatchJob): Job with updated status to persist. + + This method persists critical job status information including: + - Finalized state + - Conditions (completed, failed, cancelled) + - Request counts + - Timestamps (in_progress_at, completed_at, failed_at, cancelled_at, etc.) + """ + if not self.batch_v1_api: + raise RuntimeError("Kubernetes client not available") + + patch_body: Any = None + try: + # Create status annotations from job status + status_annotations = BatchJobTransformer.create_status_annotations( + job.status + ) + + if not status_annotations: + logger.debug("No status annotations to persist", job_id=job.job_id) # type: ignore[call-arg] + return + + # Create patch body to update pod template annotations + patch_body = { + "metadata": { + "resourceVersion": job.metadata.resource_version, + "annotations": status_annotations, + } + } + + namespace = job.metadata.namespace or "default" + + logger.info( # type: ignore[call-arg] + "Executing job status update", + job_name=job.metadata.name, + namespace=namespace, + job_id=job.job_id, + patch=patch_body, + ) + + await asyncio.to_thread( + self.batch_v1_api.patch_namespaced_job, + name=job.metadata.name, + namespace=namespace, + body=patch_body, + async_req=True, + ) + + except ApiException as e: + if e.status == 409: + logger.warning( # type: ignore[call-arg] + "Job status changed", + job=job.metadata.name, + namespace=namespace, + job_id=job.job_id, + ) + raise + else: + logger.error( # type: ignore[call-arg] + "Failed to persist job status to Kubernetes", + job_name=job.metadata.name, + namespace=job.metadata.namespace or "default", + job_id=job.job_id, + error=str(e), + status_code=e.status, + reason=e.reason, + ) + raise + except Exception as e: + logger.error( # type: ignore[call-arg] + "Unexpected error persisting job status", + job_name=job.metadata.name, + namespace=job.metadata.namespace or "default", + job_id=job.job_id, + error=str(e), + patch=str(patch_body), + operation="update_job_status", + ) + raise + + async def cancel_job(self, job: BatchJob) -> None: + """Cancel job by suspending it and persisting cancellation status. + + Args: + job_id: Job ID (batch ID) to cancel. + + Raises: + RuntimeError: If Kubernetes client is not available. + KeyError: If job is not found in cache. + ApiException: If Kubernetes API call fails. + """ + if not self.batch_v1_api: + raise RuntimeError("Kubernetes client not available") + + # Get job from cache to find namespace and name + assert ( + job.status.state == BatchJobState.FINALIZING + or job.status.state == BatchJobState.FINALIZED + or job.status.errors is not None + ) + namespace = job.metadata.namespace or "default" + job_name = job.metadata.name + + try: + # Prepare base annotations + annotations_patch = BatchJobTransformer.create_status_annotations( + job.status + ) + # Set condition after update based on error or not. + if job.status.errors is None: + annotations_patch[JobAnnotationKey.CONDITION.value] = ( + ConditionType.CANCELLED.value + ) + else: + annotations_patch[JobAnnotationKey.CONDITION.value] = ( + ConditionType.FAILED.value + ) + + # Persist conditions (failed, cancelled) + suspend_patch = { + "metadata": { + "resourceVersion": job.metadata.resource_version, + "annotations": annotations_patch, + }, + "spec": { + "suspend": True # Suspend the Kubernetes Job (instead of deleting) + }, + } + + logger.info( # type: ignore[call-arg] + "Executing job cancellation", + job=job_name, + namespace=namespace, + job_id=job.job_id, + patch=suspend_patch, + ) + + await asyncio.to_thread( + self.batch_v1_api.patch_namespaced_job, + name=job_name, + namespace=namespace, + body=suspend_patch, + async_req=True, + ) + + except ApiException as e: + if e.status == 404: + logger.warning( # type: ignore[call-arg] + "Job not found in Kubernetes for cancellation", + job=job_name, + namespace=namespace, + job_id=job.job_id, + ) + elif e.status == 409: + logger.warning( # type: ignore[call-arg] + "Job status changed", + job=job_name, + namespace=namespace, + job_id=job.job_id, + ) + raise + else: + logger.error( # type: ignore[call-arg] + "Failed to cancel job in Kubernetes", + job=job_name, + namespace=namespace, + job_id=job.job_id, + error=str(e), + status_code=e.status, + reason=e.reason, + ) + raise + except Exception as e: + logger.error( # type: ignore[call-arg] + "Unexpected error cancelling job", + job=job_name, + namespace=namespace, + job_id=job.job_id, + error=str(e), + patch=str(suspend_patch), + operation="cancel_job", + ) + raise + + async def delete_job(self, job: BatchJob) -> None: + """Cancel job by deleting the Kubernetes Job. + + Args: + job_id: Job ID (batch ID) to cancel. + + Raises: + RuntimeError: If Kubernetes client is not available. + KeyError: If job is not found in cache. + ApiException: If Kubernetes API call fails. + """ + if not self.batch_v1_api: + raise RuntimeError("Kubernetes client not available") + + namespace = job.metadata.namespace or "default" + job_name = job.metadata.name + try: + # Delete the Kubernetes Job + await asyncio.to_thread( + self.batch_v1_api.delete_namespaced_job, + name=job_name, + namespace=namespace, + propagation_policy="Foreground", # Delete pods too + async_req=True, + ) + + logger.info( # type: ignore[call-arg] + "Job deletion requested in Kubernetes", + job_id=job.job_id, + job=job_name, + namespace=namespace, + ) + except ApiException as e: + if e.status == 404: + logger.warning( # type: ignore[call-arg] + "Job not found in Kubernetes for deletion", + job=job_name, + namespace=namespace, + ) + else: + logger.error( # type: ignore[call-arg] + "Failed to delete job in Kubernetes", + job_id=job.job_id, + job=job_name, + namespace=namespace, + error=str(e), + status_code=e.status, + reason=e.reason, + ) + raise + except Exception as e: + logger.error( # type: ignore[call-arg] + "Unexpected error deleting job", + job_id=job.job_id, + job=job_name, + namespace=namespace, + error=str(e), + operation="delete_job", + ) + raise + + def _ready_batch_job_to_k8s_job_patch(self, job: BatchJob) -> Dict[str, Any]: + """Convert BatchJob to Kubernetes Job patch manifest. Only annotations will be patched. + + Args: + job_spec: BatchJob to convert. + + Returns: + patch body object. + """ + # Use pre-loaded template (deep copy to avoid modifying the original) + job_status: BatchJobStatus = job.status + assert ( + job_status.in_progress_at is not None + ), "AssertError: Job must be set as in progress before setting as ready" + + patch_annotations = BatchJobTransformer.create_status_annotations(job_status) + patch_body = { + "metadata": { + "resourceVersion": job.metadata.resource_version, + "annotations": patch_annotations, + }, + "spec": { + "template": { + "metadata": { + "annotations": { + JobAnnotationKey.OUTPUT_FILE_ID: job_status.output_file_id, + JobAnnotationKey.TEMP_OUTPUT_FILE_ID: job_status.temp_output_file_id, + JobAnnotationKey.ERROR_FILE_ID: job_status.error_file_id, + JobAnnotationKey.TEMP_ERROR_FILE_ID: job_status.temp_error_file_id, + }, + }, + }, + "suspend": False, + }, + } + return patch_body + + def _batch_job_spec_to_k8s_job( + self, + session_id: str, + job_spec: BatchJobSpec, + job_name: Optional[str] = None, + parallelism: int = 1, + prepared_job: Optional[BatchJob] = None, + ) -> Dict[str, Any]: + """Convert BatchJobSpec to Kubernetes Job manifest using pre-loaded template. + + Args: + job_spec: BatchJobSpec to convert. + job_name: Optional job name, will generate one if not provided. + prepared_job: Optional BatchJob with file IDs to add to pod annotations. + + Returns: + Kubernetes V1Job object. + """ + # Generate unique job name + if job_name is None: + job_name = f"batch-{uuid.uuid4().hex[:8]}" + + # Create pod annotations from job spec + pod_annotations: Dict[str, str] = { + JobAnnotationKey.SESSION_ID.value: session_id, + JobAnnotationKey.INPUT_FILE_ID.value: job_spec.input_file_id, + JobAnnotationKey.ENDPOINT.value: job_spec.endpoint, + } + + # Add batch metadata as pod annotations + if job_spec.metadata: + for key, value in job_spec.metadata.items(): + pod_annotations[f"{JobAnnotationKey.METADATA_PREFIX.value}{key}"] = ( + value + ) + + # Add batch opts as pod annotations + if job_spec.opts: + for key, value in job_spec.opts.items(): + pod_annotations[f"{JobAnnotationKey.OPTS_PREFIX.value}{key}"] = value + + # Add file IDs from prepared job if provided + suspend = True + if prepared_job and prepared_job.status: + if prepared_job.status.output_file_id: + pod_annotations[JobAnnotationKey.OUTPUT_FILE_ID.value] = ( + prepared_job.status.output_file_id + ) + else: + suspend = False + + if prepared_job.status.temp_output_file_id: + pod_annotations[JobAnnotationKey.TEMP_OUTPUT_FILE_ID.value] = ( + prepared_job.status.temp_output_file_id + ) + else: + suspend = False + + if prepared_job.status.error_file_id: + pod_annotations[JobAnnotationKey.ERROR_FILE_ID.value] = ( + prepared_job.status.error_file_id + ) + else: + suspend = False + + if prepared_job.status.temp_error_file_id: + pod_annotations[JobAnnotationKey.TEMP_ERROR_FILE_ID.value] = ( + prepared_job.status.temp_error_file_id + ) + else: + suspend = False + + job_patch = { + "metadata": { + "name": job_name, + # Minimal job-level annotations - most metadata moved to pod + "annotations": { + "batch.job.aibrix.ai/managed-by": "aibrix", + }, + }, + "spec": { + "template": { + "metadata": { + "annotations": pod_annotations, + }, + }, + "activeDeadlineSeconds": job_spec.completion_window, + "suspend": suspend, + "parallelism": parallelism, + "completions": parallelism, + }, + } + # Use pre-loaded template (deep copy to avoid modifying the original) + job_template = merge_yaml_object(self.job_template, job_patch) + + # Merge storage env + if ( + storage_patch := self.storage_patches.get(storage.get_storage_type().value) + ) is not None: + job_template = merge_yaml_object(job_template, storage_patch, False) + else: + logger.warning( + "No storage patch found", storage_type=storage.get_storage_type() + ) # type:ignore[call-arg] + + # Merge metastore env + if ( + metastore_patch := self.storage_patches.get( + metastore.get_metastore_type().value + ) + ) is not None: + job_template = merge_yaml_object(job_template, metastore_patch, False) + else: + logger.warning( + "No metastore patch found", storage_type=metastore.get_metastore_type() + ) # type:ignore[call-arg] + + return job_template + + +logger.info("kopf job handlers imported") + + +# Standalone kopf handlers that work with the global JobCache instance +# Use event handler only to avoid advanced kopf features such as state management, +# which introduces customized annotation. +@kopf.on.event("batch", "v1", "jobs") # type: ignore[arg-type] +async def job_event_handler(type: str, body: Any, **kwargs: Any) -> None: + """Handle Kubernetes Job creation events.""" + job_cache = get_global_job_cache() + if not job_cache: + logger.warning("No global job cache available for job creation event") + return + + if type == "ADDED": + await job_created_handler(body, **kwargs) + elif type == "MODIFIED": + job_id = body.get("metadata", {}).get("uid") + if job_cache.active_jobs.get(job_id) is None: + await job_created_handler(body, **kwargs) + else: + await job_updated_handler(body, **kwargs) + elif type == "DELETED": + await job_deleted_handler(body, **kwargs) + + +# @kopf.on.create("batch", "v1", "jobs") # type: ignore[arg-type] +async def job_created_handler(body: Any, **kwargs: Any) -> None: + """Handle Kubernetes Job creation events.""" + job_cache = get_global_job_cache() + if not job_cache: + logger.warning("No global job cache available for job creation event") + return + + try: + # Transform K8s Job to BatchJob + batch_job = k8s_job_to_batch_job(body) + job_id = batch_job.status.job_id if batch_job.status else body.metadata.uid + + logger.info( + "Job created", + job_id=job_id, + name=batch_job.metadata.name, + namespace=batch_job.metadata.namespace, + state=batch_job.status.state.value, + resource_version=batch_job.metadata.resource_version, + ) # type: ignore[call-arg] + + # Invoke callback if registered + try: + if await job_cache.job_committed(batch_job): + # Store in cache + job_cache.active_jobs[job_id] = batch_job + else: + await job_cache.delete_job(batch_job) + except Exception as e: + logger.error( + "Error in job committed handler", + error=str(e), + handler="job_committed", + ) # type: ignore[call-arg] + except ValueError as ve: + # For jobs without proper annotations, store basic info for backward compatibility + job_id = body.metadata.uid + logger.warning( + "Failed to process job creation", + job_id=job_id, + reason=str(ve), + ) # type: ignore[call-arg] + except Exception as e: + logger.error( + "Failed to process job creation", error=str(e), operation="job_created" + ) # type: ignore[call-arg] + + +# @kopf.on.update("batch", "v1", "jobs") # type: ignore[arg-type] +async def job_updated_handler(body: Any, **kwargs: Any) -> None: + """Handle Kubernetes Job update events.""" + job_cache = get_global_job_cache() + if not job_cache: + logger.warning("No global job cache available for job update event") + return + + try: + # Transform new K8s Job to BatchJob + new_batch_job = k8s_job_to_batch_job(body) + job_id = ( + new_batch_job.status.job_id if new_batch_job.status else body.metadata.uid + ) + + # Get old job from cache + old_batch_job = job_cache.active_jobs.get(job_id) + if old_batch_job is None: + logger.warning("Job updating ignored due to job not found", job_id=job_id) # type: ignore[call-arg] + return + + logger.info( + "Job updated", + job_id=job_id, + name=new_batch_job.metadata.name, + namespace=new_batch_job.metadata.namespace, + old_state=old_batch_job.status.state.value if old_batch_job else "unknown", + new_state=new_batch_job.status.state.value, + resource_version=new_batch_job.metadata.resource_version, + ) # type: ignore[call-arg] + + # Invoke callback if registered and we have both old and new jobs + try: + if await job_cache.job_updated(old_batch_job, new_batch_job): + # Update cache + job_cache.active_jobs[job_id] = new_batch_job + except Exception as uhe: + logger.error( + "Error in job updated handler", + error=str(uhe), + handler="job_updated", + ) # type: ignore[call-arg] + except Exception as e: + logger.error( + "Failed to process job update", error=str(e), operation="job_updated" + ) # type: ignore[call-arg] + + +# @kopf.on.field("batch", "v1", "jobs", field="status.conditions") # type: ignore[arg-type] +async def job_completion_handler(body: Any, **kwargs: Any) -> None: + """ + This handler triggers ONLY when the 'status.conditions' field of a Job changes. + """ + if not body: # The conditions field might be None initially + return + + await job_updated_handler(body, **kwargs) # type: ignore[call-arg, misc, arg-type] + + +# Set optional = True to prevent kopf add the finalizer. +# @kopf.on.delete("batch", "v1", "jobs", optional=True) # type: ignore[arg-type] +async def job_deleted_handler(body: Any, **kwargs: Any) -> None: + """Handle Kubernetes Job deletion events.""" + job_cache = get_global_job_cache() + if not job_cache: + logger.warning("No global job cache available for job deletion event") + return + + job_id = body.metadata.uid + job_name = body.metadata.name + namespace = body.metadata.namespace + + # Get job from cache before deletion + deleted_job = job_cache.active_jobs.get(job_id) + if deleted_job is None: + logger.info( + "Job deleted event ignore, no job found", + job_id=job_id, + name=job_name, + namespace=namespace, + ) # type: ignore[call-arg] + return + + logger.info( + "Job deleted", + job_id=job_id, + job=deleted_job.metadata.name, + namespace=deleted_job.metadata.namespace, + state=deleted_job.status.state.value, + ) # type: ignore[call-arg] + + # Invoke callback if registered + try: + if await job_cache.job_deleted(deleted_job): + del job_cache.active_jobs[job_id] + except Exception as e: + logger.error( + "Error in job deleted handler", + error=str(e), + handler="job_deleted", + ) # type: ignore[call-arg] diff --git a/python/aibrix/aibrix/metadata/cache/utils.py b/python/aibrix/aibrix/metadata/cache/utils.py new file mode 100644 index 000000000..d5d7108c8 --- /dev/null +++ b/python/aibrix/aibrix/metadata/cache/utils.py @@ -0,0 +1,78 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections.abc +import copy + + +def merge_yaml_object(base, overlay, copy_on_write=True): + """ + Recursively merges two YAML objects, mimicking kustomize's strategic merge. + Accepts both dictionaries and Kubernetes API objects as input. + """ + base_dict = base.to_dict() if hasattr(base, "to_dict") else base + overlay_dict = overlay.to_dict() if hasattr(overlay, "to_dict") else overlay + + merged = base_dict + if copy_on_write: + merged = copy.deepcopy(base_dict) + + for key, value in overlay_dict.items(): + if ( + key in merged + and isinstance(merged[key], collections.abc.Mapping) + and isinstance(value, collections.abc.Mapping) + ): + merged[key] = merge_yaml_object(merged[key], value, False) + + elif ( + key in merged and isinstance(merged[key], list) and isinstance(value, list) + ): + # To merge a list, we use "name" field as the key + base_list = merged[key] + overlay_list = value + strategy_merge = False + + # Create a map of base items by their 'name' for quick lookups + base_items_by_name = { + item.get("name"): item + for item in base_list + if isinstance(item, collections.abc.Mapping) and "name" in item + } + strategy_merge = len(base_items_by_name) > 0 + + for item in overlay_list: + if isinstance(item, collections.abc.Mapping) and "name" in item: + item_name = item.get("name") + if item_name in base_items_by_name: + # If an item with the same name exists, merge them + base_item = base_items_by_name[item_name] + base_items_by_name[item_name] = merge_yaml_object( + base_item, item, False + ) + else: + # Otherwise, append the new item + base_items_by_name[item_name] = item + base_list.append(item) + else: + # If the overlay item isn't a dict with a name, just append it + base_list.append(item) + + if strategy_merge: + merged[key] = list(base_items_by_name.values()) + else: + merged[key] = base_list + else: + merged[key] = value + + return merged diff --git a/python/aibrix/aibrix/metadata/core/__init__.py b/python/aibrix/aibrix/metadata/core/__init__.py new file mode 100644 index 000000000..395ac8a46 --- /dev/null +++ b/python/aibrix/aibrix/metadata/core/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .asyncio_thread import AsyncLoopThread, T +from .httpx_client import HTTPXClientWrapper +from .kopf_operator import KopfOperatorWrapper + +__all__ = ["AsyncLoopThread", "HTTPXClientWrapper", "KopfOperatorWrapper", "T"] diff --git a/python/aibrix/aibrix/metadata/core/asyncio_thread.py b/python/aibrix/aibrix/metadata/core/asyncio_thread.py new file mode 100644 index 000000000..2e1ed1c56 --- /dev/null +++ b/python/aibrix/aibrix/metadata/core/asyncio_thread.py @@ -0,0 +1,113 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import threading +from concurrent.futures import Future +from typing import Any, Coroutine, Optional, TypeVar + +# Define a TypeVar to represent the generic return type. +T = TypeVar("T") + + +class AsyncLoopThread(threading.Thread): + """ + A class to run and manage an asyncio event loop in a dedicated thread. + """ + + def __init__(self, name: str) -> None: + super().__init__(daemon=True) + self.loop: Optional[asyncio.AbstractEventLoop] = None + # Use an event to signal when the loop in the new thread is ready. + self._loop_started = threading.Event() + self._name = name + + def run(self) -> None: + """ + This method is the entry point for the new thread. It creates, sets, + and runs the event loop forever. + """ + self.loop = asyncio.new_event_loop() + self.loop.name = self._name # type: ignore[attr-defined] + print(f"AsyncLoopThread using: {type(self.loop)}") + asyncio.set_event_loop(self.loop) + + # Signal that the loop is set up and ready. + self._loop_started.set() + + # This will run until loop.stop() is called from another thread. + self.loop.run_forever() + + # Cleanly close the loop when it's stopped. + self.loop.close() + + def start(self) -> None: + """ + Starts the thread and blocks until the event loop inside is ready. + """ + super().start() + self._loop_started.wait() + + def stop(self) -> None: + """ + Gracefully stops the event loop and waits for the thread to exit. + """ + if self.loop: + # This is a thread-safe way to schedule loop.stop() to be called. + self.loop.call_soon_threadsafe(self.loop.stop) + self.join() + + async def run_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: + """ + Submits a coroutine to the event loop and returns an awaitable Future. + This method itself MUST be awaited. (For use from async code) + """ + # 1. Submit the coroutine to the background loop thread-safely. + # This returns a concurrent.futures.Future, which is not awaitable. + concurrent_future = self.submit_coroutine(coro) + + # 2. Wrap the concurrent future in an asyncio future, which IS awaitable + # in the current (caller's) event loop. + asyncio_future = asyncio.wrap_future(concurrent_future) + + # 3. Await the asyncio future and return its result. + return await asyncio_future + + def submit_coroutine(self, coro: Coroutine[Any, Any, T]) -> Future[T]: + """ + Submits a coroutine to the event loop from current thread and returns + a future that can be used to get the result. + + Args: + coro: The coroutine to execute. + """ + if not self.loop: + raise RuntimeError("Loop is not running.") + + return asyncio.run_coroutine_threadsafe(coro, self.loop) + + def create_task(self, coro: Coroutine[Any, Any, T]) -> asyncio.Task[T]: + """ + Submits a task to the event loop from current thread and returns + a task that can be handled later + + Args: + coro: The coroutine to execute. + + Returns: + A asyncio.Task object for the result. + """ + if not self.loop: + raise RuntimeError("Loop is not running.") + + return self.loop.create_task(coro) diff --git a/python/aibrix/aibrix/metadata/core/kopf_operator.py b/python/aibrix/aibrix/metadata/core/kopf_operator.py new file mode 100644 index 000000000..85bcae719 --- /dev/null +++ b/python/aibrix/aibrix/metadata/core/kopf_operator.py @@ -0,0 +1,264 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from typing import Optional + +import kopf + +from aibrix.logger import init_logger + +logger = init_logger(__name__) + + +class KopfOperatorWrapper: + """ + A wrapper class to run kopf operator in a separate thread, managing its lifecycle + and integrating it with FastAPI application startup and shutdown. + + This class follows the same pattern as HTTPXClientWrapper, providing lifecycle + management methods that can be called from FastAPI lifespan hooks. + """ + + def __init__( + self, + namespace: str = "default", + startup_timeout: float = 30.0, + shutdown_timeout: float = 10.0, + standalone: bool = True, + peering_name: Optional[str] = None, + ) -> None: + """ + Initialize the kopf operator wrapper. + + Args: + namespace: Kubernetes namespace to monitor (default: "default") + startup_timeout: Maximum time to wait for operator startup (seconds) + shutdown_timeout: Maximum time to wait for operator shutdown (seconds) + standalone: Whether to run kopf in standalone mode + peering_name: Used for high availability + """ + self.namespace = namespace + self.startup_timeout = startup_timeout + self.shutdown_timeout = shutdown_timeout + self.standalone = standalone + self.peering_name = peering_name + + # Threading coordination + self._operator_thread: Optional[threading.Thread] = None + self._stop_event: Optional[threading.Event] = None + self._ready_event: Optional[threading.Event] = None + self._is_running = False + + # Error tracking + self._startup_error: Optional[Exception] = None + + def start(self) -> None: + """ + Start the kopf operator in a background thread. + + This method is synchronous and will block until the operator is ready + or times out. Call from the FastAPI startup hook. + + Raises: + RuntimeError: If operator fails to start within timeout + Exception: Any startup error from the operator thread + """ + if self._is_running: + logger.warning("Kopf operator is already running") + return + + logger.info( + "Starting kopf operator", + namespace=self.namespace, + timeout=self.startup_timeout, + ) # type: ignore[call-arg] + + # Create threading coordination objects + self._stop_event = threading.Event() + self._ready_event = threading.Event() + self._startup_error = None + + # Start operator thread + self._operator_thread = threading.Thread( + target=self._run_operator, + name="kopf-operator", + daemon=True, + ) + self._operator_thread.start() + + # Wait for operator to be ready or fail + if not self._ready_event.wait(timeout=self.startup_timeout): + # Startup timeout - clean up and raise error + self._stop_event.set() + if self._operator_thread.is_alive(): + self._operator_thread.join(timeout=5.0) + + error_msg = f"Kopf operator did not start within {self.startup_timeout}s" + if self._startup_error: + error_msg = f"{error_msg}. Error: {self._startup_error}" + + logger.error("Kopf operator startup failed", reason="timeout") # type: ignore[call-arg] + raise RuntimeError(error_msg) + + # Check if startup failed with an error + if self._startup_error: + logger.error("Kopf operator startup failed", error=str(self._startup_error)) # type: ignore[call-arg] + raise self._startup_error + + self._is_running = True + logger.info( + "Kopf operator started successfully", + thread_id=self._operator_thread.ident, + namespace=self.namespace, + ) # type: ignore[call-arg] + + def stop(self) -> None: + """ + Gracefully stop the kopf operator. + + This method is async to match the FastAPI shutdown pattern. + Call from the FastAPI shutdown hook. + """ + if not self._is_running: + logger.debug("Kopf operator is not running") + return + + logger.info("Stopping kopf operator") # type: ignore[call-arg] + + # Signal operator to stop + if self._stop_event: + self._stop_event.set() + + # Wait for thread to finish + if self._operator_thread and self._operator_thread.is_alive(): + logger.debug( + "Waiting for kopf operator thread to finish", + timeout=self.shutdown_timeout, + ) # type: ignore[call-arg] + + # Join thread with timeout + self._operator_thread.join(timeout=self.shutdown_timeout) + + if self._operator_thread.is_alive(): + logger.warning( + "Kopf operator thread did not stop within timeout", + timeout=self.shutdown_timeout, + ) # type: ignore[call-arg] + else: + logger.info("Kopf operator thread stopped successfully") # type: ignore[call-arg] + + # Clean up state + self._is_running = False + self._operator_thread = None + self._stop_event = None + self._ready_event = None + self._startup_error = None + + logger.info("Kopf operator stopped") # type: ignore[call-arg] + + def is_running(self) -> bool: + """ + Check if the kopf operator is currently running. + + Returns: + bool: True if operator is running, False otherwise + """ + return self._is_running and ( + self._operator_thread is not None and self._operator_thread.is_alive() + ) + + def _run_operator(self) -> None: + """ + Target function for the operator thread. + + This runs in a separate thread and handles the kopf.run() call + with proper error handling and ready signaling. + """ + try: + logger.debug("Kopf operator thread started") # type: ignore[call-arg] + + # Import handlers module to register kopf handlers before running operator + # This ensures all @kopf.on.* decorated handlers are registered with the default registry + # logger.debug("Importing kopf handlers for registration") # type: ignore[call-arg] + # try: + # from aibrix.metadata.cache import job # noqa: F401 + + # logger.debug("Successfully imported kopf handlers module") # type: ignore[call-arg] + # except ImportError as import_error: + # logger.warning( + # "Failed to import kopf handlers module - handlers may not be available", + # error=str(import_error), + # ) # type: ignore[call-arg] + + # Run kopf operator with our coordination objects + # Only set namespace if specified, otherwise watch cluster-wide + logger.info( + "Starting kopf operator with namespace restriction", + namespace=self.namespace, + ) # type: ignore[call-arg] + kopf.run( + standalone=self.standalone, + namespace=self.namespace, + ready_flag=self._ready_event, + stop_flag=self._stop_event, + # Additional kopf configuration + peering_name=self.peering_name, # Unique peering name + ) + + logger.debug("Kopf operator thread finished normally") # type: ignore[call-arg] + + except Exception as e: + logger.error( + "Kopf operator thread failed", + error=str(e), + error_type=type(e).__name__, + ) # type: ignore[call-arg] + + # Store error for main thread + self._startup_error = e + + # Signal ready even on error so main thread doesn't hang + if self._ready_event: + self._ready_event.set() + + def get_status(self) -> dict: + """ + Get detailed status information about the operator. + + Returns: + dict: Status information including running state, thread info, etc. + """ + status = { + "is_running": self.is_running(), + "namespace": self.namespace, + "standalone": self.standalone, + "startup_timeout": self.startup_timeout, + "shutdown_timeout": self.shutdown_timeout, + } + + if self._operator_thread: + status.update( + { + "thread_name": self._operator_thread.name, + "thread_id": self._operator_thread.ident, + "thread_alive": self._operator_thread.is_alive(), + "thread_daemon": self._operator_thread.daemon, + } + ) + + if self._startup_error: + status["startup_error"] = str(self._startup_error) + + return status diff --git a/python/aibrix/aibrix/metadata/logger.py b/python/aibrix/aibrix/metadata/logger.py deleted file mode 100644 index 10ba1a86f..000000000 --- a/python/aibrix/aibrix/metadata/logger.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 The Aibrix Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import sys -from logging import Logger -from logging.handlers import RotatingFileHandler - -import structlog - -from aibrix.metadata.setting import settings - - -def _default_logging_basic_config() -> None: - # 1. Configure the standard library logging - handler: logging.Handler = logging.StreamHandler(stream=sys.stdout) - if settings.LOG_PATH is not None: - handler = RotatingFileHandler( - settings.LOG_PATH, - maxBytes=10 * (2**20), - backupCount=10, - ) - logging.basicConfig(format="%(message)s", handlers=[handler]) - - # 2. Configure structlog processors and renderer - structlog.configure( - processors=[ - structlog.stdlib.add_logger_name, - structlog.stdlib.add_log_level, - structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S %Z"), - structlog.processors.StackInfoRenderer(), - structlog.processors.format_exc_info, - structlog.processors.JSONRenderer(), # Renders the log event as JSON - ], - logger_factory=structlog.stdlib.LoggerFactory(), - wrapper_class=structlog.stdlib.BoundLogger, - cache_logger_on_first_use=True, - ) - - -def init_logger(name: str) -> Logger: - logger = structlog.get_logger(name) - logger.setLevel(settings.LOG_LEVEL) - return logger - - -_default_logging_basic_config() -logger = init_logger(__name__) diff --git a/python/aibrix/aibrix/metadata/secret_gen.py b/python/aibrix/aibrix/metadata/secret_gen.py new file mode 100644 index 000000000..04d93f448 --- /dev/null +++ b/python/aibrix/aibrix/metadata/secret_gen.py @@ -0,0 +1,408 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Kubernetes secret generator for S3 and TOS credentials. + +This module provides utilities to generate and apply Kubernetes secrets +for S3 and TOS storage backends using their respective credential sources. +""" + +import base64 +import os +from pathlib import Path +from typing import Dict, Optional + +import boto3 +import yaml +from kubernetes import client + +from aibrix.logger import init_logger + +logger = init_logger(__name__) + + +class SecretGenerator: + """Generator for Kubernetes secrets with storage credentials.""" + + def __init__(self, namespace: str = "default"): + """ + Initialize the secret generator. + + Args: + namespace: Kubernetes namespace to create secrets in + """ + self.namespace = namespace + self.core_v1 = client.CoreV1Api() + self.setting_dir = Path(__file__).parent / "setting" + + def _encode_data(self, data: Dict[str, str]) -> Dict[str, str]: + """ + Base64 encode secret data for Kubernetes. + + Args: + data: Dictionary of key-value pairs to encode + + Returns: + Dictionary with base64 encoded values + """ + return { + key: base64.b64encode(value.encode()).decode() + for key, value in data.items() + if value is not None + } + + def _load_template(self, template_name: str) -> Dict: + """ + Load a secret template from the setting directory. + + Args: + template_name: Name of the template file + + Returns: + Parsed YAML template + """ + template_path = self.setting_dir / template_name + with open(template_path, "r") as f: + return yaml.safe_load(f) + + def _get_s3_credentials(self) -> Dict[str, str]: + """ + Get S3 credentials using boto3. + + Returns: + Dictionary with S3 credentials + + Raises: + RuntimeError: If credentials cannot be obtained + """ + try: + session = boto3.Session() + credentials = session.get_credentials() + + if not credentials: + raise RuntimeError("No AWS credentials found") + + access_key = credentials.access_key + secret_key = credentials.secret_key + region = session.region_name or "us-east-1" + + if not access_key or not secret_key: + raise RuntimeError("AWS credentials incomplete") + + return { + "access_key": access_key, + "secret_key": secret_key, + "region": region, + } + + except Exception as e: + raise RuntimeError(f"Failed to get S3 credentials: {e}") + + def _get_tos_credentials(self) -> Dict[str, str]: + """ + Get TOS credentials from environment variables. + + Returns: + Dictionary with TOS credentials + + Raises: + RuntimeError: If required environment variables are not set + """ + tos_access_key = os.getenv("TOS_ACCESS_KEY") + tos_secret_key = os.getenv("TOS_SECRET_KEY") + tos_endpoint = os.getenv("TOS_ENDPOINT") + tos_region = os.getenv("TOS_REGION") + + if not all([tos_access_key, tos_secret_key, tos_endpoint, tos_region]): + missing = [ + var + for var, val in [ + ("TOS_ACCESS_KEY", tos_access_key), + ("TOS_SECRET_KEY", tos_secret_key), + ("TOS_ENDPOINT", tos_endpoint), + ("TOS_REGION", tos_region), + ] + if not val + ] + raise RuntimeError( + f"Missing TOS environment variables: {', '.join(missing)}" + ) + + # Type assertions after None check + assert tos_access_key is not None + assert tos_secret_key is not None + assert tos_endpoint is not None + assert tos_region is not None + + return { + "access_key": tos_access_key, + "secret_key": tos_secret_key, + "endpoint": tos_endpoint, + "region": tos_region, + } + + def create_s3_secret( + self, bucket_name: Optional[str] = None, secret_name: Optional[str] = None + ) -> str: + """ + Create a Kubernetes secret with S3 credentials. + + Args: + bucket_name: S3 bucket name (optional) + secret_name: Custom secret name (optional, uses template default) + + Returns: + Name of the created secret + + Raises: + RuntimeError: If secret creation fails + """ + try: + # Load template + template = self._load_template("s3_secret_template.yaml") + + # Get S3 credentials + credentials = self._get_s3_credentials() + + # Prepare secret data + secret_data = { + "access-key-id": credentials["access_key"], + "secret-access-key": credentials["secret_key"], + "region": credentials["region"], + } + + if bucket_name: + secret_data["bucket-name"] = bucket_name + + # Update template + if secret_name: + template["metadata"]["name"] = secret_name + template["metadata"]["namespace"] = self.namespace + template["data"] = self._encode_data(secret_data) + + # Create Kubernetes secret object + secret = client.V1Secret( + metadata=client.V1ObjectMeta( + name=template["metadata"]["name"], namespace=self.namespace + ), + data=template["data"], + type=template["type"], + ) + + # Apply to cluster + secret_name = template["metadata"]["name"] + assert secret_name is not None, "Secret name must be set in template" + + # Delete existing secret if it exists + try: + self.core_v1.delete_namespaced_secret( + name=secret_name, namespace=self.namespace + ) + logger.info(f"Deleted existing S3 secret: {secret_name}") + except client.ApiException as e: + if e.status != 404: + raise + + # Create the secret + self.core_v1.create_namespaced_secret(namespace=self.namespace, body=secret) + logger.info( + f"Created S3 secret: {secret_name} in namespace: {self.namespace}" + ) + + return secret_name + + except Exception as e: + raise RuntimeError(f"Failed to create S3 secret: {e}") + + def create_tos_secret( + self, bucket_name: Optional[str] = None, secret_name: Optional[str] = None + ) -> str: + """ + Create a Kubernetes secret with TOS credentials. + + Args: + bucket_name: TOS bucket name (optional) + secret_name: Custom secret name (optional, uses template default) + + Returns: + Name of the created secret + + Raises: + RuntimeError: If secret creation fails + """ + try: + # Load template + template = self._load_template("tos_secret_template.yaml") + + # Get TOS credentials + credentials = self._get_tos_credentials() + + # Prepare secret data + secret_data = { + "access-key": credentials["access_key"], + "secret-key": credentials["secret_key"], + "endpoint": credentials["endpoint"], + "region": credentials["region"], + } + + if bucket_name: + secret_data["bucket-name"] = bucket_name + + # Update template + if secret_name: + template["metadata"]["name"] = secret_name + template["metadata"]["namespace"] = self.namespace + template["data"] = self._encode_data(secret_data) + + # Create Kubernetes secret object + secret = client.V1Secret( + metadata=client.V1ObjectMeta( + name=template["metadata"]["name"], namespace=self.namespace + ), + data=template["data"], + type=template["type"], + ) + + # Apply to cluster + secret_name = template["metadata"]["name"] + assert secret_name is not None, "Secret name must be set in template" + + # Delete existing secret if it exists + try: + self.core_v1.delete_namespaced_secret( + name=secret_name, namespace=self.namespace + ) + logger.info(f"Deleted existing TOS secret: {secret_name}") + except client.ApiException as e: + if e.status != 404: + raise + + # Create the secret + self.core_v1.create_namespaced_secret(namespace=self.namespace, body=secret) + logger.info( + f"Created TOS secret: {secret_name} in namespace: {self.namespace}" + ) + + return secret_name + + except Exception as e: + raise RuntimeError(f"Failed to create TOS secret: {e}") + + def delete_secret(self, secret_name: str) -> bool: + """ + Delete a Kubernetes secret. + + Args: + secret_name: Name of the secret to delete + + Returns: + True if deleted successfully, False if not found + + Raises: + RuntimeError: If deletion fails for reasons other than not found + """ + try: + self.core_v1.delete_namespaced_secret( + name=secret_name, namespace=self.namespace + ) + logger.info( + f"Deleted secret: {secret_name} from namespace: {self.namespace}" + ) + return True + + except client.ApiException as e: + if e.status == 404: + logger.warning( + f"Secret {secret_name} not found in namespace {self.namespace}" + ) + return False + else: + raise RuntimeError(f"Failed to delete secret {secret_name}: {e}") + + def secret_exists(self, secret_name: str) -> bool: + """ + Check if a secret exists in the namespace. + + Args: + secret_name: Name of the secret to check + + Returns: + True if secret exists, False otherwise + """ + try: + self.core_v1.read_namespaced_secret( + name=secret_name, namespace=self.namespace + ) + return True + except client.ApiException as e: + if e.status == 404: + return False + else: + raise RuntimeError(f"Failed to check secret {secret_name}: {e}") + + +def create_s3_secret( + namespace: str = "default", + bucket_name: Optional[str] = None, + secret_name: Optional[str] = None, +) -> str: + """ + Convenience function to create an S3 secret. + + Args: + namespace: Kubernetes namespace + bucket_name: S3 bucket name (optional) + secret_name: Custom secret name (optional) + + Returns: + Name of the created secret + """ + generator = SecretGenerator(namespace) + return generator.create_s3_secret(bucket_name, secret_name) + + +def create_tos_secret( + namespace: str = "default", + bucket_name: Optional[str] = None, + secret_name: Optional[str] = None, +) -> str: + """ + Convenience function to create a TOS secret. + + Args: + namespace: Kubernetes namespace + bucket_name: TOS bucket name (optional) + secret_name: Custom secret name (optional) + + Returns: + Name of the created secret + """ + generator = SecretGenerator(namespace) + return generator.create_tos_secret(bucket_name, secret_name) + + +def delete_secret(secret_name: str, namespace: str = "default") -> bool: + """ + Convenience function to delete a secret. + + Args: + secret_name: Name of the secret to delete + namespace: Kubernetes namespace + + Returns: + True if deleted successfully, False if not found + """ + generator = SecretGenerator(namespace) + return generator.delete_secret(secret_name) diff --git a/python/aibrix/aibrix/metadata/setting/config.py b/python/aibrix/aibrix/metadata/setting/config.py index f8aa4f342..218df3c3a 100644 --- a/python/aibrix/aibrix/metadata/setting/config.py +++ b/python/aibrix/aibrix/metadata/setting/config.py @@ -13,47 +13,28 @@ # limitations under the License. from typing import Optional -from pydantic_settings import BaseSettings, SettingsConfigDict - +from aibrix.config import AIBrixSettings from aibrix.storage.types import StorageType -class Settings(BaseSettings): - # Model configuration for Pydantic (v2) - # This tells Pydantic where to look for environment variables - # (e.g., in a .env file if it exists, or actual env vars) - model_config = SettingsConfigDict( - env_file=".env", # Load environment variables from a .env file - env_file_encoding="utf-8", # Encoding for the .env file - extra="ignore", # Ignore extra environment variables not defined here - ) - +class Settings(AIBrixSettings): # --- Application General Settings --- PROJECT_NAME: str = "AIBrix Extension API Server" PROJECT_VERSION: str = "1.0.0" API_V1_STR: str = "/v1" # Base path for version 1 of your API - # --- Security Settings --- - SECRET_KEY: str = ( - "test-secret-key-for-testing" # No default, reserved for later use. - ) - # --- CORS (Cross-Origin Resource Sharing) Settings --- # List of origins that are allowed to make requests to your API # Example: ["http://localhost:3000", "https://your-frontend-domain.com"] # Use ["*"] for development, but specify exact origins in production. BACKEND_CORS_ORIGINS: list[str] = ["*"] - # --- Logging Settings --- - LOG_LEVEL: str = "DEBUG" # DEBUG, INFO, WARNING, ERROR, CRITICAL - LOG_PATH: Optional[str] = None # If None, logs to stdout - # --- External Service URLs (if any) --- EXTERNAL_API_URL: Optional[str] = None # Example: URL for an external microservice # --- File API settings --- - STORAGE_TYPE: StorageType = StorageType.LOCAL - METASTORE_TYPE: StorageType = StorageType.LOCAL + STORAGE_TYPE: StorageType = StorageType.AUTO + METASTORE_TYPE: StorageType = StorageType.AUTO MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1G in bytes diff --git a/python/aibrix/aibrix/metadata/setting/k8s_job_rbac.yaml b/python/aibrix/aibrix/metadata/setting/k8s_job_rbac.yaml new file mode 100644 index 000000000..b0c26a1ab --- /dev/null +++ b/python/aibrix/aibrix/metadata/setting/k8s_job_rbac.yaml @@ -0,0 +1,30 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: job-reader-sa + namespace: default +--- +# Service Account for job pod +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: job-reader-role + namespace: default +rules: +- apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get"] # Get permissions only +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: job-reader-binding + namespace: default +subjects: +- kind: ServiceAccount + name: job-reader-sa + namespace: default +roleRef: + kind: Role + name: job-reader-role + apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/python/aibrix/aibrix/metadata/setting/k8s_job_redis_patch.yaml b/python/aibrix/aibrix/metadata/setting/k8s_job_redis_patch.yaml new file mode 100644 index 000000000..6931e5c3b --- /dev/null +++ b/python/aibrix/aibrix/metadata/setting/k8s_job_redis_patch.yaml @@ -0,0 +1,15 @@ +# Patch for k8s_job_template.yaml to enable S3 testing with Kubernetes secrets +apiVersion: batch/v1 +kind: Job +spec: + template: + spec: + containers: + - name: batch-worker + env: + - name: REDIS_HOST + value: "aibrix-redis-master.aibrix-system.svc.cluster.local" + - name: REDIS_PORT + value: "6379" + - name: REDIS_DB + value: "0" \ No newline at end of file diff --git a/python/aibrix/aibrix/metadata/setting/k8s_job_s3_patch.yaml b/python/aibrix/aibrix/metadata/setting/k8s_job_s3_patch.yaml new file mode 100644 index 000000000..cbabee1ad --- /dev/null +++ b/python/aibrix/aibrix/metadata/setting/k8s_job_s3_patch.yaml @@ -0,0 +1,29 @@ +# Patch for k8s_job_template.yaml to enable S3 testing with Kubernetes secrets +apiVersion: batch/v1 +kind: Job +spec: + template: + spec: + containers: + - name: batch-worker + env: + - name: STORAGE_AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: access-key-id + - name: STORAGE_AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: secret-access-key + - name: STORAGE_AWS_REGION + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: region + - name: STORAGE_AWS_BUCKET + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: bucket-name \ No newline at end of file diff --git a/python/aibrix/aibrix/metadata/setting/k8s_job_template.yaml b/python/aibrix/aibrix/metadata/setting/k8s_job_template.yaml new file mode 100644 index 000000000..33ff63261 --- /dev/null +++ b/python/aibrix/aibrix/metadata/setting/k8s_job_template.yaml @@ -0,0 +1,89 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: batch-job-template + namespace: default + labels: + app: aibrix-batch + annotations: + # Template annotations will be merged with BatchJobSpec annotations +spec: + suspend: true # !!Important: creates the job in a paused state + template: + metadata: + labels: + app: aibrix-batch + spec: + serviceAccountName: job-reader-sa + automountServiceAccountToken: true + shareProcessNamespace: true # Allow worker to kill llm-engine + restartPolicy: Never + containers: + - name: batch-worker + image: aibrix/runtime:nightly + command: + - aibrix_batch_worker + env: + - name: JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.labels['job-name'] + - name: JOB_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + - name: JOB_UID + valueFrom: + fieldRef: + fieldPath: metadata.labels['batch.kubernetes.io/controller-uid'] + - name: LLM_READY_ENDPOINT + value: "http://localhost:8000/ready" # keep consistent with llm-engine['readinessProbe'] + # Batch job metadata from pod annotations + - name: BATCH_INPUT_FILE_ID + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.job.aibrix.ai/input-file-id'] + - name: BATCH_ENDPOINT + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.job.aibrix.ai/endpoint'] + - name: BATCH_OUTPUT_FILE_ID + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.job.aibrix.ai/output-file-id'] + - name: BATCH_TEMP_OUTPUT_FILE_ID + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.job.aibrix.ai/temp-output-file-id'] + - name: BATCH_ERROR_FILE_ID + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.job.aibrix.ai/error-file-id'] + - name: BATCH_TEMP_ERROR_FILE_ID + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.job.aibrix.ai/temp-error-file-id'] + - name: BATCH_OPTS_FAIL_AFTER_N_REQUESTS + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.job.aibrix.ai/opts.fail_after_n_requests'] + - name: llm-engine + image: aibrix/vllm-mock:nightly + ports: + - containerPort: 8000 + command: ["/bin/sh", "-c"] + args: + - | + # Run llm engine. The '|| true' at the end ensures this line never fails. + WORKER_VICTIM=1 python app.py || true + readinessProbe: + failureThreshold: 3 + httpGet: + path: /ready + port: 8000 + scheme: HTTP + periodSeconds: 5 + successThreshold: 1 + timeoutSeconds: 1 + backoffLimit: 2 + activeDeadlineSeconds: 86400 # 24 hours \ No newline at end of file diff --git a/python/aibrix/aibrix/metadata/setting/k8s_job_tos_patch.yaml b/python/aibrix/aibrix/metadata/setting/k8s_job_tos_patch.yaml new file mode 100644 index 000000000..2beea599b --- /dev/null +++ b/python/aibrix/aibrix/metadata/setting/k8s_job_tos_patch.yaml @@ -0,0 +1,34 @@ +# Patch for k8s_job_template.yaml to enable S3 testing with Kubernetes secrets +apiVersion: batch/v1 +kind: Job +spec: + template: + spec: + containers: + - name: batch-worker + env: + - name: STORAGE_TOS_ACCESS_KEY + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: access-key + - name: STORAGE_TOS_SECRET_KEY + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: secret-key + - name: STORAGE_TOS_ENDPOINT + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: endpoint + - name: STORAGE_TOS_REGION + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: region + - name: STORAGE_TOS_BUCKET + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: bucket-name \ No newline at end of file diff --git a/python/aibrix/aibrix/metadata/setting/s3_secret_template.yaml b/python/aibrix/aibrix/metadata/setting/s3_secret_template.yaml new file mode 100644 index 000000000..9e70cc2f6 --- /dev/null +++ b/python/aibrix/aibrix/metadata/setting/s3_secret_template.yaml @@ -0,0 +1,14 @@ +# Kubernetes Secret template for S3 credentials +# This is a template that can be populated using secret_gen utility +apiVersion: v1 +kind: Secret +metadata: + name: aibrix-s3-credentials + namespace: default +type: Opaque +data: + # Base64 encoded values will be populated by the test + access-key-id: "" + secret-access-key: "" + region: "" + bucket-name: "" \ No newline at end of file diff --git a/python/aibrix/aibrix/metadata/setting/tos_secret_template.yaml b/python/aibrix/aibrix/metadata/setting/tos_secret_template.yaml new file mode 100644 index 000000000..cf96c3ab3 --- /dev/null +++ b/python/aibrix/aibrix/metadata/setting/tos_secret_template.yaml @@ -0,0 +1,15 @@ +# Kubernetes Secret template for TOS credentials +# This is a template that can be populated using secret_gen utility +apiVersion: v1 +kind: Secret +metadata: + name: aibrix-tos-credentials + namespace: default +type: Opaque +data: + # Base64 encoded values will be populated by the test + access-key: "" + secret-key: "" + endpoint: "" + region: "" + bucket-name: "" \ No newline at end of file diff --git a/python/aibrix/aibrix/storage/base.py b/python/aibrix/aibrix/storage/base.py index b8af1bd68..5a656f355 100644 --- a/python/aibrix/aibrix/storage/base.py +++ b/python/aibrix/aibrix/storage/base.py @@ -17,9 +17,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from io import BytesIO, StringIO -from typing import Any, AsyncIterator, BinaryIO, Optional, TextIO, Union +from typing import AsyncIterator, BinaryIO, Optional, TextIO, Union from aibrix.storage.reader import Reader +from aibrix.storage.types import StorageType from aibrix.storage.utils import ObjectMetadata @@ -40,6 +41,58 @@ class StorageConfig: readline_buffer_size: int = 8192 # 8KB buffer for readline +@dataclass +class PutObjectOptions: + """Options for put_object operations with advanced features.""" + + # TTL support + ttl_seconds: Optional[int] = None + ttl_milliseconds: Optional[int] = None + + # Conditional operations + set_if_not_exists: bool = False # Like Redis NX + set_if_exists: bool = False # Like Redis XX + + def __post_init__(self): + """Validate options.""" + if self.set_if_not_exists and self.set_if_exists: + raise ValueError("Cannot specify both set_if_not_exists and set_if_exists") + + if self.ttl_seconds is not None and self.ttl_milliseconds is not None: + raise ValueError("Cannot specify both ttl_seconds and ttl_milliseconds") + + +class PutObjectOptionsBuilder: + """Helper class to construct PutObjectOptions.""" + + def __init__(self): + self._options = PutObjectOptions() + + def ttl_seconds(self, seconds: int) -> "PutObjectOptionsBuilder": + """Set TTL in seconds.""" + self._options.ttl_seconds = seconds + return self + + def ttl_milliseconds(self, milliseconds: int) -> "PutObjectOptionsBuilder": + """Set TTL in milliseconds.""" + self._options.ttl_milliseconds = milliseconds + return self + + def if_not_exists(self) -> "PutObjectOptionsBuilder": + """Only set if key doesn't exist (like Redis NX).""" + self._options.set_if_not_exists = True + return self + + def if_exists(self) -> "PutObjectOptionsBuilder": + """Only set if key exists (like Redis XX).""" + self._options.set_if_exists = True + return self + + def build(self) -> PutObjectOptions: + """Build the options object.""" + return self._options + + class BaseStorage(ABC): """Base class for all storage implementations. @@ -48,19 +101,54 @@ class BaseStorage(ABC): - Multipart uploads for large files - Range gets for partial file reads - Readline functionality backed by range gets + - Advanced put_object options (TTL, conditional operations) """ def __init__(self, config: Optional[StorageConfig] = None): self.config = config or StorageConfig() + @abstractmethod + def get_type(self) -> StorageType: + """Get the type of storage. + + Returns: + Type of storage + """ + pass + + def is_ttl_supported(self) -> bool: + """Check if TTL (Time To Live) is supported. + + Returns: + True if TTL is supported, False otherwise + """ + return False + + def is_set_if_not_exists_supported(self) -> bool: + """Check if conditional SET IF NOT EXISTS is supported. + + Returns: + True if SET IF NOT EXISTS is supported, False otherwise + """ + return False + + def is_set_if_exists_supported(self) -> bool: + """Check if conditional SET IF EXISTS is supported. + + Returns: + True if SET IF EXISTS is supported, False otherwise + """ + return False + @abstractmethod async def put_object( self, key: str, data: Union[bytes, str, BinaryIO, TextIO, Reader], content_type: Optional[str] = None, - metadata: Optional[dict[str, Any]] = None, - ) -> None: + metadata: Optional[dict[str, str]] = None, + options: Optional[PutObjectOptions] = None, + ) -> bool: """Put an object to storage. Args: @@ -68,6 +156,13 @@ async def put_object( data: Data to write (bytes, string, or file-like object) content_type: MIME type of the content metadata: Additional metadata to store with object + options: Advanced options for put operation + + Returns: + True if object was stored, False if conditional operation failed + + Raises: + ValueError: If unsupported options are specified """ pass @@ -240,6 +335,34 @@ async def _native_abort_multipart_upload( """ pass + def _validate_put_options(self, options: Optional[PutObjectOptions]) -> None: + """Validate put_object options against backend capabilities. + + Args: + options: Options to validate + + Raises: + ValueError: If unsupported options are specified + """ + if options is None: + return + + if options.ttl_seconds is not None or options.ttl_milliseconds is not None: + if not self.is_ttl_supported(): + raise ValueError( + f"TTL not supported by {self.get_type().value} storage" + ) + + if options.set_if_not_exists and not self.is_set_if_not_exists_supported(): + raise ValueError( + f"SET IF NOT EXISTS not supported by {self.get_type().value} storage" + ) + + if options.set_if_exists and not self.is_set_if_exists_supported(): + raise ValueError( + f"SET IF EXISTS not supported by {self.get_type().value} storage" + ) + async def multipart_upload( self, key: str, diff --git a/python/aibrix/aibrix/storage/factory.py b/python/aibrix/aibrix/storage/factory.py index 41f2b7e79..e2e1ff8c0 100644 --- a/python/aibrix/aibrix/storage/factory.py +++ b/python/aibrix/aibrix/storage/factory.py @@ -101,8 +101,8 @@ def create_storage( ) elif storage_type == StorageType.REDIS: - host = kwargs.get("host", "localhost") or envs.STORAGE_REDIS_HOST or "localhost" - port = kwargs.get("port", 6379) or envs.STORAGE_REDIS_PORT + host = kwargs.get("host") or envs.STORAGE_REDIS_HOST or "localhost" + port = kwargs.get("port") or envs.STORAGE_REDIS_PORT or 6379 db = kwargs.get("db", 0) or envs.STORAGE_REDIS_DB password = kwargs.get("password") or envs.STORAGE_REDIS_PASSWORD diff --git a/python/aibrix/aibrix/storage/local.py b/python/aibrix/aibrix/storage/local.py index e8770b5d3..9f33c882d 100644 --- a/python/aibrix/aibrix/storage/local.py +++ b/python/aibrix/aibrix/storage/local.py @@ -20,7 +20,12 @@ from pathlib import Path from typing import AsyncIterator, BinaryIO, Optional, TextIO, Union -from aibrix.storage.base import BaseStorage, StorageConfig +from aibrix.storage.base import ( + BaseStorage, + PutObjectOptions, + StorageConfig, + StorageType, +) from aibrix.storage.reader import Reader from aibrix.storage.utils import ObjectMetadata, _sanitize_key, generate_filename @@ -82,14 +87,26 @@ def _infer_content_type(self, key: str) -> Optional[str]: } return content_type_map.get(suffix) + def get_type(self) -> StorageType: + """Get the type of storage. + + Returns: + Type of storage, set to StorageType.LOCAL + """ + return StorageType.LOCAL + async def put_object( self, key: str, data: Union[bytes, str, BinaryIO, TextIO, Reader], content_type: Optional[str] = None, metadata: Optional[dict[str, str]] = None, - ) -> None: + options: Optional[PutObjectOptions] = None, + ) -> bool: """Put an object to local filesystem.""" + # Validate options (local storage doesn't support advanced options) + self._validate_put_options(options) + # Infer content type from file extension if not provided if content_type is None: content_type = self._infer_content_type(key) @@ -131,6 +148,8 @@ def _get_file_metadata(): file_last_modified, ) + return True # Local storage always succeeds + def _write_file(self, path: Path, reader: Reader) -> None: """Write data to file (synchronous helper).""" if reader.is_binary(): diff --git a/python/aibrix/aibrix/storage/redis.py b/python/aibrix/aibrix/storage/redis.py index 038ecdabe..d0bbebe07 100644 --- a/python/aibrix/aibrix/storage/redis.py +++ b/python/aibrix/aibrix/storage/redis.py @@ -17,7 +17,12 @@ import redis.asyncio as redis -from aibrix.storage.base import BaseStorage, StorageConfig +from aibrix.storage.base import ( + BaseStorage, + PutObjectOptions, + StorageConfig, + StorageType, +) from aibrix.storage.reader import Reader from aibrix.storage.utils import ObjectMetadata @@ -55,6 +60,14 @@ def __init__( self.password = password self._redis: Optional[redis.Redis] = None + def get_type(self) -> StorageType: + """Get the type of storage. + + Returns: + Type of storage, set to StorageType.REDIS + """ + return StorageType.REDIS + async def _get_redis(self) -> redis.Redis: """Get Redis connection, creating it if necessary.""" if self._redis is None: @@ -94,8 +107,9 @@ async def put_object( data: Union[bytes, str, BinaryIO, TextIO, Reader], content_type: Optional[str] = None, # Ignored for Redis metadata: Optional[dict[str, str]] = None, # Ignored for Redis - ) -> None: - """Put an object to Redis storage. + options: Optional[PutObjectOptions] = None, + ) -> bool: + """Put an object to Redis storage with advanced options. If key contains "/", creates a Redis list for the parent part. For example, "xxx/yyy" will create a list "xxx" and add "yyy" to it, @@ -106,7 +120,14 @@ async def put_object( data: Data to store content_type: Ignored for Redis metadata: Ignored for Redis + options: Advanced options for put operation + + Returns: + True if object was stored, False if conditional operation failed """ + # Validate options + self._validate_put_options(options) + redis_client = await self._get_redis() # Convert data to bytes @@ -122,10 +143,32 @@ async def put_object( # Parse hierarchical key parent_key, item_key = self._parse_hierarchical_key(key) - # Store the actual data - await redis_client.set(key, data_bytes) + # Prepare Redis SET options from PutObjectOptions + redis_ex = None + redis_px = None + redis_nx = False + redis_xx = False + + if options: + if options.ttl_seconds is not None: + redis_ex = options.ttl_seconds + elif options.ttl_milliseconds is not None: + redis_px = options.ttl_milliseconds + + redis_nx = options.set_if_not_exists + redis_xx = options.set_if_exists - # Store creation timestamp for ordering + # Store the actual data with options + result = await redis_client.set( + key, data_bytes, ex=redis_ex, px=redis_px, nx=redis_nx, xx=redis_xx + ) + + # Check if the SET operation succeeded + if result is None: + # Conditional SET failed (NX or XX condition not met) + return False + + # Store creation timestamp for ordering (only if SET succeeded) timestamp = time.time() await redis_client.zadd("timestamps:all", {key: timestamp}) @@ -136,6 +179,8 @@ async def put_object( # Also track timestamp for hierarchical objects await redis_client.zadd(f"timestamps:{parent_key}", {item_key: timestamp}) + return True + async def get_object( self, key: str, @@ -451,3 +496,28 @@ async def _native_abort_multipart_upload( NotImplementedError: Multipart upload not needed for Redis """ raise NotImplementedError("Multipart upload not needed for Redis storage") + + # Feature Support Methods + def is_ttl_supported(self) -> bool: + """Check if TTL (Time To Live) is supported. + + Returns: + True - Redis supports TTL + """ + return True + + def is_set_if_not_exists_supported(self) -> bool: + """Check if conditional SET IF NOT EXISTS is supported. + + Returns: + True - Redis supports NX option + """ + return True + + def is_set_if_exists_supported(self) -> bool: + """Check if conditional SET IF EXISTS is supported. + + Returns: + True - Redis supports XX option + """ + return True diff --git a/python/aibrix/aibrix/storage/s3.py b/python/aibrix/aibrix/storage/s3.py index 6c9365cc2..028cf6aa1 100644 --- a/python/aibrix/aibrix/storage/s3.py +++ b/python/aibrix/aibrix/storage/s3.py @@ -20,7 +20,12 @@ from botocore.config import Config from botocore.exceptions import ClientError -from aibrix.storage.base import BaseStorage, StorageConfig +from aibrix.storage.base import ( + BaseStorage, + PutObjectOptions, + StorageConfig, + StorageType, +) from aibrix.storage.reader import Reader from aibrix.storage.utils import ObjectMetadata @@ -64,14 +69,26 @@ def __init__( except ClientError as e: raise ValueError(f"Bucket {bucket_name} not accessible: {e}") + def get_type(self) -> StorageType: + """Get the type of storage. + + Returns: + Type of storage, set to StorageType.S3 + """ + return StorageType.S3 + async def put_object( self, key: str, data: Union[bytes, str, BinaryIO, TextIO, Reader], content_type: Optional[str] = None, metadata: Optional[dict[str, str]] = None, - ) -> None: + options: Optional[PutObjectOptions] = None, + ) -> bool: """Put an object to S3.""" + # Validate options (S3 doesn't support advanced options) + self._validate_put_options(options) + # Unify all data types using Reader wrapper reader = self._wrap_s3_data(data) @@ -83,7 +100,7 @@ async def put_object( size = len(reader) if size >= self.config.multipart_threshold: await self.multipart_upload(key, reader, content_type, metadata) - return + return True except (OSError, IOError, ValueError): # Can't determine size, give up multipart upload pass @@ -111,6 +128,8 @@ async def put_object( if isinstance(reader, Reader) and not isinstance(data, Reader): reader.close() + return True # S3 storage always succeeds + async def get_object( self, key: str, diff --git a/python/aibrix/aibrix/storage/tos.py b/python/aibrix/aibrix/storage/tos.py index 39179d5b1..c80dcba0e 100644 --- a/python/aibrix/aibrix/storage/tos.py +++ b/python/aibrix/aibrix/storage/tos.py @@ -19,7 +19,12 @@ import tos from tos.exceptions import TosClientError, TosServerError -from aibrix.storage.base import BaseStorage, StorageConfig +from aibrix.storage.base import ( + BaseStorage, + PutObjectOptions, + StorageConfig, + StorageType, +) from aibrix.storage.reader import Reader from aibrix.storage.utils import ObjectMetadata @@ -52,14 +57,26 @@ def __init__( except (TosClientError, TosServerError) as e: raise ValueError(f"Failed to create TOS client: {e}") + def get_type(self) -> StorageType: + """Get the type of storage. + + Returns: + Type of storage, set to StorageType.TOS + """ + return StorageType.TOS + async def put_object( self, key: str, data: Union[bytes, str, BinaryIO, TextIO, Reader], content_type: Optional[str] = None, metadata: Optional[dict[str, str]] = None, - ) -> None: + options: Optional[PutObjectOptions] = None, + ) -> bool: """Put an object to TOS.""" + # Validate options (TOS doesn't support advanced options) + self._validate_put_options(options) + # Unify all data types using Reader wrapper reader = self._wrap_data(data) @@ -68,7 +85,7 @@ async def put_object( size = reader.get_size() if size >= self.config.multipart_threshold: await self.multipart_upload(key, reader, content_type, metadata) - return + return True except (OSError, IOError, ValueError): # Can't determine size, give up multipart upload pass @@ -102,6 +119,8 @@ def _put_object(): if not isinstance(data, Reader): reader.close() + return True # TOS storage always succeeds + async def get_object( self, key: str, diff --git a/python/aibrix/aibrix/storage/utils.py b/python/aibrix/aibrix/storage/utils.py index 91f1c06fa..4897042de 100644 --- a/python/aibrix/aibrix/storage/utils.py +++ b/python/aibrix/aibrix/storage/utils.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import os -import threading from dataclasses import dataclass from datetime import datetime from typing import Optional +from aibrix.metadata.core import AsyncLoopThread + extension_map = { "image/jpeg": ".jpg", "application/x-tar": ".tar", @@ -55,51 +55,13 @@ class ObjectMetadata: expires: Optional[datetime] = None -class AsyncLoopThread(threading.Thread): - def __init__(self): - super().__init__(daemon=True) - self.loop = None - self.started_event = threading.Event() - - def run(self): - """The entry point for the new thread.""" - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - # Signal that the loop is created and running - self.started_event.set() - - # This will run until loop.stop() is called - self.loop.run_forever() - - # Cleanly close the loop - self.loop.close() - - def start(self): - """Starts the thread and waits for the loop to be ready.""" - super().start() - # Block until the 'run' method has set up the loop - self.started_event.wait() - - def stop(self): - """Stops the event loop and waits for the thread to exit.""" - if self.loop: - # Schedule loop.stop() to be called from within the loop's thread - self.loop.call_soon_threadsafe(self.loop.stop) - self.join() - - def submit_coroutine(self, coro): - """Submits a coroutine to the event loop from any thread.""" - return asyncio.run_coroutine_threadsafe(coro, self.loop) - - storage_loop_thread: Optional[AsyncLoopThread] = None -def init_storage_loop_thread(): +def init_storage_loop_thread(name: str): global storage_loop_thread if storage_loop_thread is None: - storage_loop_thread = AsyncLoopThread() + storage_loop_thread = AsyncLoopThread(name) storage_loop_thread.start() diff --git a/python/aibrix/poetry.lock b/python/aibrix/poetry.lock index 1e7eaf81f..73a7983e6 100644 --- a/python/aibrix/poetry.lock +++ b/python/aibrix/poetry.lock @@ -144,7 +144,7 @@ version = "4.8.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a"}, {file = "anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a"}, @@ -180,7 +180,7 @@ version = "24.3.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"}, {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"}, @@ -194,6 +194,21 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""] +[[package]] +name = "automat" +version = "25.4.16" +description = "Self-service finite-state machines for the programmer on the go." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "automat-25.4.16-py3-none-any.whl", hash = "sha256:04e9bce696a8d5671ee698005af6e5a9fa15354140a87f4870744604dcdd3ba1"}, + {file = "automat-25.4.16.tar.gz", hash = "sha256:0017591a5477066e90d26b0e696ddc143baafd87b588cfac8100bc6be9634de0"}, +] + +[package.extras] +visualize = ["Twisted (>=16.1.1)", "graphviz (>0.5.1)"] + [[package]] name = "backports-asyncio-runner" version = "1.2.0" @@ -289,8 +304,7 @@ version = "1.17.1" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" -groups = ["main"] -markers = "platform_python_implementation != \"PyPy\"" +groups = ["main", "dev"] files = [ {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, @@ -360,6 +374,7 @@ files = [ {file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"}, {file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"}, ] +markers = {main = "platform_python_implementation != \"PyPy\"", dev = "os_name == \"nt\" and implementation_name != \"pypy\""} [package.dependencies] pycparser = "*" @@ -494,6 +509,18 @@ files = [ ] markers = {main = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} +[[package]] +name = "constantly" +version = "23.10.4" +description = "Symbolic constants in Python" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "constantly-23.10.4-py3-none-any.whl", hash = "sha256:3fd9b4d1c3dc1ec9757f3c52aef7e53ad9323dbe39f51dfd4c43853b68dfa3f9"}, + {file = "constantly-23.10.4.tar.gz", hash = "sha256:aa92b70a33e2ac0bb33cd745eb61776594dc48764b06c35e0efd050b7f1c7cbd"}, +] + [[package]] name = "contourpy" version = "1.3.1" @@ -715,6 +742,18 @@ files = [ {file = "dash_table-5.0.0.tar.gz", hash = "sha256:18624d693d4c8ef2ddec99a6f167593437a7ea0bf153aa20f318c170c5bc7308"}, ] +[[package]] +name = "decorator" +version = "5.2.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a"}, + {file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"}, +] + [[package]] name = "deprecated" version = "1.2.15" @@ -1063,6 +1102,74 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "greenlet" +version = "3.2.3" +description = "Lightweight in-process concurrent programming" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "greenlet-3.2.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:1afd685acd5597349ee6d7a88a8bec83ce13c106ac78c196ee9dde7c04fe87be"}, + {file = "greenlet-3.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:761917cac215c61e9dc7324b2606107b3b292a8349bdebb31503ab4de3f559ac"}, + {file = "greenlet-3.2.3-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:a433dbc54e4a37e4fff90ef34f25a8c00aed99b06856f0119dcf09fbafa16392"}, + {file = "greenlet-3.2.3-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:72e77ed69312bab0434d7292316d5afd6896192ac4327d44f3d613ecb85b037c"}, + {file = "greenlet-3.2.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:68671180e3849b963649254a882cd544a3c75bfcd2c527346ad8bb53494444db"}, + {file = "greenlet-3.2.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:49c8cfb18fb419b3d08e011228ef8a25882397f3a859b9fe1436946140b6756b"}, + {file = "greenlet-3.2.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:efc6dc8a792243c31f2f5674b670b3a95d46fa1c6a912b8e310d6f542e7b0712"}, + {file = "greenlet-3.2.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:731e154aba8e757aedd0781d4b240f1225b075b4409f1bb83b05ff410582cf00"}, + {file = "greenlet-3.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:96c20252c2f792defe9a115d3287e14811036d51e78b3aaddbee23b69b216302"}, + {file = "greenlet-3.2.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:784ae58bba89fa1fa5733d170d42486580cab9decda3484779f4759345b29822"}, + {file = "greenlet-3.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0921ac4ea42a5315d3446120ad48f90c3a6b9bb93dd9b3cf4e4d84a66e42de83"}, + {file = "greenlet-3.2.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d2971d93bb99e05f8c2c0c2f4aa9484a18d98c4c3bd3c62b65b7e6ae33dfcfaf"}, + {file = "greenlet-3.2.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c667c0bf9d406b77a15c924ef3285e1e05250948001220368e039b6aa5b5034b"}, + {file = "greenlet-3.2.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:592c12fb1165be74592f5de0d70f82bc5ba552ac44800d632214b76089945147"}, + {file = "greenlet-3.2.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29e184536ba333003540790ba29829ac14bb645514fbd7e32af331e8202a62a5"}, + {file = "greenlet-3.2.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:93c0bb79844a367782ec4f429d07589417052e621aa39a5ac1fb99c5aa308edc"}, + {file = "greenlet-3.2.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:751261fc5ad7b6705f5f76726567375bb2104a059454e0226e1eef6c756748ba"}, + {file = "greenlet-3.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:83a8761c75312361aa2b5b903b79da97f13f556164a7dd2d5448655425bd4c34"}, + {file = "greenlet-3.2.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:25ad29caed5783d4bd7a85c9251c651696164622494c00802a139c00d639242d"}, + {file = "greenlet-3.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:88cd97bf37fe24a6710ec6a3a7799f3f81d9cd33317dcf565ff9950c83f55e0b"}, + {file = "greenlet-3.2.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:baeedccca94880d2f5666b4fa16fc20ef50ba1ee353ee2d7092b383a243b0b0d"}, + {file = "greenlet-3.2.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:be52af4b6292baecfa0f397f3edb3c6092ce071b499dd6fe292c9ac9f2c8f264"}, + {file = "greenlet-3.2.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0cc73378150b8b78b0c9fe2ce56e166695e67478550769536a6742dca3651688"}, + {file = "greenlet-3.2.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:706d016a03e78df129f68c4c9b4c4f963f7d73534e48a24f5f5a7101ed13dbbb"}, + {file = "greenlet-3.2.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:419e60f80709510c343c57b4bb5a339d8767bf9aef9b8ce43f4f143240f88b7c"}, + {file = "greenlet-3.2.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:93d48533fade144203816783373f27a97e4193177ebaaf0fc396db19e5d61163"}, + {file = "greenlet-3.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:7454d37c740bb27bdeddfc3f358f26956a07d5220818ceb467a483197d84f849"}, + {file = "greenlet-3.2.3-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:500b8689aa9dd1ab26872a34084503aeddefcb438e2e7317b89b11eaea1901ad"}, + {file = "greenlet-3.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a07d3472c2a93117af3b0136f246b2833fdc0b542d4a9799ae5f41c28323faef"}, + {file = "greenlet-3.2.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:8704b3768d2f51150626962f4b9a9e4a17d2e37c8a8d9867bbd9fa4eb938d3b3"}, + {file = "greenlet-3.2.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5035d77a27b7c62db6cf41cf786cfe2242644a7a337a0e155c80960598baab95"}, + {file = "greenlet-3.2.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2d8aa5423cd4a396792f6d4580f88bdc6efcb9205891c9d40d20f6e670992efb"}, + {file = "greenlet-3.2.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2c724620a101f8170065d7dded3f962a2aea7a7dae133a009cada42847e04a7b"}, + {file = "greenlet-3.2.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:873abe55f134c48e1f2a6f53f7d1419192a3d1a4e873bace00499a4e45ea6af0"}, + {file = "greenlet-3.2.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:024571bbce5f2c1cfff08bf3fbaa43bbc7444f580ae13b0099e95d0e6e67ed36"}, + {file = "greenlet-3.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:5195fb1e75e592dd04ce79881c8a22becdfa3e6f500e7feb059b1e6fdd54d3e3"}, + {file = "greenlet-3.2.3-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:3d04332dddb10b4a211b68111dabaee2e1a073663d117dc10247b5b1642bac86"}, + {file = "greenlet-3.2.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8186162dffde068a465deab08fc72c767196895c39db26ab1c17c0b77a6d8b97"}, + {file = "greenlet-3.2.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f4bfbaa6096b1b7a200024784217defedf46a07c2eee1a498e94a1b5f8ec5728"}, + {file = "greenlet-3.2.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:ed6cfa9200484d234d8394c70f5492f144b20d4533f69262d530a1a082f6ee9a"}, + {file = "greenlet-3.2.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:02b0df6f63cd15012bed5401b47829cfd2e97052dc89da3cfaf2c779124eb892"}, + {file = "greenlet-3.2.3-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:86c2d68e87107c1792e2e8d5399acec2487a4e993ab76c792408e59394d52141"}, + {file = "greenlet-3.2.3-cp314-cp314-win_amd64.whl", hash = "sha256:8c47aae8fbbfcf82cc13327ae802ba13c9c36753b67e760023fd116bc124a62a"}, + {file = "greenlet-3.2.3-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:42efc522c0bd75ffa11a71e09cd8a399d83fafe36db250a87cf1dacfaa15dc64"}, + {file = "greenlet-3.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d760f9bdfe79bff803bad32b4d8ffb2c1d2ce906313fc10a83976ffb73d64ca7"}, + {file = "greenlet-3.2.3-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:8324319cbd7b35b97990090808fdc99c27fe5338f87db50514959f8059999805"}, + {file = "greenlet-3.2.3-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:8c37ef5b3787567d322331d5250e44e42b58c8c713859b8a04c6065f27efbf72"}, + {file = "greenlet-3.2.3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ce539fb52fb774d0802175d37fcff5c723e2c7d249c65916257f0a940cee8904"}, + {file = "greenlet-3.2.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:003c930e0e074db83559edc8705f3a2d066d4aa8c2f198aff1e454946efd0f26"}, + {file = "greenlet-3.2.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7e70ea4384b81ef9e84192e8a77fb87573138aa5d4feee541d8014e452b434da"}, + {file = "greenlet-3.2.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:22eb5ba839c4b2156f18f76768233fe44b23a31decd9cc0d4cc8141c211fd1b4"}, + {file = "greenlet-3.2.3-cp39-cp39-win32.whl", hash = "sha256:4532f0d25df67f896d137431b13f4cdce89f7e3d4a96387a41290910df4d3a57"}, + {file = "greenlet-3.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:aaa7aae1e7f75eaa3ae400ad98f8644bb81e1dc6ba47ce8a93d3f17274e08322"}, + {file = "greenlet-3.2.3.tar.gz", hash = "sha256:8b0dd8ae4c0d6f5e54ee55ba935eeb3d735a9b58a8a1e5b5cbab64e01a39f365"}, +] + +[package.extras] +docs = ["Sphinx", "furo"] +test = ["objgraph", "psutil"] + [[package]] name = "gunicorn" version = "23.0.0" @@ -1215,13 +1322,28 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gr torch = ["safetensors[torch]", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] +[[package]] +name = "hyperlink" +version = "21.0.0" +description = "A featureful, immutable, and correct URL for Python." +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] +files = [ + {file = "hyperlink-21.0.0-py2.py3-none-any.whl", hash = "sha256:e6b14c37ecb73e89c77d78cdb4c2cc8f3fb59a885c5b3f819ff4ed80f25af1b4"}, + {file = "hyperlink-21.0.0.tar.gz", hash = "sha256:427af957daa58bc909471c6c40f74c5450fa123dd093fc53efd2e91d2705a56b"}, +] + +[package.dependencies] +idna = ">=2.5" + [[package]] name = "idna" version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -1273,6 +1395,25 @@ scikit-learn = ">=1.0,<2.0" sortedcontainers = ">=2.4.0,<3.0.0" xxhash = ">=2.0.0,<3.0.0" +[[package]] +name = "incremental" +version = "24.7.2" +description = "A small library that versions your Python projects." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "incremental-24.7.2-py3-none-any.whl", hash = "sha256:8cb2c3431530bec48ad70513931a760f446ad6c25e8333ca5d95e24b0ed7b8fe"}, + {file = "incremental-24.7.2.tar.gz", hash = "sha256:fb4f1d47ee60efe87d4f6f0ebb5f70b9760db2b2574c59c8e8912be4ebd464c9"}, +] + +[package.dependencies] +setuptools = ">=61.0" +tomli = {version = "*", markers = "python_version < \"3.11\""} + +[package.extras] +scripts = ["click (>=6.0)"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -1888,6 +2029,21 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "outcome" +version = "1.3.0.post0" +description = "Capture the outcome of Python function calls." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "outcome-1.3.0.post0-py2.py3-none-any.whl", hash = "sha256:e771c5ce06d1415e356078d3bdd68523f284b4ce5419828922b6871e65eda82b"}, + {file = "outcome-1.3.0.post0.tar.gz", hash = "sha256:9dcf02e65f2971b80047b377468e72a268e15c0af3cf1238e6ff14f7f91143b8"}, +] + +[package.dependencies] +attrs = ">=19.2.0" + [[package]] name = "packaging" version = "24.2" @@ -2260,12 +2416,12 @@ version = "2.22" description = "C parser in Python" optional = false python-versions = ">=3.8" -groups = ["main"] -markers = "platform_python_implementation != \"PyPy\"" +groups = ["main", "dev"] files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] +markers = {main = "platform_python_implementation != \"PyPy\"", dev = "os_name == \"nt\" and implementation_name != \"pypy\""} [[package]] name = "pydantic" @@ -2483,6 +2639,61 @@ pytest = ">=8.2,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-tornasync" +version = "0.6.0.post2" +description = "py.test plugin for testing Python 3.5+ Tornado code" +optional = false +python-versions = "*" +groups = ["dev"] +files = [ + {file = "pytest-tornasync-0.6.0.post2.tar.gz", hash = "sha256:d781b6d951a2e7c08843141d3ff583610b4ea86bfa847714c76edefb576bbe5d"}, + {file = "pytest_tornasync-0.6.0.post2-py3-none-any.whl", hash = "sha256:4b165b6ba76b5b228933598f456b71ba233f127991a52889788db0a950ad04ba"}, +] + +[package.dependencies] +pytest = ">=3.0" +tornado = ">=5.0" + +[[package]] +name = "pytest-trio" +version = "0.8.0" +description = "Pytest plugin for trio" +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "pytest-trio-0.8.0.tar.gz", hash = "sha256:8363db6336a79e6c53375a2123a41ddbeccc4aa93f93788651641789a56fb52e"}, + {file = "pytest_trio-0.8.0-py3-none-any.whl", hash = "sha256:e6a7e7351ae3e8ec3f4564d30ee77d1ec66e1df611226e5618dbb32f9545c841"}, +] + +[package.dependencies] +outcome = ">=1.1.0" +pytest = ">=7.2.0" +trio = ">=0.22.0" + +[[package]] +name = "pytest-twisted" +version = "1.14.3" +description = "A twisted plugin for pytest." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["dev"] +files = [ + {file = "pytest_twisted-1.14.3-py2.py3-none-any.whl", hash = "sha256:f2e3f3f6f12f78df17c028fe16d87af09c76b95a7a85bc378b2d3e73a086e81a"}, + {file = "pytest_twisted-1.14.3.tar.gz", hash = "sha256:37e150cbbc0edba6592d36c53f44fc1196f3a9e93e7bef6a25bb10d9963f7f3e"}, +] + +[package.dependencies] +decorator = "*" +greenlet = "*" +pytest = ">=2.3" + +[package.extras] +dev = ["black", "pre-commit"] +pyqt5 = ["qt5reactor[pyqt5] (>=0.6.2)"] +pyside2 = ["qt5reactor[pyside2] (>=0.6.3)"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3012,7 +3223,7 @@ version = "75.8.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "setuptools-75.8.0-py3-none-any.whl", hash = "sha256:e3982f444617239225d675215d51f6ba05f845d4eec313da4418fdbb56fb27e3"}, {file = "setuptools-75.8.0.tar.gz", hash = "sha256:c5afc8f407c626b8313a86e10311dd3f661c6cd9c09d4bf8c15c0e11f9f2b0e6"}, @@ -3045,7 +3256,7 @@ version = "1.3.1" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -3057,7 +3268,7 @@ version = "2.4.0" description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" optional = false python-versions = "*" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"}, {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, @@ -3256,6 +3467,28 @@ files = [ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] +[[package]] +name = "tornado" +version = "6.5.1" +description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "tornado-6.5.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d50065ba7fd11d3bd41bcad0825227cc9a95154bad83239357094c36708001f7"}, + {file = "tornado-6.5.1-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:9e9ca370f717997cb85606d074b0e5b247282cf5e2e1611568b8821afe0342d6"}, + {file = "tornado-6.5.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b77e9dfa7ed69754a54c89d82ef746398be82f749df69c4d3abe75c4d1ff4888"}, + {file = "tornado-6.5.1-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253b76040ee3bab8bcf7ba9feb136436a3787208717a1fb9f2c16b744fba7331"}, + {file = "tornado-6.5.1-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:308473f4cc5a76227157cdf904de33ac268af770b2c5f05ca6c1161d82fdd95e"}, + {file = "tornado-6.5.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:caec6314ce8a81cf69bd89909f4b633b9f523834dc1a352021775d45e51d9401"}, + {file = "tornado-6.5.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:13ce6e3396c24e2808774741331638ee6c2f50b114b97a55c5b442df65fd9692"}, + {file = "tornado-6.5.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5cae6145f4cdf5ab24744526cc0f55a17d76f02c98f4cff9daa08ae9a217448a"}, + {file = "tornado-6.5.1-cp39-abi3-win32.whl", hash = "sha256:e0a36e1bc684dca10b1aa75a31df8bdfed656831489bc1e6a6ebed05dc1ec365"}, + {file = "tornado-6.5.1-cp39-abi3-win_amd64.whl", hash = "sha256:908e7d64567cecd4c2b458075589a775063453aeb1d2a1853eedb806922f568b"}, + {file = "tornado-6.5.1-cp39-abi3-win_arm64.whl", hash = "sha256:02420a0eb7bf617257b9935e2b754d1b63897525d8a289c9d65690d580b4dcf7"}, + {file = "tornado-6.5.1.tar.gz", hash = "sha256:84ceece391e8eb9b2b95578db65e920d2a61070260594819589609ba9bc6308c"}, +] + [[package]] name = "tos" version = "2.8.0" @@ -3366,6 +3599,64 @@ torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", video = ["av (==9.2.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] +[[package]] +name = "trio" +version = "0.30.0" +description = "A friendly Python library for async concurrency and I/O" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "trio-0.30.0-py3-none-any.whl", hash = "sha256:3bf4f06b8decf8d3cf00af85f40a89824669e2d033bb32469d34840edcfc22a5"}, + {file = "trio-0.30.0.tar.gz", hash = "sha256:0781c857c0c81f8f51e0089929a26b5bb63d57f927728a5586f7e36171f064df"}, +] + +[package.dependencies] +attrs = ">=23.2.0" +cffi = {version = ">=1.14", markers = "os_name == \"nt\" and implementation_name != \"pypy\""} +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +idna = "*" +outcome = "*" +sniffio = ">=1.3.0" +sortedcontainers = "*" + +[[package]] +name = "twisted" +version = "25.5.0" +description = "An asynchronous networking framework written in Python" +optional = false +python-versions = ">=3.8.0" +groups = ["dev"] +files = [ + {file = "twisted-25.5.0-py3-none-any.whl", hash = "sha256:8559f654d01a54a8c3efe66d533d43f383531ebf8d81d9f9ab4769d91ca15df7"}, + {file = "twisted-25.5.0.tar.gz", hash = "sha256:1deb272358cb6be1e3e8fc6f9c8b36f78eb0fa7c2233d2dbe11ec6fee04ea316"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +automat = ">=24.8.0" +constantly = ">=15.1" +hyperlink = ">=17.1.1" +incremental = ">=24.7.0" +typing-extensions = ">=4.2.0" +zope-interface = ">=5" + +[package.extras] +all-non-platform = ["appdirs (>=1.4.0)", "appdirs (>=1.4.0)", "bcrypt (>=3.1.3)", "bcrypt (>=3.1.3)", "cryptography (>=3.3)", "cryptography (>=3.3)", "cython-test-exception-raiser (>=1.0.2,<2)", "cython-test-exception-raiser (>=1.0.2,<2)", "h2 (>=3.2,<5.0)", "h2 (>=3.2,<5.0)", "httpx[http2] (>=0.27)", "httpx[http2] (>=0.27)", "hypothesis (>=6.56)", "hypothesis (>=6.56)", "idna (>=2.4)", "idna (>=2.4)", "priority (>=1.1.0,<2.0)", "priority (>=1.1.0,<2.0)", "pyhamcrest (>=2)", "pyhamcrest (>=2)", "pyopenssl (>=21.0.0)", "pyopenssl (>=21.0.0)", "pyserial (>=3.0)", "pyserial (>=3.0)", "pywin32 (!=226) ; platform_system == \"Windows\"", "pywin32 (!=226) ; platform_system == \"Windows\"", "service-identity (>=18.1.0)", "service-identity (>=18.1.0)", "wsproto", "wsproto"] +conch = ["appdirs (>=1.4.0)", "bcrypt (>=3.1.3)", "cryptography (>=3.3)"] +dev = ["coverage (>=7.5,<8.0)", "cython-test-exception-raiser (>=1.0.2,<2)", "httpx[http2] (>=0.27)", "hypothesis (>=6.56)", "pydoctor (>=24.11.1,<24.12.0)", "pyflakes (>=2.2,<3.0)", "pyhamcrest (>=2)", "python-subunit (>=1.4,<2.0)", "sphinx (>=6,<7)", "sphinx-rtd-theme (>=1.3,<2.0)", "towncrier (>=23.6,<24.0)", "twistedchecker (>=0.7,<1.0)"] +dev-release = ["pydoctor (>=24.11.1,<24.12.0)", "pydoctor (>=24.11.1,<24.12.0)", "sphinx (>=6,<7)", "sphinx (>=6,<7)", "sphinx-rtd-theme (>=1.3,<2.0)", "sphinx-rtd-theme (>=1.3,<2.0)", "towncrier (>=23.6,<24.0)", "towncrier (>=23.6,<24.0)"] +gtk-platform = ["appdirs (>=1.4.0)", "appdirs (>=1.4.0)", "bcrypt (>=3.1.3)", "bcrypt (>=3.1.3)", "cryptography (>=3.3)", "cryptography (>=3.3)", "cython-test-exception-raiser (>=1.0.2,<2)", "cython-test-exception-raiser (>=1.0.2,<2)", "h2 (>=3.2,<5.0)", "h2 (>=3.2,<5.0)", "httpx[http2] (>=0.27)", "httpx[http2] (>=0.27)", "hypothesis (>=6.56)", "hypothesis (>=6.56)", "idna (>=2.4)", "idna (>=2.4)", "priority (>=1.1.0,<2.0)", "priority (>=1.1.0,<2.0)", "pygobject", "pygobject", "pyhamcrest (>=2)", "pyhamcrest (>=2)", "pyopenssl (>=21.0.0)", "pyopenssl (>=21.0.0)", "pyserial (>=3.0)", "pyserial (>=3.0)", "pywin32 (!=226) ; platform_system == \"Windows\"", "pywin32 (!=226) ; platform_system == \"Windows\"", "service-identity (>=18.1.0)", "service-identity (>=18.1.0)", "wsproto", "wsproto"] +http2 = ["h2 (>=3.2,<5.0)", "priority (>=1.1.0,<2.0)"] +macos-platform = ["appdirs (>=1.4.0)", "appdirs (>=1.4.0)", "bcrypt (>=3.1.3)", "bcrypt (>=3.1.3)", "cryptography (>=3.3)", "cryptography (>=3.3)", "cython-test-exception-raiser (>=1.0.2,<2)", "cython-test-exception-raiser (>=1.0.2,<2)", "h2 (>=3.2,<5.0)", "h2 (>=3.2,<5.0)", "httpx[http2] (>=0.27)", "httpx[http2] (>=0.27)", "hypothesis (>=6.56)", "hypothesis (>=6.56)", "idna (>=2.4)", "idna (>=2.4)", "priority (>=1.1.0,<2.0)", "priority (>=1.1.0,<2.0)", "pyhamcrest (>=2)", "pyhamcrest (>=2)", "pyobjc-core (<11) ; python_version < \"3.9\"", "pyobjc-core (<11) ; python_version < \"3.9\"", "pyobjc-core ; python_version >= \"3.9\"", "pyobjc-core ; python_version >= \"3.9\"", "pyobjc-framework-cfnetwork (<11) ; python_version < \"3.9\"", "pyobjc-framework-cfnetwork (<11) ; python_version < \"3.9\"", "pyobjc-framework-cfnetwork ; python_version >= \"3.9\"", "pyobjc-framework-cfnetwork ; python_version >= \"3.9\"", "pyobjc-framework-cocoa (<11) ; python_version < \"3.9\"", "pyobjc-framework-cocoa (<11) ; python_version < \"3.9\"", "pyobjc-framework-cocoa ; python_version >= \"3.9\"", "pyobjc-framework-cocoa ; python_version >= \"3.9\"", "pyopenssl (>=21.0.0)", "pyopenssl (>=21.0.0)", "pyserial (>=3.0)", "pyserial (>=3.0)", "pywin32 (!=226) ; platform_system == \"Windows\"", "pywin32 (!=226) ; platform_system == \"Windows\"", "service-identity (>=18.1.0)", "service-identity (>=18.1.0)", "wsproto", "wsproto"] +mypy = ["appdirs (>=1.4.0)", "bcrypt (>=3.1.3)", "coverage (>=7.5,<8.0)", "cryptography (>=3.3)", "cython-test-exception-raiser (>=1.0.2,<2)", "h2 (>=3.2,<5.0)", "httpx[http2] (>=0.27)", "hypothesis (>=6.56)", "idna (>=2.4)", "mypy (==1.10.1)", "mypy-zope (==1.0.6)", "priority (>=1.1.0,<2.0)", "pydoctor (>=24.11.1,<24.12.0)", "pyflakes (>=2.2,<3.0)", "pyhamcrest (>=2)", "pyopenssl (>=21.0.0)", "pyserial (>=3.0)", "python-subunit (>=1.4,<2.0)", "pywin32 (!=226) ; platform_system == \"Windows\"", "service-identity (>=18.1.0)", "sphinx (>=6,<7)", "sphinx-rtd-theme (>=1.3,<2.0)", "towncrier (>=23.6,<24.0)", "twistedchecker (>=0.7,<1.0)", "types-pyopenssl", "types-setuptools", "wsproto"] +osx-platform = ["appdirs (>=1.4.0)", "appdirs (>=1.4.0)", "bcrypt (>=3.1.3)", "bcrypt (>=3.1.3)", "cryptography (>=3.3)", "cryptography (>=3.3)", "cython-test-exception-raiser (>=1.0.2,<2)", "cython-test-exception-raiser (>=1.0.2,<2)", "h2 (>=3.2,<5.0)", "h2 (>=3.2,<5.0)", "httpx[http2] (>=0.27)", "httpx[http2] (>=0.27)", "hypothesis (>=6.56)", "hypothesis (>=6.56)", "idna (>=2.4)", "idna (>=2.4)", "priority (>=1.1.0,<2.0)", "priority (>=1.1.0,<2.0)", "pyhamcrest (>=2)", "pyhamcrest (>=2)", "pyobjc-core (<11) ; python_version < \"3.9\"", "pyobjc-core (<11) ; python_version < \"3.9\"", "pyobjc-core ; python_version >= \"3.9\"", "pyobjc-core ; python_version >= \"3.9\"", "pyobjc-framework-cfnetwork (<11) ; python_version < \"3.9\"", "pyobjc-framework-cfnetwork (<11) ; python_version < \"3.9\"", "pyobjc-framework-cfnetwork ; python_version >= \"3.9\"", "pyobjc-framework-cfnetwork ; python_version >= \"3.9\"", "pyobjc-framework-cocoa (<11) ; python_version < \"3.9\"", "pyobjc-framework-cocoa (<11) ; python_version < \"3.9\"", "pyobjc-framework-cocoa ; python_version >= \"3.9\"", "pyobjc-framework-cocoa ; python_version >= \"3.9\"", "pyopenssl (>=21.0.0)", "pyopenssl (>=21.0.0)", "pyserial (>=3.0)", "pyserial (>=3.0)", "pywin32 (!=226) ; platform_system == \"Windows\"", "pywin32 (!=226) ; platform_system == \"Windows\"", "service-identity (>=18.1.0)", "service-identity (>=18.1.0)", "wsproto", "wsproto"] +serial = ["pyserial (>=3.0)", "pywin32 (!=226) ; platform_system == \"Windows\""] +test = ["cython-test-exception-raiser (>=1.0.2,<2)", "httpx[http2] (>=0.27)", "hypothesis (>=6.56)", "pyhamcrest (>=2)"] +tls = ["idna (>=2.4)", "pyopenssl (>=21.0.0)", "service-identity (>=18.1.0)"] +websocket = ["wsproto"] +windows-platform = ["appdirs (>=1.4.0)", "appdirs (>=1.4.0)", "bcrypt (>=3.1.3)", "bcrypt (>=3.1.3)", "cryptography (>=3.3)", "cryptography (>=3.3)", "cython-test-exception-raiser (>=1.0.2,<2)", "cython-test-exception-raiser (>=1.0.2,<2)", "h2 (>=3.2,<5.0)", "h2 (>=3.2,<5.0)", "httpx[http2] (>=0.27)", "httpx[http2] (>=0.27)", "hypothesis (>=6.56)", "hypothesis (>=6.56)", "idna (>=2.4)", "idna (>=2.4)", "priority (>=1.1.0,<2.0)", "priority (>=1.1.0,<2.0)", "pyhamcrest (>=2)", "pyhamcrest (>=2)", "pyopenssl (>=21.0.0)", "pyopenssl (>=21.0.0)", "pyserial (>=3.0)", "pyserial (>=3.0)", "pywin32 (!=226)", "pywin32 (!=226)", "pywin32 (!=226) ; platform_system == \"Windows\"", "pywin32 (!=226) ; platform_system == \"Windows\"", "service-identity (>=18.1.0)", "service-identity (>=18.1.0)", "twisted-iocpsupport (>=1.0.2)", "twisted-iocpsupport (>=1.0.2)", "wsproto", "wsproto"] + [[package]] name = "types-cffi" version = "1.16.0.20241221" @@ -3896,7 +4187,62 @@ enabler = ["pytest-enabler (>=2.2)"] test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] type = ["pytest-mypy"] +[[package]] +name = "zope-interface" +version = "7.2" +description = "Interfaces for Python" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "zope.interface-7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ce290e62229964715f1011c3dbeab7a4a1e4971fd6f31324c4519464473ef9f2"}, + {file = "zope.interface-7.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:05b910a5afe03256b58ab2ba6288960a2892dfeef01336dc4be6f1b9ed02ab0a"}, + {file = "zope.interface-7.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:550f1c6588ecc368c9ce13c44a49b8d6b6f3ca7588873c679bd8fd88a1b557b6"}, + {file = "zope.interface-7.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0ef9e2f865721553c6f22a9ff97da0f0216c074bd02b25cf0d3af60ea4d6931d"}, + {file = "zope.interface-7.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27f926f0dcb058211a3bb3e0e501c69759613b17a553788b2caeb991bed3b61d"}, + {file = "zope.interface-7.2-cp310-cp310-win_amd64.whl", hash = "sha256:144964649eba4c5e4410bb0ee290d338e78f179cdbfd15813de1a664e7649b3b"}, + {file = "zope.interface-7.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1909f52a00c8c3dcab6c4fad5d13de2285a4b3c7be063b239b8dc15ddfb73bd2"}, + {file = "zope.interface-7.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:80ecf2451596f19fd607bb09953f426588fc1e79e93f5968ecf3367550396b22"}, + {file = "zope.interface-7.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:033b3923b63474800b04cba480b70f6e6243a62208071fc148354f3f89cc01b7"}, + {file = "zope.interface-7.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a102424e28c6b47c67923a1f337ede4a4c2bba3965b01cf707978a801fc7442c"}, + {file = "zope.interface-7.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25e6a61dcb184453bb00eafa733169ab6d903e46f5c2ace4ad275386f9ab327a"}, + {file = "zope.interface-7.2-cp311-cp311-win_amd64.whl", hash = "sha256:3f6771d1647b1fc543d37640b45c06b34832a943c80d1db214a37c31161a93f1"}, + {file = "zope.interface-7.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:086ee2f51eaef1e4a52bd7d3111a0404081dadae87f84c0ad4ce2649d4f708b7"}, + {file = "zope.interface-7.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:21328fcc9d5b80768bf051faa35ab98fb979080c18e6f84ab3f27ce703bce465"}, + {file = "zope.interface-7.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6dd02ec01f4468da0f234da9d9c8545c5412fef80bc590cc51d8dd084138a89"}, + {file = "zope.interface-7.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e7da17f53e25d1a3bde5da4601e026adc9e8071f9f6f936d0fe3fe84ace6d54"}, + {file = "zope.interface-7.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cab15ff4832580aa440dc9790b8a6128abd0b88b7ee4dd56abacbc52f212209d"}, + {file = "zope.interface-7.2-cp312-cp312-win_amd64.whl", hash = "sha256:29caad142a2355ce7cfea48725aa8bcf0067e2b5cc63fcf5cd9f97ad12d6afb5"}, + {file = "zope.interface-7.2-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:3e0350b51e88658d5ad126c6a57502b19d5f559f6cb0a628e3dc90442b53dd98"}, + {file = "zope.interface-7.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:15398c000c094b8855d7d74f4fdc9e73aa02d4d0d5c775acdef98cdb1119768d"}, + {file = "zope.interface-7.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:802176a9f99bd8cc276dcd3b8512808716492f6f557c11196d42e26c01a69a4c"}, + {file = "zope.interface-7.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb23f58a446a7f09db85eda09521a498e109f137b85fb278edb2e34841055398"}, + {file = "zope.interface-7.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a71a5b541078d0ebe373a81a3b7e71432c61d12e660f1d67896ca62d9628045b"}, + {file = "zope.interface-7.2-cp313-cp313-win_amd64.whl", hash = "sha256:4893395d5dd2ba655c38ceb13014fd65667740f09fa5bb01caa1e6284e48c0cd"}, + {file = "zope.interface-7.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d3a8ffec2a50d8ec470143ea3d15c0c52d73df882eef92de7537e8ce13475e8a"}, + {file = "zope.interface-7.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:31d06db13a30303c08d61d5fb32154be51dfcbdb8438d2374ae27b4e069aac40"}, + {file = "zope.interface-7.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e204937f67b28d2dca73ca936d3039a144a081fc47a07598d44854ea2a106239"}, + {file = "zope.interface-7.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:224b7b0314f919e751f2bca17d15aad00ddbb1eadf1cb0190fa8175edb7ede62"}, + {file = "zope.interface-7.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baf95683cde5bc7d0e12d8e7588a3eb754d7c4fa714548adcd96bdf90169f021"}, + {file = "zope.interface-7.2-cp38-cp38-win_amd64.whl", hash = "sha256:7dc5016e0133c1a1ec212fc87a4f7e7e562054549a99c73c8896fa3a9e80cbc7"}, + {file = "zope.interface-7.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7bd449c306ba006c65799ea7912adbbfed071089461a19091a228998b82b1fdb"}, + {file = "zope.interface-7.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a19a6cc9c6ce4b1e7e3d319a473cf0ee989cbbe2b39201d7c19e214d2dfb80c7"}, + {file = "zope.interface-7.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:72cd1790b48c16db85d51fbbd12d20949d7339ad84fd971427cf00d990c1f137"}, + {file = "zope.interface-7.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52e446f9955195440e787596dccd1411f543743c359eeb26e9b2c02b077b0519"}, + {file = "zope.interface-7.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ad9913fd858274db8dd867012ebe544ef18d218f6f7d1e3c3e6d98000f14b75"}, + {file = "zope.interface-7.2-cp39-cp39-win_amd64.whl", hash = "sha256:1090c60116b3da3bfdd0c03406e2f14a1ff53e5771aebe33fec1edc0a350175d"}, + {file = "zope.interface-7.2.tar.gz", hash = "sha256:8b49f1a3d1ee4cdaf5b32d2e738362c7f5e40ac8b46dd7d1a65e82a4872728fe"}, +] + +[package.dependencies] +setuptools = "*" + +[package.extras] +docs = ["Sphinx", "furo", "repoze.sphinx.autointerface"] +test = ["coverage[toml]", "zope.event", "zope.testing"] +testing = ["coverage[toml]", "zope.event", "zope.testing"] + [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.13" -content-hash = "fd8aedaf15b79941f66f282c130b91ed0ebc79d3b05ffc0aadb929f457f5f32b" +content-hash = "9fdd49a33f617e3bc6811cbf06f8fc9a20defc432cefdd1f14f14d568cd08742" diff --git a/python/aibrix/pyproject.toml b/python/aibrix/pyproject.toml index 8c153399d..317e84de2 100644 --- a/python/aibrix/pyproject.toml +++ b/python/aibrix/pyproject.toml @@ -46,6 +46,8 @@ aibrix_runtime = 'aibrix.app:main' aibrix_download = 'aibrix.downloader.__main__:main' aibrix_benchmark = "aibrix.gpu_optimizer.optimizer.profiling.benchmark:main" aibrix_gen_profile = 'aibrix.gpu_optimizer.optimizer.profiling.gen_profile:main' +aibrix_batch_worker = 'aibrix.batch.worker:main' +aibrix_api_extension = 'aibrix.metadata.app:main' [tool.poetry.dependencies] python = ">=3.10,<3.13" @@ -84,7 +86,12 @@ tenacity = "^9.0.0" mypy = "1.11.1" ruff = "0.6.1" pytest = "^8.3.2" +anyio = "4.8.0" pytest-asyncio = "^1.1.0" +pytest-tornasync = "0.6.0.post2" +pytest-trio = "0.8.0" +pytest-twisted = "1.14.3" +twisted = "25.5.0" [build-system] requires = ["poetry-core", "poetry-dynamic-versioning"] diff --git a/python/aibrix/scripts/generate_secrets.py b/python/aibrix/scripts/generate_secrets.py new file mode 100644 index 000000000..1e113ce63 --- /dev/null +++ b/python/aibrix/scripts/generate_secrets.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +""" +Command-line utility for generating Kubernetes secrets for S3 and TOS storage. + +This script provides a convenient way to create storage secrets in Kubernetes +clusters using the secret_gen module. +""" + +import argparse +import sys +from pathlib import Path + +# Add the parent directory to the path so we can import aibrix modules +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from aibrix.metadata.secret_gen import SecretGenerator + + +def create_s3_secret_cli(args): + """Create an S3 secret from CLI arguments.""" + try: + generator = SecretGenerator(namespace=args.namespace) + + secret_name = generator.create_s3_secret( + bucket_name=args.bucket, secret_name=args.name + ) + + print(f"✅ Successfully created S3 secret: {secret_name}") + print(f" Namespace: {args.namespace}") + if args.bucket: + print(f" Bucket: {args.bucket}") + + except Exception as e: + print(f"❌ Failed to create S3 secret: {e}") + sys.exit(1) + + +def create_tos_secret_cli(args): + """Create a TOS secret from CLI arguments.""" + try: + generator = SecretGenerator(namespace=args.namespace) + + secret_name = generator.create_tos_secret( + bucket_name=args.bucket, secret_name=args.name + ) + + print(f"✅ Successfully created TOS secret: {secret_name}") + print(f" Namespace: {args.namespace}") + if args.bucket: + print(f" Bucket: {args.bucket}") + + except Exception as e: + print(f"❌ Failed to create TOS secret: {e}") + sys.exit(1) + + +def delete_secret_cli(args): + """Delete a secret from CLI arguments.""" + try: + generator = SecretGenerator(namespace=args.namespace) + + if generator.delete_secret(args.secret_name): + print(f"✅ Successfully deleted secret: {args.secret_name}") + print(f" Namespace: {args.namespace}") + else: + print(f"⚠️ Secret not found: {args.secret_name}") + print(f" Namespace: {args.namespace}") + + except Exception as e: + print(f"❌ Failed to delete secret: {e}") + sys.exit(1) + + +def list_secrets_cli(args): + """List secrets in the namespace.""" + try: + generator = SecretGenerator(namespace=args.namespace) + + # Get all secrets in the namespace + secrets = generator.core_v1.list_namespaced_secret(namespace=args.namespace) + + print(f"📋 Secrets in namespace '{args.namespace}':") + print("-" * 50) + + if not secrets.items: + print(" No secrets found") + else: + for secret in secrets.items: + secret_type = secret.type or "Opaque" + data_keys = list(secret.data.keys()) if secret.data else [] + print(f" {secret.metadata.name} ({secret_type})") + if data_keys: + print(f" Keys: {', '.join(data_keys)}") + + except Exception as e: + print(f"❌ Failed to list secrets: {e}") + sys.exit(1) + + +def main(): + """Main CLI function.""" + parser = argparse.ArgumentParser( + description="Generate Kubernetes secrets for S3 and TOS storage", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Create S3 secret with default name + python generate_secrets.py s3 --bucket my-bucket + + # Create S3 secret with custom name + python generate_secrets.py s3 --bucket my-bucket --name my-s3-creds + + # Create TOS secret (requires TOS_* environment variables) + python generate_secrets.py tos --bucket my-tos-bucket + + # Delete a secret + python generate_secrets.py delete my-secret-name + + # List all secrets in namespace + python generate_secrets.py list + + # Use custom namespace (either position works) + python generate_secrets.py --namespace my-namespace s3 --bucket my-bucket + python generate_secrets.py s3 --bucket my-bucket --namespace my-namespace + """, + ) + + parser.add_argument( + "--namespace", + "-n", + default="default", + help="Kubernetes namespace (default: default)", + ) + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # S3 secret command + s3_parser = subparsers.add_parser("s3", help="Create S3 secret") + s3_parser.add_argument("--bucket", "-b", help="S3 bucket name (optional)") + s3_parser.add_argument( + "--name", help="Custom secret name (optional, uses template default)" + ) + s3_parser.add_argument( + "--namespace", + "-n", + default="default", + help="Kubernetes namespace (default: default)", + ) + + # TOS secret command + tos_parser = subparsers.add_parser("tos", help="Create TOS secret") + tos_parser.add_argument("--bucket", "-b", help="TOS bucket name (optional)") + tos_parser.add_argument( + "--name", help="Custom secret name (optional, uses template default)" + ) + tos_parser.add_argument( + "--namespace", + "-n", + default="default", + help="Kubernetes namespace (default: default)", + ) + + # Delete secret command + delete_parser = subparsers.add_parser("delete", help="Delete a secret") + delete_parser.add_argument("secret_name", help="Name of the secret to delete") + delete_parser.add_argument( + "--namespace", + "-n", + default="default", + help="Kubernetes namespace (default: default)", + ) + + # List secrets command + list_parser = subparsers.add_parser("list", help="List secrets in namespace") + list_parser.add_argument( + "--namespace", + "-n", + default="default", + help="Kubernetes namespace (default: default)", + ) + + args = parser.parse_args() + + if not args.command: + parser.print_help() + sys.exit(1) + + # Handle namespace argument - use subparser's namespace if provided, otherwise use main parser's + # This allows --namespace to work in either position + if hasattr(args, "namespace") and args.namespace != "default": + # Subparser namespace takes priority + namespace = args.namespace + else: + # Fall back to main parser namespace (this handles the case where --namespace comes before subcommand) + namespace = getattr(args, "namespace", "default") + + # Update args.namespace to the resolved value + args.namespace = namespace + + # Check Kubernetes access + try: + from kubernetes import config + + config.load_kube_config() + except Exception as e: + print(f"❌ Failed to load Kubernetes configuration: {e}") + print("Make sure you have kubectl configured and access to a cluster") + sys.exit(1) + + # Execute the appropriate command + if args.command == "s3": + create_s3_secret_cli(args) + elif args.command == "tos": + create_tos_secret_cli(args) + elif args.command == "delete": + delete_secret_cli(args) + elif args.command == "list": + list_secrets_cli(args) + + +if __name__ == "__main__": + main() diff --git a/python/aibrix/tests/batch/conftest.py b/python/aibrix/tests/batch/conftest.py new file mode 100644 index 000000000..bfb74452d --- /dev/null +++ b/python/aibrix/tests/batch/conftest.py @@ -0,0 +1,377 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import threading +from pathlib import Path +from typing import Any, Dict, Optional + +import boto3 +import kopf +import pytest +import yaml +from kubernetes import client, config + +from aibrix.logger import init_logger +from aibrix.metadata.app import build_app +from aibrix.metadata.cache.job import JobCache +from aibrix.metadata.setting import settings +from aibrix.storage import StorageType + +logger = init_logger(__name__) + +# Use a threading.Event to signal when the operator is ready +OPERATOR_READY = threading.Event() + + +def run_operator_in_thread(stop_flag: threading.Event): + """The target function for the operator thread.""" + # The 'ready_flag' is a special kopf argument that gets set + # when the operator has started and is ready to handle events. + kopf.run( + standalone=True, + ready_flag=OPERATOR_READY, + namespace="default", # Monitor default namespace for tests + stop_flag=stop_flag, + ) + + +@pytest.fixture(scope="session") +def k8s_config(): + """Initialize Kubernetes client.""" + try: + config.load_incluster_config() + except config.ConfigException: + config.load_kube_config() + + +@pytest.fixture +def kopf_operator(scope="function"): + """ + A session-scoped fixture to run the kopf operator in a background thread. + This ensures JobCache handlers are properly triggered during tests. + """ + from aibrix.metadata.core import KopfOperatorWrapper + + operator = KopfOperatorWrapper( + namespace="default", + startup_timeout=30, + shutdown_timeout=10, + ) + try: + # Start the kopf operator in a daemon thread + print("--- Starting kopf operator in background thread ---") + operator.start() + print("--- Kopf operator is ready, yielding to tests ---") + yield # Tests run here + + finally: + print("\n--- Kopf operator test session finished ---") + operator.stop() + + +@pytest.fixture(scope="session") +def test_namespace(): + """Use default namespace for testing.""" + return "default" + + +@pytest.fixture(scope="function") +def job_cache(kopf_operator, ensure_job_rbac): + """ + Function-scoped fixture that provides a JobCache instance. + The kopf_operator fixture ensures the operator is running. + Uses the unittest job template with the correct service account. + """ + from pathlib import Path + + template_patch_path = ( + Path(__file__).parent / "testdata" / "k8s_job_patch_unittest.yaml" + ) + return JobCache(template_patch_path=template_patch_path) + + +@pytest.fixture(scope="session") +def s3_config_available(): + """Check if S3 configuration is available locally.""" + try: + # Check for AWS credentials + session = boto3.Session() + credentials = session.get_credentials() + + if not credentials: + pytest.skip("No AWS credentials found") + + # Check for required environment variables or default credentials + access_key = credentials.access_key + secret_key = credentials.secret_key + + if not access_key or not secret_key: + pytest.skip("AWS credentials incomplete") + + # Test S3 access + s3_client = session.client("s3") + s3_client.list_buckets() + + return { + "access_key": access_key, + "secret_key": secret_key, + "region": session.region_name or "us-west-2", + } + + except Exception as e: + pytest.skip(f"S3 configuration not available: {e}") + + +@pytest.fixture(scope="session") +def redis_config_available(): + """Check if S3 configuration is available locally.""" + # Check for AWS credentials + if os.getenv("REDIS_HOST") is None: + pytest.skip("Redis configuration not available") + + +@pytest.fixture(scope="session") +def test_s3_bucket(s3_config_available): + """Get or create test S3 bucket.""" + bucket_name = os.getenv("AIBRIX_TEST_S3_BUCKET") + + session = boto3.Session() + s3_client = session.client("s3") + + try: + # Try to access the bucket + s3_client.head_bucket(Bucket=bucket_name) + logger.info(f"Using existing S3 bucket: {bucket_name}") + except s3_client.exceptions.NoSuchBucket: + pytest.skip( + f"Test bucket {bucket_name} does not exist. Set TEST_S3_BUCKET env var or create bucket." + ) + except Exception as e: + pytest.skip(f"Cannot access S3 bucket {bucket_name}: {e}") + + return bucket_name + + +@pytest.fixture(scope="session") +def s3_credentials_secret( + k8s_config, test_namespace, s3_config_available, test_s3_bucket +): + """Create K8s secret with S3 credentials from YAML template.""" + import base64 + + # Load secret template from YAML + secret_template_path = Path(__file__).parent / "testdata" / "s3_secret.yaml" + with open(secret_template_path, "r") as f: + secret_template = yaml.safe_load(f) + + core_v1 = client.CoreV1Api() + secret_name = secret_template["metadata"]["name"] + + # Populate secret data with actual values (K8s expects base64 encoded values) + secret_template["data"] = { + "access-key-id": base64.b64encode( + s3_config_available["access_key"].encode() + ).decode(), + "secret-access-key": base64.b64encode( + s3_config_available["secret_key"].encode() + ).decode(), + "region": base64.b64encode(s3_config_available["region"].encode()).decode(), + "bucket-name": base64.b64encode(test_s3_bucket.encode()).decode(), + } + + # Update namespace + secret_template["metadata"]["namespace"] = test_namespace + + # Create K8s Secret object + secret = client.V1Secret( + metadata=client.V1ObjectMeta(name=secret_name, namespace=test_namespace), + data=secret_template["data"], + type=secret_template["type"], + ) + + try: + # Delete existing secret if it exists + try: + core_v1.delete_namespaced_secret(name=secret_name, namespace=test_namespace) + except client.ApiException as e: + if e.status != 404: + raise + + # Create the secret + core_v1.create_namespaced_secret(namespace=test_namespace, body=secret) + logger.info(f"Created K8s secret: {secret_name}") + + yield secret_name + + finally: + # Cleanup: delete the secret + try: + core_v1.delete_namespaced_secret(name=secret_name, namespace=test_namespace) + logger.info(f"Deleted K8s secret: {secret_name}") + except client.ApiException as e: + if e.status != 404: + logger.warning(f"Failed to cleanup secret {secret_name}: {e}") + + +@pytest.fixture(scope="session") +def job_rbac(k8s_config, test_namespace): + """ + Session-scoped fixture to set up RBAC resources for job testing. + This ensures that tests using create_test_app with enable_k8s_job=True + have the necessary service accounts and permissions. + Returns the service account name for use in job creation. + """ + from kubernetes import utils + + # Load RBAC resources from YAML + rbac_yaml_path = Path(__file__).parent / "testdata" / "job_rbac.yaml" + + # Read YAML content and update namespace if needed + with open(rbac_yaml_path, "r") as f: + yaml_content = f.read() + + # Replace default namespace with test namespace if different + if test_namespace != "default": + yaml_content = yaml_content.replace( + "namespace: default", f"namespace: {test_namespace}" + ) + + # Parse YAML to get resource info for cleanup + rbac_docs = list(yaml.safe_load_all(yaml_content)) + created_resources = [] + service_account_name = None + + # Capture service account name for return + for doc in rbac_docs: + if doc and doc.get("kind") == "ServiceAccount": + service_account_name = doc.get("metadata", {}).get("name") + break + + try: + # Apply YAML using Kubernetes utils + logger.info(f"Applying RBAC resources from {rbac_yaml_path}") + + # Create API client + k8s_client = client.ApiClient() + + # Apply the YAML content + utils.create_from_yaml( + k8s_client, yaml_objects=rbac_docs, namespace=test_namespace + ) + + # Track created resources for cleanup + for doc in rbac_docs: + if doc: + kind = doc.get("kind") + name = doc.get("metadata", {}).get("name") + namespace = doc.get("metadata", {}).get("namespace", test_namespace) + if namespace == "default": + namespace = test_namespace + created_resources.append((kind, name, namespace)) + + logger.info( + f"Successfully applied RBAC resources. Service account: {service_account_name}" + ) + yield service_account_name + + except Exception as e: + logger.error(f"Failed to apply RBAC resources: {e}") + raise + + finally: + # Cleanup: delete created resources + logger.info("Cleaning up RBAC resources...") + core_v1 = client.CoreV1Api() + rbac_v1 = client.RbacAuthorizationV1Api() + + for kind, name, namespace in reversed(created_resources): + try: + if kind == "ServiceAccount": + core_v1.delete_namespaced_service_account( + name=name, namespace=namespace + ) + elif kind == "Role": + rbac_v1.delete_namespaced_role(name=name, namespace=namespace) + elif kind == "RoleBinding": + rbac_v1.delete_namespaced_role_binding( + name=name, namespace=namespace + ) + logger.info(f"Deleted {kind}: {name}") + except client.ApiException as e: + if e.status != 404: + logger.warning(f"Failed to cleanup {kind} {name}: {e}") + + +@pytest.fixture(scope="function") +def ensure_job_rbac(job_rbac): + """ + Function-scoped fixture that ensures RBAC resources are available for tests. + Use this fixture in tests that depend on create_test_app with enable_k8s_job=True. + Returns the service account name for use in job creation. + """ + return job_rbac + + +def create_test_app( + enable_k8s_job: bool = False, + k8s_job_patch: Optional[Path] = None, + storage_type: StorageType = StorageType.LOCAL, + metastore_type: StorageType = StorageType.LOCAL, + params: Optional[Dict[str, Any]] = None, +): + """Create a FastAPI app configured for e2e testing.""" + if params is None: + params = {} + + # Save old settings + oldStorage, oldMetaStore = settings.STORAGE_TYPE, settings.METASTORE_TYPE + # Override settings + settings.STORAGE_TYPE, settings.METASTORE_TYPE = storage_type, metastore_type + # Create app + app = build_app( + argparse.Namespace( + host=None, + port=8100, + enable_fastapi_docs=False, + disable_batch_api=False, + enable_k8s_job=enable_k8s_job, + k8s_job_patch=k8s_job_patch, + e2e_test=True, + ), + params, + ) + # RESTORE settings + settings.STORAGE_TYPE, settings.METASTORE_TYPE = oldStorage, oldMetaStore + return app + + +@pytest.fixture(scope="function") +def test_app( + k8s_config, + test_s3_bucket, + s3_credentials_secret, + redis_config_available, + ensure_job_rbac, +): + # Get the path to the unittest job template + patch_path = Path(__file__).parent / "testdata" / "k8s_job_patch_unittest.yaml" + return create_test_app( + enable_k8s_job=True, + k8s_job_patch=patch_path, + storage_type=StorageType.S3, + metastore_type=StorageType.REDIS, + params={"bucket_name": test_s3_bucket}, + ) diff --git a/python/aibrix/tests/batch/sample_job_input.jsonl b/python/aibrix/tests/batch/sample_job_input.jsonl deleted file mode 100644 index 70a9b3901..000000000 --- a/python/aibrix/tests/batch/sample_job_input.jsonl +++ /dev/null @@ -1,3 +0,0 @@ -{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} \ No newline at end of file diff --git a/python/aibrix/tests/batch/test_batch_storage_adapter.py b/python/aibrix/tests/batch/test_batch_storage_adapter.py new file mode 100644 index 000000000..c92e48d66 --- /dev/null +++ b/python/aibrix/tests/batch/test_batch_storage_adapter.py @@ -0,0 +1,491 @@ +"""Unit tests for BatchStorageAdapter.finalize_job_output_data metastore-based behavior""" + +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest + +from aibrix.batch.job_entity import ( + BatchJob, + BatchJobEndpoint, + BatchJobSpec, + BatchJobState, + BatchJobStatus, + CompletionWindow, + ObjectMeta, + RequestCountStats, + TypeMeta, +) +from aibrix.batch.storage.adapter import BatchStorageAdapter +from aibrix.storage.base import BaseStorage + + +def create_test_batch_job( + job_id: str = "test-job-id", + input_file_id: str = "test-input-file", + state: BatchJobState = BatchJobState.CREATED, + launched: int = 0, + completed: int = 0, + failed: int = 0, + total: int = 0, + output_file_id: str = "output_123", + error_file_id: str = "error_123", + temp_output_file_id: str = "temp_output_123", + temp_error_file_id: str = "temp_error_123", +) -> BatchJob: + """Factory function to create BatchJob instances for testing.""" + return BatchJob( + typeMeta=TypeMeta(apiVersion="batch/v1", kind="Job"), + metadata=ObjectMeta( + name="test-job", + namespace="default", + uid=str(uuid.uuid4()), + creationTimestamp=datetime.now(timezone.utc), + resourceVersion=None, + deletionTimestamp=None, + ), + spec=BatchJobSpec( + input_file_id=input_file_id, + endpoint=BatchJobEndpoint.CHAT_COMPLETIONS.value, + completion_window=CompletionWindow.TWENTY_FOUR_HOURS.expires_at(), + ), + status=BatchJobStatus( + jobID=job_id, + state=state, + createdAt=datetime.now(timezone.utc), + outputFileID=output_file_id, + errorFileID=error_file_id, + tempOutputFileID=temp_output_file_id, + tempErrorFileID=temp_error_file_id, + requestCounts=RequestCountStats( + total=total, + launched=launched, + completed=completed, + failed=failed, + ), + ), + ) + + +@pytest.fixture +def mock_storage(): + """Create a mock storage instance.""" + storage = AsyncMock(spec=BaseStorage) + return storage + + +@pytest.mark.asyncio +async def test_finalize_job_output_data_corrects_counts_from_metastore(mock_storage): + """Test that finalize_job_output_data correctly calculates counts from metastore keys.""" + adapter = BatchStorageAdapter(mock_storage) + + # Create job with incorrect initial counts + batch_job = create_test_batch_job( + job_id="job-123", + launched=10, # Wrong - should be corrected to 3 based on metastore + completed=5, # Wrong - should be corrected to 2 based on metadata + failed=2, # Wrong - should be corrected to 1 based on metadata + total=15, # Wrong - should be corrected to 5 based on max index + ) + + # Mock metastore keys for indices 0, 2, 4 (non-consecutive to test max calculation) + expected_keys = [ + f"batch:{batch_job.job_id}:done/0", + f"batch:{batch_job.job_id}:done/2", + f"batch:{batch_job.job_id}:done/4", + ] + + with patch( + "aibrix.batch.storage.adapter.list_metastore_keys", new_callable=AsyncMock + ) as mock_list_keys: + with patch( + "aibrix.batch.storage.adapter.get_metadata", new_callable=AsyncMock + ) as mock_get_metadata: + with patch( + "aibrix.batch.storage.adapter.delete_metadata", new_callable=AsyncMock + ) as mock_delete_metadata: + # Setup mocks + mock_list_keys.return_value = expected_keys + + # Mock metadata responses: idx 0 and 2 are outputs, idx 4 is error + mock_get_metadata.side_effect = [ + ("output:etag123", True), # idx 0: success + ("output:etag456", True), # idx 2: success + ("error:etag789", True), # idx 4: error + ] + + # Mock storage operations + mock_storage.complete_multipart_upload.return_value = None + + # Call the method + await adapter.finalize_job_output_data(batch_job) + + # Verify list_metastore_keys was called with correct prefix + mock_list_keys.assert_called_once_with( + f"batch:{batch_job.job_id}:done/" + ) + + # Verify get_metadata was called for each found index + assert mock_get_metadata.call_count == 3 + + # Verify job request counts were updated correctly + assert batch_job.status.request_counts.total == 5 # max index (4) + 1 + assert ( + batch_job.status.request_counts.launched == 3 + ) # number of keys found + assert ( + batch_job.status.request_counts.completed == 2 + ) # 2 outputs (idx 0, 2) + assert batch_job.status.request_counts.failed == 1 # 1 error (idx 4) + + # Verify storage complete_multipart_upload was called correctly + assert mock_storage.complete_multipart_upload.call_count == 2 + + # Check output parts + output_call = mock_storage.complete_multipart_upload.call_args_list[0] + output_parts = output_call[0][2] # third argument is the parts list + expected_output_parts = [ + {"etag": "etag123", "part_number": 0}, + {"etag": "etag456", "part_number": 2}, + ] + assert output_parts == expected_output_parts + + # Check error parts + error_call = mock_storage.complete_multipart_upload.call_args_list[1] + error_parts = error_call[0][2] # third argument is the parts list + expected_error_parts = [{"etag": "etag789", "part_number": 4}] + assert error_parts == expected_error_parts + + # Verify cleanup - metadata should be deleted for all valid keys + assert mock_delete_metadata.call_count == 3 + + +@pytest.mark.asyncio +async def test_finalize_job_output_data_handles_missing_metadata(mock_storage): + """Test handling when some keys exist in list but metadata is missing.""" + adapter = BatchStorageAdapter(mock_storage) + batch_job = create_test_batch_job(job_id="job-456") + + expected_keys = [ + f"batch:{batch_job.job_id}:done/0", + f"batch:{batch_job.job_id}:done/1", + f"batch:{batch_job.job_id}:done/2", + ] + + with patch( + "aibrix.batch.storage.adapter.list_metastore_keys", new_callable=AsyncMock + ) as mock_list_keys: + with patch( + "aibrix.batch.storage.adapter.get_metadata", new_callable=AsyncMock + ) as mock_get_metadata: + with patch( + "aibrix.batch.storage.adapter.delete_metadata", new_callable=AsyncMock + ) as mock_delete_metadata: + mock_list_keys.return_value = expected_keys + + # Mock metadata: idx 0 exists, idx 1 missing, idx 2 exists + mock_get_metadata.side_effect = [ + ("output:etag123", True), # idx 0: exists + ("", False), # idx 1: missing metadata + ("error:etag789", True), # idx 2: exists + ] + + mock_storage.complete_multipart_upload.return_value = None + + await adapter.finalize_job_output_data(batch_job) + + # Should calculate based on all keys found, but only process existing metadata + assert batch_job.status.request_counts.total == 3 # max index (2) + 1 + assert ( + batch_job.status.request_counts.launched == 3 + ) # all keys found in list + assert ( + batch_job.status.request_counts.completed == 1 + ) # only idx 0 had output + assert ( + batch_job.status.request_counts.failed == 1 + ) # only idx 2 had error + + # Should only delete metadata for existing keys (idx 0 and 2) + assert mock_delete_metadata.call_count == 2 + + +@pytest.mark.asyncio +async def test_finalize_job_output_data_handles_empty_metastore(mock_storage): + """Test handling when no keys are found in metastore.""" + adapter = BatchStorageAdapter(mock_storage) + batch_job = create_test_batch_job( + job_id="job-789", launched=5, total=10 + ) # Initial wrong counts + + with patch( + "aibrix.batch.storage.adapter.list_metastore_keys", new_callable=AsyncMock + ) as mock_list_keys: + with patch( + "aibrix.batch.storage.adapter.get_metadata", new_callable=AsyncMock + ) as mock_get_metadata: + with patch( + "aibrix.batch.storage.adapter.delete_metadata", new_callable=AsyncMock + ) as mock_delete_metadata: + # No keys found in metastore + mock_list_keys.return_value = [] + + mock_storage.complete_multipart_upload.return_value = None + + await adapter.finalize_job_output_data(batch_job) + + # Verify counts are corrected to zero + assert batch_job.status.request_counts.total == 0 + assert batch_job.status.request_counts.launched == 0 + assert batch_job.status.request_counts.completed == 0 + assert batch_job.status.request_counts.failed == 0 + + # get_metadata should not be called + mock_get_metadata.assert_not_called() + + # delete_metadata should not be called + mock_delete_metadata.assert_not_called() + + # Storage operations should still be called with empty lists + assert mock_storage.complete_multipart_upload.call_count == 2 + + +@pytest.mark.asyncio +async def test_finalize_job_output_data_handles_invalid_key_formats(mock_storage): + """Test handling of invalid key formats in metastore.""" + adapter = BatchStorageAdapter(mock_storage) + batch_job = create_test_batch_job(job_id="job-invalid") + + keys_with_invalid = [ + f"batch:{batch_job.job_id}:done/0", # valid + f"batch:{batch_job.job_id}:done/abc", # invalid - non-numeric + f"batch:{batch_job.job_id}:done/2", # valid + f"batch:{batch_job.job_id}:done/", # invalid - empty index + "batch:other-job:done/1", # invalid - different job prefix + "other:format:key/3", # invalid - completely different format + ] + + with patch( + "aibrix.batch.storage.adapter.list_metastore_keys", new_callable=AsyncMock + ) as mock_list_keys: + with patch( + "aibrix.batch.storage.adapter.get_metadata", new_callable=AsyncMock + ) as mock_get_metadata: + with patch( + "aibrix.batch.storage.adapter.delete_metadata", new_callable=AsyncMock + ) as mock_delete_metadata: + mock_list_keys.return_value = keys_with_invalid + + # Only valid keys should get metadata calls + mock_get_metadata.side_effect = [ + ("output:etag123", True), # idx 0 + ("output:etag456", True), # idx 2 + ] + + mock_storage.complete_multipart_upload.return_value = None + + await adapter.finalize_job_output_data(batch_job) + + # Should only process valid indices (0, 2) + assert ( + batch_job.status.request_counts.total == 3 + ) # max valid index (2) + 1 + assert batch_job.status.request_counts.launched == 2 # valid keys found + assert ( + batch_job.status.request_counts.completed == 2 + ) # both were outputs + assert batch_job.status.request_counts.failed == 0 + + # Should only call metadata operations for valid keys + assert mock_get_metadata.call_count == 2 + assert mock_delete_metadata.call_count == 2 + + +@pytest.mark.asyncio +async def test_finalize_job_output_data_no_update_when_counts_match(mock_storage): + """Test that job counts are not updated when they already match metastore data.""" + adapter = BatchStorageAdapter(mock_storage) + + # Create job with correct initial counts + batch_job = create_test_batch_job( + job_id="job-correct", + launched=2, # Correct + completed=1, # Correct + failed=1, # Correct + total=2, # Correct (max index 1 + 1) + ) + + expected_keys = [ + f"batch:{batch_job.job_id}:done/0", + f"batch:{batch_job.job_id}:done/1", + ] + + with patch( + "aibrix.batch.storage.adapter.list_metastore_keys", new_callable=AsyncMock + ) as mock_list_keys: + with patch( + "aibrix.batch.storage.adapter.get_metadata", new_callable=AsyncMock + ) as mock_get_metadata: + with patch( + "aibrix.batch.storage.adapter.delete_metadata", new_callable=AsyncMock + ): + mock_list_keys.return_value = expected_keys + mock_get_metadata.side_effect = [ + ("output:etag123", True), # idx 0: success + ("error:etag456", True), # idx 1: error + ] + mock_storage.complete_multipart_upload.return_value = None + + # Store initial counts to verify they don't change + initial_counts = { + "total": batch_job.status.request_counts.total, + "launched": batch_job.status.request_counts.launched, + "completed": batch_job.status.request_counts.completed, + "failed": batch_job.status.request_counts.failed, + } + + await adapter.finalize_job_output_data(batch_job) + + # Verify counts remain the same (since they were already correct) + assert batch_job.status.request_counts.total == initial_counts["total"] + assert ( + batch_job.status.request_counts.launched + == initial_counts["launched"] + ) + assert ( + batch_job.status.request_counts.completed + == initial_counts["completed"] + ) + assert ( + batch_job.status.request_counts.failed == initial_counts["failed"] + ) + + +@pytest.mark.asyncio +async def test_finalize_job_output_data_sequential_indices_calculation(mock_storage): + """Test proper calculation when indices are sequential starting from 0.""" + adapter = BatchStorageAdapter(mock_storage) + batch_job = create_test_batch_job(job_id="job-sequential") + + # Sequential indices 0,1,2,3,4 + expected_keys = [ + f"batch:{batch_job.job_id}:done/0", + f"batch:{batch_job.job_id}:done/1", + f"batch:{batch_job.job_id}:done/2", + f"batch:{batch_job.job_id}:done/3", + f"batch:{batch_job.job_id}:done/4", + ] + + with patch( + "aibrix.batch.storage.adapter.list_metastore_keys", new_callable=AsyncMock + ) as mock_list_keys: + with patch( + "aibrix.batch.storage.adapter.get_metadata", new_callable=AsyncMock + ) as mock_get_metadata: + with patch( + "aibrix.batch.storage.adapter.delete_metadata", new_callable=AsyncMock + ): + mock_list_keys.return_value = expected_keys + mock_get_metadata.side_effect = [ + ("output:etag0", True), + ("output:etag1", True), + ("error:etag2", True), + ("output:etag3", True), + ("error:etag4", True), + ] + mock_storage.complete_multipart_upload.return_value = None + + await adapter.finalize_job_output_data(batch_job) + + # For sequential indices 0-4, total should be 5 (max index 4 + 1) + assert batch_job.status.request_counts.total == 5 + assert batch_job.status.request_counts.launched == 5 # All 5 keys found + assert ( + batch_job.status.request_counts.completed == 3 + ) # indices 0,1,3 are outputs + assert ( + batch_job.status.request_counts.failed == 2 + ) # indices 2,4 are errors + + +@pytest.mark.asyncio +async def test_finalize_job_output_data_single_request(mock_storage): + """Test edge case with single request (index 0 only).""" + adapter = BatchStorageAdapter(mock_storage) + batch_job = create_test_batch_job( + job_id="job-single", launched=10, total=20 + ) # Wrong initial counts + + expected_keys = [f"batch:{batch_job.job_id}:done/0"] + + with patch( + "aibrix.batch.storage.adapter.list_metastore_keys", new_callable=AsyncMock + ) as mock_list_keys: + with patch( + "aibrix.batch.storage.adapter.get_metadata", new_callable=AsyncMock + ) as mock_get_metadata: + with patch( + "aibrix.batch.storage.adapter.delete_metadata", new_callable=AsyncMock + ): + mock_list_keys.return_value = expected_keys + mock_get_metadata.side_effect = [("output:etag0", True)] + mock_storage.complete_multipart_upload.return_value = None + + await adapter.finalize_job_output_data(batch_job) + + # For single request at index 0, total should be 1 + assert batch_job.status.request_counts.total == 1 + assert batch_job.status.request_counts.launched == 1 + assert batch_job.status.request_counts.completed == 1 + assert batch_job.status.request_counts.failed == 0 + + +@pytest.mark.asyncio +async def test_finalize_job_output_data_preserves_part_numbers(mock_storage): + """Test that part numbers in multipart upload correspond to original indices.""" + adapter = BatchStorageAdapter(mock_storage) + batch_job = create_test_batch_job(job_id="job-parts") + + # Non-sequential indices to test part number preservation + expected_keys = [ + f"batch:{batch_job.job_id}:done/5", # Large index + f"batch:{batch_job.job_id}:done/10", # Even larger index + f"batch:{batch_job.job_id}:done/15", # Largest index + ] + + with patch( + "aibrix.batch.storage.adapter.list_metastore_keys", new_callable=AsyncMock + ) as mock_list_keys: + with patch( + "aibrix.batch.storage.adapter.get_metadata", new_callable=AsyncMock + ) as mock_get_metadata: + with patch( + "aibrix.batch.storage.adapter.delete_metadata", new_callable=AsyncMock + ): + mock_list_keys.return_value = expected_keys + mock_get_metadata.side_effect = [ + ("output:etag5", True), # idx 5 + ("output:etag10", True), # idx 10 + ("error:etag15", True), # idx 15 + ] + mock_storage.complete_multipart_upload.return_value = None + + await adapter.finalize_job_output_data(batch_job) + + # Verify part numbers match original indices + output_call = mock_storage.complete_multipart_upload.call_args_list[0] + output_parts = output_call[0][2] + expected_output_parts = [ + {"etag": "etag5", "part_number": 5}, + {"etag": "etag10", "part_number": 10}, + ] + assert output_parts == expected_output_parts + + error_call = mock_storage.complete_multipart_upload.call_args_list[1] + error_parts = error_call[0][2] + expected_error_parts = [{"etag": "etag15", "part_number": 15}] + assert error_parts == expected_error_parts + + # Total should be max index + 1 + assert batch_job.status.request_counts.total == 16 # 15 + 1 diff --git a/python/aibrix/tests/batch/test_driver.py b/python/aibrix/tests/batch/test_driver.py index b81a22211..c29481dbe 100644 --- a/python/aibrix/tests/batch/test_driver.py +++ b/python/aibrix/tests/batch/test_driver.py @@ -20,14 +20,16 @@ import pytest -from aibrix.batch.constant import EXPIRE_INTERVAL +import aibrix.batch.constant as constant from aibrix.batch.driver import BatchDriver -from aibrix.batch.job_entity import BatchJobState +from aibrix.batch.job_entity import BatchJobErrorCode, BatchJobState, BatchJobStatus from aibrix.storage import StorageType +constant.EXPIRE_INTERVAL = 0.1 + def generate_input_data(num_requests, local_file): - input_name = Path(os.path.dirname(__file__)) / "sample_job_input.jsonl" + input_name = Path(os.path.dirname(__file__)) / "testdata" / "sample_job_input.jsonl" data = None with open(input_name, "r") as file: for line in file.readlines(): @@ -48,6 +50,7 @@ async def test_batch_driver_job_creation(): driver = BatchDriver( storage_type=StorageType.LOCAL, metastore_type=StorageType.LOCAL ) + await driver.start() # Test that driver is properly initialized assert driver is not None @@ -78,7 +81,7 @@ async def test_batch_driver_job_creation(): print(f"Created job_id: {job_id}") # Test status retrieval - job = driver.job_manager.get_job(job_id) + job = await driver.job_manager.get_job(job_id) assert job is not None print(f"Job status: {job.status.state}") assert job.status.state == BatchJobState.CREATED @@ -87,7 +90,7 @@ async def test_batch_driver_job_creation(): await driver.clear_job(job_id) finally: # Shutdown driver - await driver.close() + await driver.stop() # Clean up temporary file Path(temp_path).unlink(missing_ok=True) @@ -103,6 +106,7 @@ async def test_batch_driver_integration(): driver = BatchDriver( storage_type=StorageType.LOCAL, metastore_type=StorageType.LOCAL ) + await driver.start() # Create temporary input file with tempfile.NamedTemporaryFile( @@ -128,14 +132,103 @@ async def test_batch_driver_integration(): assert job_id is not None print(f"Created job_id: {job_id}") - job = driver.job_manager.get_job(job_id) + job = await driver.job_manager.get_job(job_id) assert job is not None print(f"Initial status: {job.status.state}") assert job.status.state == BatchJobState.CREATED # 3. Wait for job to be scheduled and start processing - await asyncio.sleep(5 * EXPIRE_INTERVAL) - job = driver.job_manager.get_job(job_id) + await asyncio.sleep(3 * constant.EXPIRE_INTERVAL) + job = await driver.job_manager.get_job(job_id) + assert job is not None + print(f"Status after scheduling: {job.status.state}") + assert job.status.state == BatchJobState.IN_PROGRESS + assert job.status.output_file_id is not None + assert job.status.error_file_id is not None + + # 4. Wait for job to complete + while True: + await asyncio.sleep(1 * constant.EXPIRE_INTERVAL) + job = await driver.job_manager.get_job(job_id) + assert job is not None + print(f"Progressing: {job.status.state}") + if job.status.finished: + break + + print(f"Final status: {job.status.state}") + assert job.status.state == BatchJobState.FINALIZED + assert job.status.completed + assert job.status.output_file_id is not None + assert job.status.error_file_id is not None + + # 5. Retrieve results and verify they exist + results = await driver.retrieve_job_result(job.status.output_file_id) + assert results is not None + assert len(results) == 10 # Should match num_requests + print(f"Retrieved {len(results)} results") + + # 6. Verify results content + for i, req_result in enumerate(results): + print(f"Result {i}: {req_result}") + assert req_result is not None + + # 7. Clean up the job + await driver.clear_job(job_id) + print(f"Job {job_id} cleaned up") + + finally: + # Shutdown driver + await driver.stop() + + # Clean up temporary file + Path(temp_path).unlink(missing_ok=True) + + +@pytest.mark.asyncio +async def test_batch_driver_resuming(): + """ + Integration test for the batch driver workflow. + Tests job creation, scheduling, execution, and result retrieval. + """ + # Initialize driver without job_entity_manager (use local job management) + driver = BatchDriver( + storage_type=StorageType.LOCAL, metastore_type=StorageType.LOCAL + ) + await driver.start() + + # Create temporary input file + with tempfile.NamedTemporaryFile( + mode="w", suffix=".jsonl", delete=False + ) as temp_file: + temp_path = temp_file.name + generate_input_data(10, temp_path) + + try: + # 1. Upload batch data and verify it's stored locally + upload_id = await driver.upload_job_data(temp_path) + assert upload_id is not None + print(f"Upload ID: {upload_id}") + + # 2. Create job and verify initial state + job_id = await driver.job_manager.create_job( + session_id="test-session-integration", + input_file_id=str(upload_id), + api_endpoint="/v1/chat/completions", + completion_window="24h", + meta_data={}, + initial_state=BatchJobState.IN_PROGRESS, + ) + assert job_id is not None + print(f"Created job_id: {job_id}") + + job = await driver.job_manager.get_job(job_id) + assert job is not None + print(f"Initial status: {job.status.state}") + assert job.status.state == BatchJobState.IN_PROGRESS + + # 3. Wait for job to be scheduled and start processing + await asyncio.sleep(3 * constant.EXPIRE_INTERVAL) + job = await driver.job_manager.get_job(job_id) assert job is not None print(f"Status after scheduling: {job.status.state}") assert job.status.state == BatchJobState.IN_PROGRESS @@ -143,16 +236,17 @@ async def test_batch_driver_integration(): assert job.status.error_file_id is not None # 4. Wait for job to complete - for i in range(10): - await asyncio.sleep(1 * EXPIRE_INTERVAL) - job = driver.job_manager.get_job(job_id) + while True: + await asyncio.sleep(1 * constant.EXPIRE_INTERVAL) + job = await driver.job_manager.get_job(job_id) assert job is not None print(f"Progressing: {job.status.state}") - if job.status.state.is_finished(): + if job.status.finished: break print(f"Final status: {job.status.state}") - assert job.status.state == BatchJobState.COMPLETED + assert job.status.state == BatchJobState.FINALIZED + assert job.status.completed assert job.status.output_file_id is not None assert job.status.error_file_id is not None @@ -173,7 +267,145 @@ async def test_batch_driver_integration(): finally: # Shutdown driver - await driver.close() + await driver.stop() + + # Clean up temporary file + Path(temp_path).unlink(missing_ok=True) + + +@pytest.mark.asyncio +async def test_batch_driver_validation_failed() -> None: + """ + Integration test for the batch driver workflow. + Tests job creation, scheduling, and validation failed. + """ + # Initialize driver without job_entity_manager (use local job management) + driver = BatchDriver( + storage_type=StorageType.LOCAL, metastore_type=StorageType.LOCAL + ) + await driver.start() + + try: + # 1. Create job with non-exist upload_id + job_id = await driver.job_manager.create_job( + session_id="test-session-integration", + input_file_id="non-exist-upload-id", + api_endpoint="/v1/chat/completions", + completion_window="24h", + meta_data={}, + ) + assert job_id is not None + print(f"Created job_id: {job_id}") + + job = await driver.job_manager.get_job(job_id) + assert job is not None + print(f"Initial status: {job.status.state}") + assert job.status.state == BatchJobState.CREATED + + # 2. Wait for job to be scheduled and validation failed + await asyncio.sleep(5 * constant.EXPIRE_INTERVAL) + job = await driver.job_manager.get_job(job_id) + assert job is not None + + job_status: BatchJobStatus = job.status + print(f"Status after scheduling: {job_status.state}") + assert job_status.state == BatchJobState.FINALIZED + assert job_status.failed + assert job_status.output_file_id is None + assert job_status.error_file_id is None + assert job_status.errors is not None + assert len(job_status.errors) > 0 + assert job_status.errors[0].code == BatchJobErrorCode.INVALID_INPUT_FILE + + # 7. Clean up the job + await driver.clear_job(job_id) + print(f"Job {job_id} cleaned up") + + finally: + # Shutdown driver + await driver.stop() + + +@pytest.mark.asyncio +async def test_batch_driver_stop_raises_exception_with_fail_after_n_requests(): + """Test that BatchDriver.stop() raises RuntimeError when jobs with fail_after_n_requests exist.""" + + driver = BatchDriver( + storage_type=StorageType.LOCAL, + metastore_type=StorageType.LOCAL, + stand_alone=False, + ) + + # Create a temporary file for job input + temp_file_descriptor, temp_path = tempfile.mkstemp(suffix=".jsonl") + try: + # Generate test data + generate_input_data(3, temp_path) + + # Start the driver + await driver.start() + + # 1. Upload input data + input_file_id = await driver.upload_job_data(temp_path) + print(f"Input file uploaded: {input_file_id}") + + # 2. Create a job with fail_after_n_requests opts + from aibrix.batch.job_entity import BatchJobSpec + + job_spec = BatchJobSpec.from_strings( + input_file_id=input_file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"test": "metadata"}, + opts={ + constant.BATCH_OPTS_FAIL_AFTER_N_REQUESTS: "2" + }, # This should trigger the stop() exception + ) + + job_id = await driver.job_manager.create_job_with_spec( + session_id="test-session", job_spec=job_spec + ) + print(f"Job created with fail_after_n_requests: {job_id}") + # 3. Verify the job was created successfully + job = await driver.job_manager.get_job(job_id) + assert job is not None + assert job.spec.opts is not None + assert constant.BATCH_OPTS_FAIL_AFTER_N_REQUESTS in job.spec.opts + + # 4. Wait for job to complete + waited = 0 + while True: + await asyncio.sleep(1 * constant.EXPIRE_INTERVAL) + waited += 1 + job = await driver.job_manager.get_job(job_id) + assert job is not None + print(f"Progressing: {job.status.state}") + if job.status.finished: + break + if waited > 10: + assert False, "job timeout" + + print(f"Final status: {job.status.state}") + assert job.status.state == BatchJobState.FINALIZED + assert job.status.failed + assert job.status.output_file_id is not None + assert job.status.error_file_id is not None + + # wait for exception reach driver. + await asyncio.sleep(3.0) + + # 5. Attempt to stop the driver - this should raise RuntimeError + with pytest.raises(RuntimeError, match="Artificial failure.*"): + await driver.stop() + + print( + "✅ BatchDriver.stop() correctly raised RuntimeError for job with fail_after_n_requests" + ) + + # 6. Clean up the job to allow proper shutdown + await driver.clear_job(job_id) + + finally: # Clean up temporary file Path(temp_path).unlink(missing_ok=True) diff --git a/python/aibrix/tests/batch/test_e2e_abnormal_job_behavior.py b/python/aibrix/tests/batch/test_e2e_abnormal_job_behavior.py new file mode 100644 index 000000000..2a7448631 --- /dev/null +++ b/python/aibrix/tests/batch/test_e2e_abnormal_job_behavior.py @@ -0,0 +1,932 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end tests for abnormal job behavior scenarios. + +These tests cover various failure modes and edge cases in the batch job lifecycle: +- Validation failures +- Processing failures during different stages +- Job cancellation at various points +- Job expiration scenarios +""" + +import asyncio +import json +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +import aibrix.batch.constant as constant +from aibrix.batch.job_entity import ( + BatchJobError, + BatchJobErrorCode, + BatchJobSpec, + BatchJobState, +) +from aibrix.batch.job_manager import JobManager, JobMetaInfo + +T = TypeVar("T") + + +def generate_batch_input_data(num_requests: int = 3) -> str: + """Generate test batch input data and return the content as string.""" + base_request = { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo-0125", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello world!"}, + ], + "max_tokens": 1000, + }, + } + + lines = [] + for i in range(num_requests): + request = base_request.copy() + request["custom_id"] = f"request-{i+1}" + lines.append(json.dumps(request)) + + return "\n".join(lines) + + +async def wait_for_status( + client: TestClient, + batch_id: str, + expected_status: str, + extra_expected_fields: Optional[Union[str, List[str]]] = None, + max_polls: int = 20, + poll_interval: float = 0.5, +) -> Dict[str, Any]: + """Wait for batch job to reach expected status.""" + for attempt in range(max_polls): + response = client.get(f"/v1/batches/{batch_id}") + assert response.status_code == 200, f"Status check failed: {response.text}" + + result = response.json() + current_status = result["status"] + + if current_status == expected_status: + if extra_expected_fields is None: + return result + elif ( + isinstance(extra_expected_fields, str) + and result[extra_expected_fields] is not None + ): + return result + else: + satisfied = True + # Extra fields must all present. + for field in extra_expected_fields: + if result[extra_expected_fields] is None: + satisfied = False + if satisfied: + return result + elif current_status in ["failed", "cancelled", "expired", "completed"]: + # Terminal states + if current_status != expected_status: + return result # Return actual final state + + await asyncio.sleep(poll_interval) + + # Return last known status if timeout + return result + + +def validate_batch_response( + response: Dict[str, Any], + *, + # Required fields (default checked for not None) + expected_id: Optional[str] = None, + expected_object: str = "batch", + expected_endpoint: Optional[str] = None, + expected_input_file_id: Optional[str] = None, + expected_completion_window: str = "24h", + expected_status: Optional[str] = None, + # Optional fields (default checked for None) + expected_errors: Optional[Union[str, List[str]]] = None, # List of error codes + expected_output_file_id: Optional[bool] = None, + expected_error_file_id: Optional[bool] = None, + expected_in_progress_at: Optional[bool] = None, + expected_finalizing_at: Optional[bool] = None, + expected_completed_at: Optional[bool] = None, + expected_failed_at: Optional[bool] = None, + expected_expired_at: Optional[bool] = None, + expected_cancelling_at: Optional[bool] = None, + expected_cancelled_at: Optional[bool] = None, + expected_request_counts: Optional[Union[bool, Dict[str, int]]] = None, + expected_metadata: Optional[Dict[str, str]] = None, +) -> None: + """Validate batch response fields according to BatchResponse schema. + + This function validates all fields as defined in api/v1/batch.py::BatchResponse. + + Args: + response: The batch response dict to validate + expected_*: Expected values or presence checks for each field + - For required fields: actual expected value or None to skip check + - For optional fields: True = not None, False = None, None = skip check + """ + # Verify all BatchResponse fields are present + required_fields = [ + "id", + "object", + "endpoint", + "input_file_id", + "completion_window", + "status", + "created_at", + "expires_at", + ] + optional_fields: List[Tuple[str, Any, type]] = [ + # Pass bool to skip equal test. + ("errors", True if expected_errors else False, Dict), + ("output_file_id", expected_output_file_id, str), + ("error_file_id", expected_error_file_id, str), + ("in_progress_at", expected_in_progress_at, int), + ("finalizing_at", expected_finalizing_at, int), + ("completed_at", expected_completed_at, int), + ("failed_at", expected_failed_at, int), + ("expired_at", expected_expired_at, int), + ("cancelling_at", expected_cancelling_at, int), + ("cancelled_at", expected_cancelled_at, int), + ("request_counts", expected_request_counts, Dict), + ("metadata", expected_metadata, Dict), + ] + + # Check that all expected fields exist + all_fields = required_fields + [field[0] for field in optional_fields] + for field in all_fields: + assert field in response, f"Field '{field}' missing from response" + + # Validate required fields + if expected_id is not None: + assert ( + response["id"] == expected_id + ), f"Expected id '{expected_id}', got '{response['id']}'" + else: + assert response["id"] is not None, "Required field 'id' should not be None" + + assert ( + response["object"] == expected_object + ), f"Expected object '{expected_object}', got '{response['object']}'" + + if expected_endpoint is not None: + assert ( + response["endpoint"] == expected_endpoint + ), f"Expected endpoint '{expected_endpoint}', got '{response['endpoint']}'" + else: + assert ( + response["endpoint"] is not None + ), "Required field 'endpoint' should not be None" + + if expected_input_file_id is not None: + assert ( + response["input_file_id"] == expected_input_file_id + ), f"Expected input_file_id '{expected_input_file_id}', got '{response['input_file_id']}'" + else: + assert ( + response["input_file_id"] is not None + ), "Required field 'input_file_id' should not be None" + + assert ( + response["completion_window"] == expected_completion_window + ), f"Expected completion_window '{expected_completion_window}', got '{response['completion_window']}'" + + if expected_status is not None: + assert ( + response["status"] == expected_status + ), f"Expected status '{expected_status}', got '{response['status']}'" + else: + assert ( + response["status"] is not None + ), "Required field 'status' should not be None" + + # created_at is required and should not be None + assert ( + response["created_at"] is not None + ), "Required field 'created_at' should not be None" + assert isinstance( + response["created_at"], int + ), "Expected 'created_at' to be unix timestamp (int)" + + # created_at is required and should not be None + assert ( + response["expires_at"] is not None + ), "Required field 'expires_at' should not be None" + assert isinstance( + response["expires_at"], int + ), "Expected 'expires_at' to be unix timestamp (int)" + if expected_completion_window == "24h": + assert ( + response["expires_at"] == response["created_at"] + 86400 + ), "Expected 'expires_at' to be 'created_at' + 86400" + + # Validate optional fields + def check_optional_field( + field_name: str, expected_value: Optional[T], expected_type: type + ): + if expected_value is None or expected_value is False: + assert response[field_name] is None, f"Expected '{field_name}' to be None" + elif expected_value is True and expected_type is not bool: + assert ( + response[field_name] is not None + ), f"Required field '{field_name}' should not be None" + assert isinstance( + response[field_name], expected_type + ), f"Expected '{field_name}' to be type ({expected_type})" + else: + assert ( + response[field_name] == expected_value + ), f"Expected {field_name} '{expected_value}', got '{response[field_name]}'" + + for field_name, expected_value, expected_type in optional_fields: + check_optional_field(field_name, expected_value, expected_type) + + # Check non-timestamp optional fields + if expected_errors is not None: + if isinstance(expected_errors, str): + expected_errors = [expected_errors] + assert isinstance(response["errors"], dict), "Expected 'errors' to be dict" + assert "data" in response["errors"], "Expected 'errors.data' field" + assert isinstance( + response["errors"]["data"], list + ), "Expected 'errors.data' to be list" + assert len(response["errors"]["data"]) > 0 + errors = {} + for error in response["errors"]["data"]: + assert "code" in error, f"No 'code' in error:{error}" + assert "message" in error, f"No 'message' in error:{error}" + errors[error["code"]] = error["message"] + for err_code in expected_errors: + assert err_code in errors + + +class FailingJobManager(JobManager): + """JobManager that can be configured to fail at specific stages.""" + + def __init__( + self, + fail_validation: bool = False, + fail_during_processing: bool = False, + fail_during_finalizing: bool = False, + stall_validation: Optional[float] = None, + stall_cancelling: Optional[float] = None, + fail_after_n_requests: Optional[int] = None, + expiration: Optional[int] = None, + ): + super().__init__() + self.fail_validation = fail_validation + self.fail_during_processing = fail_during_processing + self.fail_during_finalizing = fail_during_finalizing + self.stall_validation = stall_validation + self.stall_cancelling = stall_cancelling + self.fail_after_n_requests = fail_after_n_requests + self.expiration = expiration + self._processed_requests = 0 + + async def validate_job(self, meta_data: JobMetaInfo): + """Override to simulate validation failures during job execution start.""" + if self.stall_validation is not None: + # Prolong validation duration to allow cancellation during validation + await asyncio.sleep(self.stall_validation) + + if self.fail_validation: + # Mark job as failed with authentication error + raise BatchJobError( + code=BatchJobErrorCode.AUTHENTICATION_ERROR, + message="Simulated authentication failure", + param="authentication", + ) + + return await super().validate_job(meta_data) + + async def cancel_job(self, job_id: str) -> bool: + if self.stall_cancelling is not None: + await asyncio.sleep(self.stall_cancelling) + + return await super().cancel_job(job_id) + + async def create_job_with_spec( + self, + session_id: str, + job_spec: BatchJobSpec, + timeout: float = 30.0, + initial_state: BatchJobState = BatchJobState.CREATED, + ) -> str: + """Override create_job to inject fail_after_n_requests in opts.""" + # Create job spec with opts if needed + if self.fail_after_n_requests is not None: + opts = job_spec.opts or {} + opts[constant.BATCH_OPTS_FAIL_AFTER_N_REQUESTS] = str( + self.fail_after_n_requests + ) + job_spec.opts = opts + if self.expiration is not None: + job_spec.completion_window = self.expiration + + return await super().create_job_with_spec( + session_id, job_spec, timeout, initial_state + ) + + +@pytest.mark.asyncio +async def test_job_validation_failure(test_app): + """Test case 1: Create job, failure during validation.""" + print("Test 1: Job validation failure scenario") + + with TestClient(test_app) as client: + # Step 1: Skip uploading file + + # Step 2: Inject the FailingJobManager to fail during validation + original_manager = test_app.state.batch_driver.job_manager + failing_manager = FailingJobManager(fail_validation=True) + await test_app.state.batch_driver.run_coroutine( + failing_manager.set_job_entity_manager(original_manager._job_entity_manager) + ) + test_app.state.batch_driver._job_manager = failing_manager + + try: + # Step 3: Create batch job + batch_request = { + "input_file_id": "invalid_input_file_id", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + # Step 4: Wait for validation to fail + final_status = await wait_for_status(client, batch_id, "failed") + + # Step 5: Verify failed status using comprehensive validation + validate_batch_response( + final_status, + expected_status="failed", + expected_failed_at=True, + expected_errors="authentication_error", + ) + + print("✅ Validation failure test completed successfully") + + await test_app.state.batch_driver.clear_job(batch_id) + finally: + # Restore original manager + test_app.state.batch_driver._job_manager = original_manager + + +@pytest.mark.asyncio +async def test_job_processing_failure(test_app): + """Test case 2: Create job, failure during in progress using k8s job worker with fail_after metadata.""" + print("Test 2: Job processing failure scenario using worker fail_after metadata") + + with TestClient(test_app) as client: + # Step 1: Upload input file + input_data = generate_batch_input_data( + 10 + ) # Generate more tasks for failure after 3 backoffs. + files = {"file": ("test_input.jsonl", input_data, "application/jsonl")} + data = {"purpose": "batch"} + + upload_response = client.post("/v1/files", files=files, data=data) + assert upload_response.status_code == 200 + input_file_id = upload_response.json()["id"] + + # Step 2: Inject FailingJobManager to add fail_after metadata during job creation + original_manager = test_app.state.batch_driver.job_manager + failing_manager = FailingJobManager( + fail_after_n_requests=1 + ) # Fail after processing 1 request + await test_app.state.batch_driver.run_coroutine( + failing_manager.set_job_entity_manager(original_manager._job_entity_manager) + ) + test_app.state.batch_driver._job_manager = failing_manager + + try: + # Step 3: Create batch job with failing manager that injects fail_after metadata + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + # Step 4: Wait for job to start processing + await wait_for_status(client, batch_id, "in_progress", max_polls=10) + # Step 5: Wait for finalization to complete + final_status = await wait_for_status( + client, batch_id, "failed", max_polls=60, poll_interval=1.0 + ) # Wait longer for job retries + + # Step 7: Verify failed status using comprehensive validation + validate_batch_response( + final_status, + expected_status="failed", + expected_endpoint="/v1/chat/completions", + expected_in_progress_at=True, # Should have started processing + expected_failed_at=True, # Should have failure timestamp + expected_finalizing_at=True, # Should have reached finalizing stage + expected_output_file_id=True, + expected_error_file_id=True, + expected_request_counts={ + "total": 3, + "completed": 3, + "failed": 0, + }, + ) + + print( + "✅ Processing failure test with worker fail_after completed successfully" + ) + + await test_app.state.batch_driver.clear_job(batch_id) + finally: + # Restore original manager + test_app.state.batch_driver._job_manager = original_manager + + +@pytest.mark.asyncio +async def test_job_finalizing_failure(test_app): + """Test case 3: Create job, failure during finalizing.""" + print("Test 3: Job finalizing failure scenario") + + with TestClient(test_app) as client: + # Step 1: Upload input file + input_data = generate_batch_input_data(2) + files = {"file": ("test_input.jsonl", input_data, "application/jsonl")} + data = {"purpose": "batch"} + + upload_response = client.post("/v1/files", files=files, data=data) + assert upload_response.status_code == 200 + input_file_id = upload_response.json()["id"] + + # Step 2: Inject the exception to the finalize_job_output_data to fail during finalizing + finalizing_patcher = patch( + "aibrix.batch.storage.adapter.BatchStorageAdapter.finalize_job_output_data" + ) + mock_finalize = finalizing_patcher.start() + mock_finalize.side_effect = Exception("Simulated finalization failure") + + try: + # Step 3: Create batch job + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + await wait_for_status(client, batch_id, "in_progress") + + await asyncio.sleep(3) + + # Step 4: Wait for job to reach final status + final_status = await wait_for_status(client, batch_id, "failed") + + # Step 6: Verify failed status due to finalization error using comprehensive validation + validate_batch_response( + final_status, + expected_status="failed", + expected_in_progress_at=True, # Should have started processing + expected_failed_at=True, # Should have failure timestamp + expected_finalizing_at=True, # Should have reached finalizing stage + expected_errors=BatchJobErrorCode.FINALIZING_ERROR, + expected_output_file_id=True, # May or may not have output file + expected_error_file_id=True, # May or may not have error file + expected_request_counts=False, # May or may not have counts + ) + + print("✅ Finalizing failure test completed successfully") + + await test_app.state.batch_driver.clear_job(batch_id) + finally: + finalizing_patcher.stop() + + +@pytest.mark.asyncio +async def test_job_cancellation_in_validation(test_app): + """Test case 4: Create job, cancel during validation.""" + print("Test 4: Job cancellation during validation scenario") + + with TestClient(test_app) as client: + # Step 1: Skip uploading file + + # Step 2: Inject the FailingJobManager to fail during validation + original_manager = test_app.state.batch_driver.job_manager + failing_manager = FailingJobManager(stall_validation=3.0) + await test_app.state.batch_driver.run_coroutine( + failing_manager.set_job_entity_manager(original_manager._job_entity_manager) + ) + test_app.state.batch_driver._job_manager = failing_manager + + try: + # Step 3: Create batch job + batch_request = { + "input_file_id": "invalid_input_file_id", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + # Step 4: Cancel job during processing + cancel_response = client.post(f"/v1/batches/{batch_id}/cancel") + assert cancel_response.status_code == 200 + + # Step 5: Wait for validation to fail + final_status = await wait_for_status(client, batch_id, "cancelled") + + # Step 6: Verify failed status using comprehensive validation + validate_batch_response( + final_status, + expected_status="cancelled", + expected_cancelling_at=True, + expected_cancelled_at=True, + ) + + print("✅ Validation failure test completed successfully") + + await test_app.state.batch_driver.clear_job(batch_id) + finally: + # Restore original manager + test_app.state.batch_driver._job_manager = original_manager + + +@pytest.mark.asyncio +async def test_job_cancellation_in_progress_before_preparation(test_app): + """Test case 5: Create job, cancel during in progress, finalizing, validate finalized result.""" + print("Test 5: Job cancellation during processing scenario") + + with TestClient(test_app) as client: + # Step 1: Upload input file + input_data = generate_batch_input_data( + 10 + ) # More requests for longer processing + files = {"file": ("test_input.jsonl", input_data, "application/jsonl")} + data = {"purpose": "batch"} + + upload_response = client.post("/v1/files", files=files, data=data) + assert upload_response.status_code == 200 + input_file_id = upload_response.json()["id"] + + # Step 2: Create batch job + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + try: + # Step 3: Wait for job to start processing, ASAP + await wait_for_status( + client, batch_id, "in_progress", max_polls=10, poll_interval=0.1 + ) + + # Step 4: Cancel job during processing + cancel_response = client.post(f"/v1/batches/{batch_id}/cancel") + assert cancel_response.status_code == 200 + + # Step 5: Wait for cancellation and finalization + final_status = await wait_for_status( + client, batch_id, "cancelled", max_polls=20 + ) + + # Step 6: Verify cancelled status using comprehensive validation + validate_batch_response( + final_status, + expected_status="cancelled", + expected_endpoint="/v1/chat/completions", + expected_in_progress_at=True, # Should have started processing + expected_cancelled_at=True, # Should have cancellation timestamp + expected_cancelling_at=True, # Should have cancelling start timestamp + expected_finalizing_at=True, # Should have finalizing timestamp + ) + + print("✅ Processing cancellation test completed successfully") + + await test_app.state.batch_driver.clear_job(batch_id) + finally: + pass + + +@pytest.mark.asyncio +async def test_job_cancellation_in_progress(test_app): + """Test case 5: Create job, cancel during in progress, finalizing, validate finalized result.""" + print("Test 5: Job cancellation during processing scenario") + + with TestClient(test_app) as client: + # Step 1: Upload input file + input_data = generate_batch_input_data( + 10 + ) # More requests for longer processing + files = {"file": ("test_input.jsonl", input_data, "application/jsonl")} + data = {"purpose": "batch"} + + upload_response = client.post("/v1/files", files=files, data=data) + assert upload_response.status_code == 200 + input_file_id = upload_response.json()["id"] + + # Step 2: Create batch job + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + try: + # Step 3: Wait for job to start processing, ASAP + await wait_for_status( + client, + batch_id, + "in_progress", + "output_file_id", + max_polls=10, + poll_interval=1, + ) + + await asyncio.sleep(3.0) # Wait a second for job to make some progress. + + # Step 4: Cancel job during processing + cancel_response = client.post(f"/v1/batches/{batch_id}/cancel") + assert cancel_response.status_code == 200 + + # Step 5: Wait for cancellation and finalization + final_status = await wait_for_status( + client, batch_id, "cancelled", max_polls=20 + ) + + # Step 6: Verify cancelled status using comprehensive validation + validate_batch_response( + final_status, + expected_status="cancelled", + expected_endpoint="/v1/chat/completions", + expected_in_progress_at=True, # Should have started processing + expected_cancelled_at=True, # Should have cancellation timestamp + expected_cancelling_at=True, # Should have cancelling start timestamp + expected_finalizing_at=True, # Should have finalizing timestamp + expected_output_file_id=True, + expected_error_file_id=True, + expected_request_counts=True, # May or may not have counts + ) + + print("✅ Processing cancellation test completed successfully") + finally: + await test_app.state.batch_driver.clear_job(batch_id) + + +@pytest.mark.asyncio +async def test_job_cancellation_in_finalizing(test_app): + """Test case 6: Create job, cancel during finalizing, report completed.""" + print("Test 6: Job cancellation during finalizing (reports completed) scenario") + + with TestClient(test_app) as client: + # Step 1: Upload input file + input_data = generate_batch_input_data( + 10 + ) # Larger file for taking some time to finalizing + files = {"file": ("test_input.jsonl", input_data, "application/jsonl")} + data = {"purpose": "batch"} + + upload_response = client.post("/v1/files", files=files, data=data) + assert upload_response.status_code == 200 + input_file_id = upload_response.json()["id"] + + # Step 2: Inject the FailingJobManager to stall cancellation execution + original_manager = test_app.state.batch_driver.job_manager + failing_manager = FailingJobManager(stall_cancelling=2.0) + await test_app.state.batch_driver.run_coroutine( + failing_manager.set_job_entity_manager(original_manager._job_entity_manager) + ) + test_app.state.batch_driver._job_manager = failing_manager + + # Step 3: Create batch job + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + try: + # Step 4: Wait for processing to reach finalizing + await wait_for_status(client, batch_id, "finalizing", max_polls=20) + + # Step 5: Try to cancel during or after processing + client.post(f"/v1/batches/{batch_id}/cancel") + # Note: Cancel may succeed or be ignored if already finalizing + + # Step 6: Wait for final status + final_status = await wait_for_status( + client, batch_id, ["completed", "cancelled"], max_polls=10 + ) + + # Step 6: Verify final status using comprehensive validation + validate_batch_response( + final_status, + expected_status="completed", + expected_endpoint="/v1/chat/completions", + expected_in_progress_at=True, # Should have started processing + expected_completed_at=True, # Should have completion timestamp + expected_finalizing_at=True, # Should have reached finalizing + expected_output_file_id=True, # Should have output file + expected_error_file_id=True, # Should have error file + expected_request_counts=True, # Should have request counts + ) + print(" Job completed by ignoring cancellation") + finally: + # Restore original manager + test_app.state.batch_driver._job_manager = original_manager + await test_app.state.batch_driver.clear_job(batch_id) + + +@pytest.mark.asyncio +async def test_job_expiration_during_validation(test_app): + """Test case 7: Create job, set expire to 1min and prevent validation, expired and report.""" + print("Test 7: Job expiration during validation scenario") + + pytest.skip("No batch scheduling enabled for k8s job") + + with TestClient(test_app) as client: + # Step 1: Upload input file + input_data = generate_batch_input_data(2) + files = {"file": ("test_input.jsonl", input_data, "application/jsonl")} + data = {"purpose": "batch"} + + upload_response = client.post("/v1/files", files=files, data=data) + assert upload_response.status_code == 200 + input_file_id = upload_response.json()["id"] + + # Step 2: Mock job manager to prevent validation + original_manager = test_app.state.batch_driver.job_manager + failing_manager = FailingJobManager(prevent_validation=True) + await test_app.state.batch_driver.run_coroutine( + failing_manager.set_job_entity_manager(original_manager._job_entity_manager) + ) + test_app.state.batch_driver.job_manager = failing_manager + + try: + # Step 3: Create batch job with very short completion window + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "60", # 60 seconds for test + } + + # Patch the completion window to be very short for testing + with patch( + "aibrix.batch.constant.EXPIRE_INTERVAL", 0.1 + ): # Check every 100ms + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + # Step 4: Wait longer than completion window for expiration + await asyncio.sleep(2) # Wait for expiration to trigger + + # Step 5: Check that job expired + final_status = await wait_for_status( + client, batch_id, "expired", max_polls=10 + ) + + # Step 6: Verify expired status using comprehensive validation + validate_batch_response( + final_status, + expected_status="expired", + expected_endpoint="/v1/chat/completions", + expected_in_progress_at=False, # Should be None (expired before processing) + expected_expired_at=True, # Should have expiration timestamp + expected_failed_at=False, # Should be None (expired, not failed) + expected_completed_at=False, # Should be None (expired, not completed) + expected_cancelled_at=False, # Should be None (not cancelled) + expected_cancelling_at=False, # Should be None (not cancelled) + expected_finalizing_at=False, # Should be None (expired before finalizing) + expected_errors=False, # Should be None (expired during validation) + expected_output_file_id=False, # Should be None (expired before processing) + expected_error_file_id=False, # Should be None (expired before processing) + expected_request_counts=False, # Should be None (expired before processing) + ) + + print("✅ Validation expiration test completed successfully") + + finally: + # Restore original manager + test_app.state.batch_driver.job_manager = original_manager + await test_app.state.batch_driver.clear_job(batch_id) + + +@pytest.mark.asyncio +async def test_job_expiration_during_processing(test_app): + """Test case 8: Create job, set expire to 1min, expired during in progress, finalizing, validate result.""" + print("Test 8: Job expiration during processing scenario") + + with TestClient(test_app) as client: + # Step 1: Upload input file + input_data = generate_batch_input_data(3) + files = {"file": ("test_input.jsonl", input_data, "application/jsonl")} + data = {"purpose": "batch"} + + upload_response = client.post("/v1/files", files=files, data=data) + assert upload_response.status_code == 200 + input_file_id = upload_response.json()["id"] + + # Step 2: Inject FailingJobManager to add fail_after metadata during job creation + original_manager = test_app.state.batch_driver.job_manager + failing_manager = FailingJobManager(expiration=2) # Expired after 2 seconds + await test_app.state.batch_driver.run_coroutine( + failing_manager.set_job_entity_manager(original_manager._job_entity_manager) + ) + test_app.state.batch_driver._job_manager = failing_manager + + try: + # Step 3: Create batch job with failing manager that injects fail_after metadata + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert batch_response.status_code == 200 + batch_id = batch_response.json()["id"] + + # Step 4: Wait for job to start processing + await wait_for_status(client, batch_id, "in_progress", max_polls=10) + # Step 5: Wait for finalization to complete + final_status = await wait_for_status( + client, batch_id, "failed", max_polls=60, poll_interval=1.0 + ) # Wait longer for job retries + + # Step 7: Verify failed status using comprehensive validation + validate_batch_response( + final_status, + expected_status="expired", + expected_endpoint="/v1/chat/completions", + expected_completion_window="0h", # overrided to 2s + expected_in_progress_at=True, # Should have started processing + expected_expired_at=True, # Should have failure timestamp + expected_finalizing_at=True, # Should have reached finalizing stage + expected_output_file_id=True, # May or may not have output file + expected_error_file_id=True, # May or may not have error file + ) + + print( + "✅ Processing failure test with worker fail_after completed successfully" + ) + + await test_app.state.batch_driver.clear_job(batch_id) + finally: + # Restore original manager + test_app.state.batch_driver._job_manager = original_manager + + +if __name__ == "__main__": + # Allow running individual tests + import sys + + if len(sys.argv) > 1: + test_name = sys.argv[1] + if hasattr(sys.modules[__name__], test_name): + asyncio.run(getattr(sys.modules[__name__], test_name)()) + else: + print("Available tests:") + for name in dir(sys.modules[__name__]): + if name.startswith("test_"): + print(f" {name}") diff --git a/python/aibrix/tests/batch/test_e2e_openai_batch_api.py b/python/aibrix/tests/batch/test_e2e_openai_batch_api.py index a65f77e4a..96b953e91 100644 --- a/python/aibrix/tests/batch/test_e2e_openai_batch_api.py +++ b/python/aibrix/tests/batch/test_e2e_openai_batch_api.py @@ -12,28 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import asyncio import json import pytest from fastapi.testclient import TestClient -from aibrix.metadata.app import build_app - - -def create_test_app(): - """Create a FastAPI app configured for e2e testing.""" - return build_app( - argparse.Namespace( - host=None, - port=8100, - enable_fastapi_docs=False, - disable_batch_api=False, - enable_k8s_job=False, - e2e_test=True, - ) - ) +from tests.batch.conftest import create_test_app def generate_batch_input_data(num_requests: int = 3) -> str: @@ -61,9 +46,7 @@ def generate_batch_input_data(num_requests: int = 3) -> str: return "\n".join(lines) -def verify_batch_output_content( - output_content: str, expected_requests: int = 3 -) -> bool: +def verify_batch_output_content(output_content: str, expected_requests: int) -> bool: """Verify that batch output content has the expected structure.""" lines = output_content.strip().split("\n") @@ -73,24 +56,41 @@ def verify_batch_output_content( for i, line in enumerate(lines): try: - response = json.loads(line) + output = json.loads(line) # Check required fields in OpenAI batch response format - # [TODO][NEXT] check required_fields = ["id", "custom_id", "response"] - required_fields = ["custom_id"] # For now, just check custom_id + required_fields = ["id", "custom_id", "response"] for field in required_fields: - if field not in response: + if field not in output: print(f"Missing required field '{field}' in response {i+1}") return False # Verify custom_id matches expected pattern expected_custom_id = f"request-{i+1}" - if response["custom_id"] != expected_custom_id: + if output["custom_id"] != expected_custom_id: print( - f"Expected custom_id '{expected_custom_id}', got '{response['custom_id']}'" + f"Expected custom_id '{expected_custom_id}', got '{output['custom_id']}'" ) return False + response = output["response"] + required_fields = ["status_code", "request_id", "body"] + for field in required_fields: + if field not in response: + print( + f"Missing required field 'response.{field}' in response {i+1}" + ) + return False + + body = response["body"] + required_fields = ["model"] # For now, just check model + for field in required_fields: + if field not in body: + print( + f"Missing required field 'response.body.{field}' in response {i+1}" + ) + return False + except json.JSONDecodeError as e: print(f"Invalid JSON in output line {i+1}: {e}") return False @@ -175,6 +175,13 @@ async def test_openai_batch_api_e2e(): assert ( output_file_id is not None ), "Expected output_file_id for completed batch" + + request_counts = status_result.get("request_counts") + assert request_counts is not None + assert request_counts["total"] == 3 + assert request_counts["completed"] == 3 + assert request_counts["failed"] == 0 + break elif current_status == "failed": pytest.fail( @@ -237,7 +244,190 @@ async def test_openai_batch_api_e2e(): print( "\n🎉 E2E test completed successfully! All OpenAI Batch API endpoints working correctly." ) - await app.state.job_controller.clear_job(batch_id) + await app.state.batch_driver.clear_job(batch_id) + + +@pytest.mark.asyncio +async def test_openai_batch_api_metadata_server_workflow(test_app): + """ + End-to-end test for OpenAI Batch API with metadata server workflow: + 1. Upload sample input file via Files API + 2. Create batch job via Batch API (using metadata server mode) + 3. Verify metadata server prepares job output files + 4. Simulate worker execution by checking IN_PROGRESS state + 5. Poll job status until completion + 6. Download and verify output via Files API + """ + with TestClient(test_app) as client: + # Step 1: Upload sample input file via Files API + print("Step 1: Uploading batch input file...") + + input_data = generate_batch_input_data(10) + files = { + "file": ("metadata_batch_input.jsonl", input_data, "application/jsonl") + } + data = {"purpose": "batch"} + + upload_response = client.post("/v1/files", files=files, data=data) + assert ( + upload_response.status_code == 200 + ), f"File upload failed: {upload_response.text}" + + upload_result = upload_response.json() + input_file_id = upload_result["id"] + print(f"✅ File uploaded successfully with ID: {input_file_id}") + + # Step 2: Create batch job via Batch API (metadata server mode) + print("Step 2: Creating batch job with metadata server workflow...") + + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = client.post("/v1/batches", json=batch_request) + assert ( + batch_response.status_code == 200 + ), f"Batch creation failed: {batch_response.text}" + + batch_result = batch_response.json() + assert "id" in batch_result + assert batch_result["object"] == "batch" + assert batch_result["input_file_id"] == input_file_id + assert batch_result["endpoint"] == "/v1/chat/completions" + assert batch_result["completion_window"] == "24h" + assert isinstance(batch_result["created_at"], int) + + batch_id = batch_result["id"] + print(f"✅ Batch created successfully with ID: {batch_id}") + + # Step 3: Verify metadata server prepared job output files + print("Step 3: Checking if job output files are prepared...") + + # Poll for a short time to see if job moves through states + preparation_polls = 5 + for attempt in range(preparation_polls): + status_response = client.get(f"/v1/batches/{batch_id}") + assert ( + status_response.status_code == 200 + ), f"Status check failed: {status_response.text}" + + status_result = status_response.json() + current_status = status_result["status"] + print(f" Preparation check {attempt + 1}: Status = {current_status}") + + # Check if we have output file IDs which indicates preparation is done + output_file_id = status_result.get("output_file_id") + error_file_id = status_result.get("error_file_id") + + if output_file_id and error_file_id: + print("✅ Job output files prepared by metadata server!") + break + elif current_status in ["failed", "cancelled", "expired"]: + pytest.fail(f"Job failed during preparation: {current_status}") + + await asyncio.sleep(0.5) + + # Step 4: Simulate worker workflow - wait for IN_PROGRESS + print("Step 4: Waiting for job to reach IN_PROGRESS state...") + + in_progress_polls = 3 + for attempt in range(in_progress_polls): + status_response = client.get(f"/v1/batches/{batch_id}") + status_result = status_response.json() + print(status_result) + current_status = status_result["status"] + + print(f" IN_PROGRESS check {attempt + 1}: Status = {current_status}") + + if current_status == "in_progress": + print("✅ Job reached IN_PROGRESS state - worker can start execution!") + + # Verify status_result + assert isinstance(status_result["in_progress_at"], int) + + break + elif current_status in ["failed", "cancelled", "expired"]: + pytest.fail(f"Job failed before reaching IN_PROGRESS: {current_status}") + + await asyncio.sleep(1) + + # Step 5: Poll job status until completion (metadata server should finalize) + print("Step 5: Polling job status until completion...") + + max_polls = 20 # Extended for metadata server workflow + poll_interval = 1 + + for attempt in range(max_polls): + status_response = client.get(f"/v1/batches/{batch_id}") + status_result = status_response.json() + print(status_result) + current_status = status_result["status"] + + print(f" Completion check {attempt + 1}: Status = {current_status}") + + if current_status == "completed": + print( + "✅ Batch job completed successfully with metadata server workflow!" + ) + + # Verify status_result + output_file_id = status_result["output_file_id"] + assert ( + output_file_id is not None + ), "Expected output_file_id for completed batch" + + request_counts = status_result.get("request_counts") + assert request_counts is not None + assert request_counts["total"] == 10 + assert request_counts["completed"] == 10 + assert request_counts["failed"] == 0 + + assert isinstance(status_result["finalizing_at"], int) + assert isinstance(status_result["completed_at"], int) + + break + elif current_status == "failed": + pytest.fail( + f"Batch job failed: {status_result.get('errors', 'Unknown error')}" + ) + elif current_status in ["cancelled", "expired"]: + pytest.fail(f"Batch job was {current_status}") + + await asyncio.sleep(poll_interval) + else: + pytest.fail( + f"Batch job did not complete within {max_polls * poll_interval} seconds" + ) + + # Step 6: Download and verify output via Files API + print("Step 6: Downloading and verifying output...") + + output_response = client.get(f"/v1/files/{output_file_id}/content") + assert ( + output_response.status_code == 200 + ), f"Output download failed: {output_response.text}" + + output_content = output_response.content.decode("utf-8") + assert output_content, "Output file is empty" + + # Verify output content structure + is_valid = verify_batch_output_content(output_content, 10) + assert ( + is_valid + ), f"Output content verification failed. Content:\n{output_content}" + + print("✅ Output downloaded and verified successfully!") + print(f"Output content preview:\n{output_content[:200]}...") + + print( + "\n🎉 Metadata server workflow E2E test completed successfully! " + "Job preparation, worker coordination, and finalization working correctly." + ) + await test_app.state.batch_driver.clear_job(batch_id) + + assert False @pytest.mark.asyncio diff --git a/python/aibrix/tests/batch/test_inference_client_integration.py b/python/aibrix/tests/batch/test_inference_client_integration.py new file mode 100644 index 000000000..9d179ccaa --- /dev/null +++ b/python/aibrix/tests/batch/test_inference_client_integration.py @@ -0,0 +1,90 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from aibrix.batch.job_driver import InferenceEngineClient, ProxyInferenceEngineClient + + +class TestInferenceClientIntegration: + """Test inference client functionality and retry logic.""" + + @pytest.mark.asyncio + async def test_mock_inference_client(self): + """Test inference client in mock mode.""" + client = InferenceEngineClient() # No base_url = mock mode + + request_data = { + "custom_id": "test-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + }, + } + + # Test mock response + response = await client.inference_request("/v1/chat/completions", request_data) + assert response == request_data # Mock should echo the request + + @pytest.mark.asyncio + async def test_real_inference_client_with_invalid_url(self): + """Test inference client with invalid URL to verify error handling.""" + client = ProxyInferenceEngineClient("http://invalid-host:9999") + + request_data = { + "custom_id": "test-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + }, + } + + # Should raise an exception when trying to connect to invalid host + with pytest.raises(Exception): # Will be httpx.RequestError or similar + await client.inference_request("/v1/chat/completions", request_data) + + @pytest.mark.asyncio + async def test_retry_behavior_demonstration(self): + """Demonstrate that the retry logic works by testing mock behavior.""" + + class FailingInferenceClient(InferenceEngineClient): + def __init__(self): + super().__init__() + self.attempt_count = 0 + + async def inference_request(self, endpoint, request_data): + self.attempt_count += 1 + if self.attempt_count < 3: + raise Exception(f"Simulated failure {self.attempt_count}") + return {"success": True, "attempts": self.attempt_count} + + # Note: This test shows the pattern, but retry logic is in JobDriver + # In actual usage, JobDriver would retry the inference_request calls + client = FailingInferenceClient() + + # First two calls should fail + with pytest.raises(Exception, match="Simulated failure 1"): + await client.inference_request("/test", {}) + + with pytest.raises(Exception, match="Simulated failure 2"): + await client.inference_request("/test", {}) + + # Third call should succeed + result = await client.inference_request("/test", {}) + assert result["success"] is True + assert result["attempts"] == 3 diff --git a/python/aibrix/tests/batch/test_job_cache.py b/python/aibrix/tests/batch/test_job_cache.py new file mode 100644 index 000000000..ae6da0375 --- /dev/null +++ b/python/aibrix/tests/batch/test_job_cache.py @@ -0,0 +1,84 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +# Set required environment variable before importing +os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing") + +from aibrix.batch.job_entity import JobEntityManager +from aibrix.metadata.cache.job import JobCache + + +def test_job_cache_implements_job_entity_manager(): + """Test that JobCache properly implements JobEntityManager interface.""" + cache = JobCache() + assert isinstance(cache, JobEntityManager) + + +def test_job_cache_with_custom_patch(): + """ + Test that JobCache can be initialized with a custom template path + and loads the correct service account. + """ + from aibrix.metadata.cache.job import JobCache + + # Get the path to the unittest job template + patch_path = Path(__file__).parent / "testdata" / "k8s_job_patch_unittest.yaml" + + # Initialize JobCache with custom template + job_cache = JobCache(template_patch_path=patch_path) + + # Verify that the template was loaded + assert job_cache.job_template is not None + assert job_cache.job_template["kind"] == "Job" + assert job_cache.job_template["apiVersion"] == "batch/v1" + + # Verify that the unittest service account is used + service_account_name = job_cache.job_template["spec"]["template"]["spec"][ + "serviceAccountName" + ] + assert service_account_name == "unittest-job-reader-sa" + + print( + f"✓ JobCache loaded custom template with service account: {service_account_name}" + ) + print("✅ Template path integration works correctly!") + + +def test_job_cache_with_default_template(): + """ + Test that JobCache still works with the default template when no path is provided. + """ + from aibrix.metadata.cache.job import JobCache + + # Initialize JobCache without custom template (should use default) + job_cache = JobCache() + + # Verify that the template was loaded + assert job_cache.job_template is not None + assert job_cache.job_template["kind"] == "Job" + assert job_cache.job_template["apiVersion"] == "batch/v1" + + # Verify that the default service account is used + service_account_name = job_cache.job_template["spec"]["template"]["spec"][ + "serviceAccountName" + ] + assert service_account_name == "job-reader-sa" # Default service account + + print( + f"✓ JobCache loaded default template with service account: {service_account_name}" + ) + print("✅ Default template loading works correctly!") diff --git a/python/aibrix/tests/batch/test_job_entity.py b/python/aibrix/tests/batch/test_job_entity.py new file mode 100644 index 000000000..b3cc91463 --- /dev/null +++ b/python/aibrix/tests/batch/test_job_entity.py @@ -0,0 +1,344 @@ +"""Unit tests for BatchJobError and related job entity classes""" + +import json + +import pytest + +from aibrix.batch.job_entity import ( + BatchJobEndpoint, + BatchJobError, + BatchJobErrorCode, + BatchJobSpec, + CompletionWindow, +) + + +class TestBatchJobError: + """Test BatchJobError creation and handling.""" + + def test_batch_job_error_with_standard_exception(self): + """Test creating BatchJobError with a standard Exception converted to string.""" + # Create a standard exception + fe = Exception("Something went wrong during processing") + + # This should not throw an error + error = BatchJobError(code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe)) + + # Verify the error was created correctly + assert error.code == BatchJobErrorCode.FINALIZING_ERROR.value + assert error.message == "Something went wrong during processing" + assert error.param is None + assert error.line is None + assert str(error) == "Something went wrong during processing" + + def test_batch_job_error_with_nested_exception(self): + """Test creating BatchJobError with a nested exception chain.""" + try: + try: + raise ValueError("Inner error") + except ValueError as inner: + raise RuntimeError("Outer error") from inner + except RuntimeError as fe: + error = BatchJobError( + code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe) + ) + + assert error.code == BatchJobErrorCode.FINALIZING_ERROR.value + assert error.message == "Outer error" + + def test_batch_job_error_with_all_parameters(self): + """Test creating BatchJobError with all optional parameters.""" + fe = Exception("Processing failed at specific location") + + error = BatchJobError( + code=BatchJobErrorCode.FINALIZING_ERROR, + message=str(fe), + param="output_file_id", + line=42, + ) + + assert error.code == BatchJobErrorCode.FINALIZING_ERROR.value + assert error.message == "Processing failed at specific location" + assert error.param == "output_file_id" + assert error.line == 42 + + def test_batch_job_error_inheritance(self): + """Test that BatchJobError is properly an Exception subclass.""" + fe = Exception("Test exception") + + error = BatchJobError(code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe)) + + # Should be an instance of Exception + assert isinstance(error, Exception) + assert isinstance(error, BatchJobError) + + # Should be raisable + with pytest.raises(BatchJobError) as exc_info: + raise error + + assert exc_info.value.code == BatchJobErrorCode.FINALIZING_ERROR.value + assert exc_info.value.message == "Test exception" + + def test_batch_job_error_with_unicode_exception(self): + """Test creating BatchJobError with unicode characters in exception message.""" + fe = Exception("Processing failed: 文件不存在 (file not found)") + + error = BatchJobError(code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe)) + + assert error.code == BatchJobErrorCode.FINALIZING_ERROR.value + assert error.message == "Processing failed: 文件不存在 (file not found)" + + def test_batch_job_error_with_json_serialization(self): + """Test that BatchJobError can be represented in JSON-like structure.""" + fe = Exception("JSON serialization test") + + error = BatchJobError( + code=BatchJobErrorCode.FINALIZING_ERROR, + message=str(fe), + param="test_param", + line=100, + ) + + # Test that error attributes can be serialized + error_dict = { + "code": error.code, + "message": error.message, + "param": error.param, + "line": error.line, + } + + # Should be able to serialize to JSON + json_str = json.dumps(error_dict) + assert json_str is not None + + # Should be able to deserialize + parsed = json.loads(json_str) + assert parsed["code"] == BatchJobErrorCode.FINALIZING_ERROR.value + assert parsed["message"] == "JSON serialization test" + assert parsed["param"] == "test_param" + assert parsed["line"] == 100 + + +class TestBatchJobEntityCreation: + """Test creation of various batch job entities.""" + + def test_batch_job_spec_creation(self): + """Test creating BatchJobSpec.""" + spec = BatchJobSpec( + input_file_id="test-input-123", + endpoint=BatchJobEndpoint.CHAT_COMPLETIONS.value, + completion_window=CompletionWindow.TWENTY_FOUR_HOURS.expires_at(), + metadata={"priority": "high"}, + ) + + assert spec.input_file_id == "test-input-123" + assert spec.endpoint == "/v1/chat/completions" + assert spec.completion_window == 86400 + assert spec.metadata == {"priority": "high"} + + def test_batch_job_error_codes_coverage(self): + """Test that all BatchJobErrorCode values work with exceptions.""" + test_exception = Exception("Test exception message") + + # Test each error code + error_codes_to_test = [ + BatchJobErrorCode.INVALID_INPUT_FILE, + BatchJobErrorCode.INVALID_ENDPOINT, + BatchJobErrorCode.INVALID_COMPLETION_WINDOW, + BatchJobErrorCode.INVALID_METADATA, + BatchJobErrorCode.AUTHENTICATION_ERROR, + BatchJobErrorCode.INFERENCE_FAILED, + BatchJobErrorCode.PREPARE_OUTPUT_ERROR, + BatchJobErrorCode.FINALIZING_ERROR, + BatchJobErrorCode.UNKNOWN_ERROR, + ] + + for error_code in error_codes_to_test: + # This should not throw any errors + error = BatchJobError(code=error_code, message=str(test_exception)) + + assert error.code == error_code.value + assert error.message == "Test exception message" + assert isinstance(error, Exception) + assert isinstance(error, BatchJobError) + + +class TestExceptionMessageConversion: + """Test various ways exceptions can be converted to messages.""" + + def test_exception_with_args(self): + """Test exception with multiple arguments.""" + fe = Exception("Primary message", "Secondary info", 42) + + error = BatchJobError(code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe)) + + # str() should convert the first argument + assert "Primary message" in error.message + + def test_exception_with_no_args(self): + """Test exception with no arguments.""" + fe = Exception() + + error = BatchJobError(code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe)) + + # Should handle empty exception gracefully + assert error.message == "" + + def test_custom_exception_class(self): + """Test with custom exception class.""" + + class CustomFinalizationError(Exception): + def __init__(self, operation, details): + self.operation = operation + self.details = details + super().__init__(f"Operation '{operation}' failed: {details}") + + fe = CustomFinalizationError("file_upload", "network timeout") + + error = BatchJobError(code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe)) + + assert error.message == "Operation 'file_upload' failed: network timeout" + + def test_exception_with_special_characters(self): + """Test exception message with special characters.""" + fe = Exception("Error with special chars: !@#$%^&*()_+-={}[]|\\:;\"'<>?,./") + + error = BatchJobError(code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe)) + + assert ( + error.message + == "Error with special chars: !@#$%^&*()_+-={}[]|\\:;\"'<>?,./" + ) + + def test_exception_with_newlines_and_tabs(self): + """Test exception message with newlines and tabs.""" + fe = Exception("Multi-line\nerror\tmessage\nwith\ttabs") + + error = BatchJobError(code=BatchJobErrorCode.FINALIZING_ERROR, message=str(fe)) + + assert error.message == "Multi-line\nerror\tmessage\nwith\ttabs" + + +class TestBatchJobErrorFastAPICompatibility: + """Test BatchJobError compatibility with FastAPI serialization.""" + + def test_batch_job_error_pydantic_type_adapter_compatibility(self): + """Test BatchJobError with Pydantic TypeAdapter (FastAPI requirement).""" + from pydantic import TypeAdapter + + # Create a BatchJobError instance + fe = Exception("Pydantic TypeAdapter test") + error = BatchJobError( + code=BatchJobErrorCode.FINALIZING_ERROR, + message=str(fe), + param="test_param", + line=100, + ) + + # Create TypeAdapter for BatchJobError + adapter = TypeAdapter(BatchJobError) + + # Test serialization (this is what was failing) + serialized = adapter.dump_python(error) + + assert isinstance(serialized, dict) + assert serialized["code"] == BatchJobErrorCode.FINALIZING_ERROR.value + assert serialized["message"] == "Pydantic TypeAdapter test" + assert serialized["param"] == "test_param" + assert serialized["line"] == 100 + + def test_batch_job_error_pydantic_json_serialization(self): + """Test BatchJobError JSON serialization through Pydantic TypeAdapter.""" + from pydantic import TypeAdapter + + fe = Exception("JSON TypeAdapter test") + error = BatchJobError( + code=BatchJobErrorCode.AUTHENTICATION_ERROR, + message=str(fe), + param=None, + line=None, + ) + + adapter = TypeAdapter(BatchJobError) + + # Test JSON serialization + json_str = adapter.dump_json(error) + assert json_str is not None + + # Test the JSON content + parsed = json.loads(json_str) + assert parsed["code"] == BatchJobErrorCode.AUTHENTICATION_ERROR.value + assert parsed["message"] == "JSON TypeAdapter test" + assert parsed["param"] is None + assert parsed["line"] is None + + def test_batch_job_error_pydantic_validation_and_serialization_roundtrip(self): + """Test BatchJobError validation and serialization roundtrip through Pydantic.""" + from pydantic import TypeAdapter + + # Original error + fe = Exception("Roundtrip test") + original_error = BatchJobError( + code=BatchJobErrorCode.PREPARE_OUTPUT_ERROR, + message=str(fe), + param="roundtrip_param", + line=50, + ) + + adapter = TypeAdapter(BatchJobError) + + # Serialize to dict + error_dict = adapter.dump_python(original_error) + + # Validate back to BatchJobError (roundtrip) + validated_error = adapter.validate_python(error_dict) + + # Verify the roundtrip worked + assert isinstance(validated_error, BatchJobError) + assert validated_error.code == original_error.code + assert validated_error.message == original_error.message + assert validated_error.param == original_error.param + assert validated_error.line == original_error.line + + def test_batch_job_error_list_serialization_for_fastapi(self): + """Test BatchJobError list serialization for FastAPI responses with errors list.""" + from typing import List + + from pydantic import TypeAdapter + + # Create multiple errors + exceptions = [ + ValueError("First validation error"), + FileNotFoundError("Second file error"), + RuntimeError("Third runtime error"), + ] + + errors = [] + error_codes = [ + BatchJobErrorCode.INVALID_INPUT_FILE, + BatchJobErrorCode.AUTHENTICATION_ERROR, + BatchJobErrorCode.FINALIZING_ERROR, + ] + + for exc, code in zip(exceptions, error_codes): + error = BatchJobError( + code=code, message=str(exc), param=f"param_{code.value}", line=None + ) + errors.append(error) + + # Test list serialization + adapter = TypeAdapter(List[BatchJobError]) + + # Serialize list of errors + serialized_errors = adapter.dump_python(errors) + + assert isinstance(serialized_errors, list) + assert len(serialized_errors) == 3 + + # Verify each error was serialized correctly + for i, serialized_error in enumerate(serialized_errors): + assert isinstance(serialized_error, dict) + assert serialized_error["code"] == error_codes[i].value + assert serialized_error["message"] == str(exceptions[i]) + assert serialized_error["param"] == f"param_{error_codes[i].value}" + assert serialized_error["line"] is None diff --git a/python/aibrix/tests/batch/test_job_manager.py b/python/aibrix/tests/batch/test_job_manager.py index 1d3b6fb5d..f3fef8783 100644 --- a/python/aibrix/tests/batch/test_job_manager.py +++ b/python/aibrix/tests/batch/test_job_manager.py @@ -59,7 +59,7 @@ async def test_local_job_cancellation(): assert job_id not in job_manager._done_jobs # Cancel the job - result = job_manager.cancel_job(job_id) + result = await job_manager.cancel_job(job_id) assert result is True # Verify job moved to done state with cancelled status @@ -67,58 +67,17 @@ async def test_local_job_cancellation(): assert job_id in job_manager._done_jobs cancelled_job = job_manager._done_jobs[job_id] - assert cancelled_job.status.state == BatchJobState.CANCELED + assert cancelled_job.status.state == BatchJobState.FINALIZED + assert cancelled_job.status.cancelled @pytest.mark.asyncio -async def test_job_cancellation_race_condition(): - """Test race condition handling where job completes before cancellation.""" - job_manager = JobManager() - - # Create a job - await job_manager.create_job( - session_id="test-session-2", - input_file_id="test-file-2", - api_endpoint="/v1/embeddings", - completion_window="24h", - meta_data={"test": "race"}, - ) - - job_id = next(iter(job_manager._pending_jobs.keys())) - - # Simulate job completing by manually updating its status - job = job_manager._pending_jobs[job_id] - completed_status = BatchJobStatus( - jobID=job_id, - state=BatchJobState.COMPLETED, - createdAt=datetime.now(), - completedAt=datetime.now(), - ) - completed_job = BatchJob( - typeMeta=job.type_meta, - metadata=job.metadata, - spec=job.spec, - status=completed_status, - ) - job_manager._pending_jobs[job_id] = completed_job - - # Try to cancel already completed job - result = job_manager.cancel_job(job_id) - assert result is False # Should fail because job is already completed - - # Job is removed from pending during cancellation attempt, but since it failed, - # the job doesn't get moved to done state - it gets lost - assert job_id not in job_manager._pending_jobs - assert job_id not in job_manager._done_jobs - assert job_id not in job_manager._in_progress_jobs - - -def test_cancel_nonexistent_job(): +async def test_cancel_nonexistent_job(): """Test cancelling a job that doesn't exist.""" job_manager = JobManager() # Try to cancel non-existent job - result = job_manager.cancel_job("nonexistent-job-id") + result = await job_manager.cancel_job("nonexistent-job-id") assert result is False @@ -144,11 +103,12 @@ async def test_cancel_job_already_done(): job_manager._done_jobs[job_id] = job # Try to cancel job that's already done - result = job_manager.cancel_job(job_id) - assert result is True # Changed: done jobs now return True + result = await job_manager.cancel_job(job_id) + assert result is False # Changed: done jobs now return False -def test_job_committed_handler(): +@pytest.mark.asyncio +async def test_job_committed_handler(): """Test that job_committed_handler correctly adds jobs to pending.""" job_manager = JobManager() @@ -165,8 +125,8 @@ def test_job_committed_handler(): ), spec=BatchJobSpec( input_file_id="test-file-123", - endpoint=BatchJobEndpoint.CHAT_COMPLETIONS, - completion_window=CompletionWindow.TWENTY_FOUR_HOURS, + endpoint=BatchJobEndpoint.CHAT_COMPLETIONS.value, + completion_window=CompletionWindow.TWENTY_FOUR_HOURS.expires_at(), ), status=BatchJobStatus( jobID="test-job-id", @@ -176,14 +136,15 @@ def test_job_committed_handler(): ) # Call the handler - job_manager.job_committed_handler(batch_job) + await job_manager.job_committed_handler(batch_job) # Verify job is in pending state assert "test-job-id" in job_manager._pending_jobs assert job_manager._pending_jobs["test-job-id"] == batch_job -def test_job_deleted_handler(): +@pytest.mark.asyncio +async def test_job_deleted_handler(): """Test that job_deleted_handler correctly moves jobs to done state.""" job_manager = JobManager() @@ -200,8 +161,8 @@ def test_job_deleted_handler(): ), spec=BatchJobSpec( input_file_id="test-file-456", - endpoint=BatchJobEndpoint.EMBEDDINGS, - completion_window=CompletionWindow.TWENTY_FOUR_HOURS, + endpoint=BatchJobEndpoint.EMBEDDINGS.value, + completion_window=CompletionWindow.TWENTY_FOUR_HOURS.expires_at(), ), status=BatchJobStatus( jobID="test-job-id-2", @@ -214,7 +175,7 @@ def test_job_deleted_handler(): job_manager._pending_jobs["test-job-id-2"] = batch_job # Call the deleted handler - job_manager.job_deleted_handler(batch_job) + await job_manager.job_deleted_handler(batch_job) # Verify job is removed from pending (job_deleted_handler removes jobs, doesn't move them) assert "test-job-id-2" not in job_manager._pending_jobs @@ -231,15 +192,17 @@ def __init__(self, delay: float = 0.1): self.submitted_jobs: List[tuple] = [] # Track submitted jobs self.should_fail = False # Flag to simulate failures - def submit_job(self, session_id: str, job: BatchJobSpec): + async def submit_job(self, session_id: str, job: BatchJobSpec): """Mock job submission with async callback.""" + print(f"start time: {datetime.now()}") if self.should_fail: raise RuntimeError("Mock job submission failed") self.submitted_jobs.append((session_id, job)) # Simulate async job creation with a delay - asyncio.create_task(self._simulate_job_creation(session_id, job)) + await self._simulate_job_creation(session_id, job) + print(f"end time: {datetime.now()}") async def _simulate_job_creation(self, session_id: str, job_spec: BatchJobSpec): """Simulate async job creation process.""" @@ -258,36 +221,49 @@ async def _simulate_job_creation(self, session_id: str, job_spec: BatchJobSpec): spec=job_spec, status=BatchJobStatus( jobID=f"mock-job-{session_id}", - state=BatchJobState.CREATED, + state=BatchJobState.IN_PROGRESS, # Set to in_progress to skip job validation and preparetion. createdAt=datetime.now(), ), ) # Call the committed handler - if self._job_committed_handler: - self._job_committed_handler(batch_job) + await self.job_committed(batch_job) def get_job(self, job_id: str) -> Optional[BatchJob]: """Mock get_job implementation.""" return None + async def update_job_ready(self, job: BatchJob) -> None: + """Mock update_job_ready implementation.""" + pass + + async def update_job_status(self, job: BatchJob) -> None: + """Mock update_job_status implementation.""" + pass + def list_jobs(self) -> List[BatchJob]: """Mock list_jobs implementation.""" return [] - def cancel_job(self, job_id: str) -> bool: + async def cancel_job(self, job: BatchJob): """Mock cancel_job implementation.""" - return True + pass + + async def delete_job(self, job: BatchJob): + """Mock cancel_job implementation.""" + pass @pytest.mark.asyncio -async def test_aysnc_create_job(): +async def test_async_create_job(): """Test that JobEntityManager assigns job_id and calls handlers correctly.""" # Create mock job entity manager mock_entity_manager = MockJobEntityManager(delay=0.05) # Create job manager with entity manager - job_manager = JobManager(job_entity_manager=mock_entity_manager) + asyncio.get_running_loop().name = "test_async_create_job" + job_manager = JobManager() + await job_manager.set_job_entity_manager(mock_entity_manager) # Create a job using the async method session_id = "test-session-async-1" @@ -310,9 +286,9 @@ async def test_aysnc_create_job(): assert submitted_session_id == session_id assert submitted_spec.input_file_id == "test-input-1" - # Verify job was added to pending jobs - assert job_id in job_manager._pending_jobs - job = job_manager._pending_jobs[job_id] + # Verify job was added to progress jobs since MockJobEntityManager set initial state to in_progress + assert job_id in job_manager._in_progress_jobs + job = job_manager._in_progress_jobs[job_id] assert job.session_id == session_id assert job.status.job_id == job_id @@ -325,7 +301,10 @@ async def test_async_create_job_with_timeout(): """Test that create_job throws error when timeout occurs.""" # Create mock entity manager with long delay (longer than timeout) mock_entity_manager = MockJobEntityManager(delay=2.0) - job_manager = JobManager(job_entity_manager=mock_entity_manager) + + asyncio.get_running_loop().name = "test_async_create_job_with_timeout" + job_manager = JobManager() + await job_manager.set_job_entity_manager(mock_entity_manager) # Attempt to create job with short timeout session_id = "test-session-timeout" @@ -344,26 +323,28 @@ async def test_async_create_job_with_timeout(): assert len(mock_entity_manager.submitted_jobs) == 1 assert session_id not in job_manager._creating_jobs - # Verify no job was added to pending (since timeout occurred) - assert len(job_manager._pending_jobs) == 0 + # Verify no job was added to _in_progress_jobs (since timeout occurred) + assert len(job_manager._in_progress_jobs) == 0 # Wait for job to be added. await asyncio.sleep(3.0) - # Verify the job will still be valid and we do nothing about it. - assert len(job_manager._pending_jobs) == 1 - for job_id in job_manager._pending_jobs: - job = job_manager._pending_jobs[job_id] - assert job.status.state == BatchJobState.CREATED + # Verify the job will be ignore by job_manager + assert len(job_manager._in_progress_jobs) == 0 + all_jobs = await job_manager.list_jobs() + assert len(all_jobs) == 0 @pytest.mark.asyncio -async def test_aysnc_create_job_throws_error(): +async def test_async_create_job_throws_error(): """Test that create_job throws error when job submission fails.""" # Create mock entity manager that fails mock_entity_manager = MockJobEntityManager() mock_entity_manager.should_fail = True - job_manager = JobManager(job_entity_manager=mock_entity_manager) + + asyncio.get_running_loop().name = "test_async_create_job_throws_error" + job_manager = JobManager() + await job_manager.set_job_entity_manager(mock_entity_manager) # Attempt to create job session_id = "test-session-fail" @@ -388,7 +369,10 @@ async def test_aysnc_create_job_throws_error(): async def test_multiple_concurrent_job_creation(): """Test creating multiple jobs concurrently.""" mock_entity_manager = MockJobEntityManager(delay=0.1) - job_manager = JobManager(job_entity_manager=mock_entity_manager) + + asyncio.get_running_loop().name = "test_multiple_concurrent_job_creation" + job_manager = JobManager() + await job_manager.set_job_entity_manager(mock_entity_manager) # Create multiple jobs concurrently tasks = [] @@ -416,8 +400,8 @@ async def test_multiple_concurrent_job_creation(): # Verify all jobs are in pending state for i, job_id in enumerate(job_ids): - assert job_id in job_manager._pending_jobs - job = job_manager._pending_jobs[job_id] + assert job_id in job_manager._in_progress_jobs + job = job_manager._in_progress_jobs[job_id] assert job.session_id == session_ids[i] # Verify all futures were cleaned up diff --git a/python/aibrix/tests/batch/test_k8s_job_persistence.py b/python/aibrix/tests/batch/test_k8s_job_persistence.py new file mode 100644 index 000000000..2dd0f5f09 --- /dev/null +++ b/python/aibrix/tests/batch/test_k8s_job_persistence.py @@ -0,0 +1,291 @@ +"""Unit tests for job status persistence functionality.""" + +import json +from datetime import datetime, timezone + +from aibrix.batch.job_entity import ( + BatchJobError, + BatchJobErrorCode, + BatchJobState, + BatchJobStatus, + BatchJobTransformer, + Condition, + ConditionStatus, + ConditionType, + JobAnnotationKey, + RequestCountStats, +) + + +def test_create_status_annotations(): + """Test that create_status_annotations correctly handles all possible fields.""" + # Create timestamps + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + in_progress_time = datetime(2024, 1, 1, 12, 5, 0, tzinfo=timezone.utc) + cancelling_time = datetime(2024, 1, 1, 12, 10, 0, tzinfo=timezone.utc) + completed_time = datetime(2024, 1, 1, 12, 15, 0, tzinfo=timezone.utc) + finalizing_time = datetime(2024, 1, 1, 12, 16, 0, tzinfo=timezone.utc) + finalized_time = datetime(2024, 1, 1, 12, 17, 0, tzinfo=timezone.utc) + + # Create comprehensive conditions - one of each type + conditions = [ + Condition( + type=ConditionType.COMPLETED, + status=ConditionStatus.TRUE, + lastTransitionTime=completed_time, + reason="AllCompleted", + message="All requests completed successfully", + ), + Condition( + type=ConditionType.FAILED, + status=ConditionStatus.TRUE, + lastTransitionTime=completed_time, + reason="ProcessingFailed", + message="Some requests failed", + ), + Condition( + type=ConditionType.CANCELLED, + status=ConditionStatus.TRUE, + lastTransitionTime=cancelling_time, + reason="UserCancelled", + message="Job cancelled by user", + ), + ] + + # Create errors list that would be persisted + errors = [ + BatchJobError( + code=BatchJobErrorCode.INVALID_INPUT_FILE, + message="Input file contains invalid JSON", + param="input_file", + line=42, + ), + BatchJobError( + code=BatchJobErrorCode.INFERENCE_FAILED, + message="Model inference timeout", + param="timeout", + ), + ] + + # Create status with all possible fields that create_status_annotations touches + status = BatchJobStatus( + jobID="test-job-id", + state=BatchJobState.FINALIZED, # Required for finalized state annotation + createdAt=base_time, + inProgressAt=in_progress_time, # Touches IN_PROGRESS_AT + cancellingAt=cancelling_time, # Touches CANCELLING_AT + completedAt=completed_time, # Touches COMPLETED_AT + finalizingAt=finalizing_time, # Touches FINALIZING_AT + finalizedAt=finalized_time, # Touches FINALIZED_AT + conditions=conditions, # Touches CONDITION (priority: cancelled > failed, completed not persisted) + errors=errors, # Touches ERRORS + requestCounts=RequestCountStats( # Touches REQUEST_COUNTS + total=100, + launched=95, + completed=85, + failed=10, + ), + ) + + # Call create_status_annotations + annotations = BatchJobTransformer.create_status_annotations(status) + assert json.dumps(annotations) + + # Verify all expected annotations are created + + # 1. Finalized state annotation (only for FINALIZED state) + assert ( + annotations[JobAnnotationKey.JOB_STATE.value] == BatchJobState.FINALIZED.value + ) + + # 2. Condition annotation (should prioritize CANCELLED over FAILED; COMPLETED not persisted) + assert ( + annotations[JobAnnotationKey.CONDITION.value] == ConditionType.CANCELLED.value + ) + + # 3. Errors annotation (only if errors exist) + errors_json = annotations[JobAnnotationKey.ERRORS.value] + errors_data = json.loads(errors_json) + assert len(errors_data) == 2 + assert errors_data[0]["code"] == BatchJobErrorCode.INVALID_INPUT_FILE.value + assert errors_data[0]["message"] == "Input file contains invalid JSON" + assert errors_data[0]["param"] == "input_file" + assert errors_data[0]["line"] == 42 + assert errors_data[1]["code"] == BatchJobErrorCode.INFERENCE_FAILED.value + assert errors_data[1]["message"] == "Model inference timeout" + assert errors_data[1]["param"] == "timeout" + assert errors_data[1]["line"] is None + + # 4. Request counts annotation (only if total > 0) + request_counts_json = annotations[JobAnnotationKey.REQUEST_COUNTS.value] + request_counts_data = json.loads(request_counts_json) + assert request_counts_data == { + "total": 100, + "launched": 95, + "completed": 85, + "failed": 10, + } + + # 5. Timestamp annotations (only if timestamps exist) + assert ( + annotations[JobAnnotationKey.IN_PROGRESS_AT.value] + == in_progress_time.isoformat() + ) + assert ( + annotations[JobAnnotationKey.FINALIZING_AT.value] == finalizing_time.isoformat() + ) + assert ( + annotations[JobAnnotationKey.FINALIZED_AT.value] == finalized_time.isoformat() + ) + assert ( + annotations[JobAnnotationKey.CANCELLING_AT.value] == cancelling_time.isoformat() + ) + + +def test_create_status_annotations_condition_priority(): + """Test that condition annotation respects priority: cancelled > failed (completed not persisted).""" + base_time = datetime.now(timezone.utc) + + # Test priority: when multiple conditions exist, CANCELLED takes precedence + conditions_with_cancelled = [ + Condition( + type=ConditionType.COMPLETED, + status=ConditionStatus.TRUE, + lastTransitionTime=base_time, + ), + Condition( + type=ConditionType.CANCELLED, + status=ConditionStatus.TRUE, + lastTransitionTime=base_time, + ), + Condition( + type=ConditionType.FAILED, + status=ConditionStatus.TRUE, + lastTransitionTime=base_time, + ), + ] + + status = BatchJobStatus( + jobID="test-job-id", + state=BatchJobState.FINALIZED, + createdAt=base_time, + conditions=conditions_with_cancelled, + requestCounts=RequestCountStats(total=10), + ) + + annotations = BatchJobTransformer.create_status_annotations(status) + assert json.dumps(annotations) + assert ( + annotations[JobAnnotationKey.CONDITION.value] == ConditionType.CANCELLED.value + ) + + # Test priority: when no CANCELLED, FAILED takes precedence over COMPLETED + conditions_failed_completed = [ + Condition( + type=ConditionType.COMPLETED, + status=ConditionStatus.TRUE, + lastTransitionTime=base_time, + ), + Condition( + type=ConditionType.FAILED, + status=ConditionStatus.TRUE, + lastTransitionTime=base_time, + ), + ] + + status.conditions = conditions_failed_completed + annotations = BatchJobTransformer.create_status_annotations(status) + assert annotations[JobAnnotationKey.CONDITION.value] == ConditionType.FAILED.value + + # Test: when only COMPLETED exists, no condition annotation is created + conditions_only_completed = [ + Condition( + type=ConditionType.COMPLETED, + status=ConditionStatus.TRUE, + lastTransitionTime=base_time, + ), + ] + + status.conditions = conditions_only_completed + annotations = BatchJobTransformer.create_status_annotations(status) + assert ( + JobAnnotationKey.CONDITION.value not in annotations + ) # COMPLETED is not persisted + + +def test_create_status_annotations_errors(): + """Test that create_status_annotations correctly handles errors field.""" + base_time = datetime.now(timezone.utc) + + # Test with errors present + errors = [ + BatchJobError( + code=BatchJobErrorCode.AUTHENTICATION_ERROR, + message="Invalid API key", + param="api_key", + ), + ] + + status = BatchJobStatus( + jobID="test-job-id", + state=BatchJobState.FINALIZED, + createdAt=base_time, + errors=errors, + requestCounts=RequestCountStats(total=0), + ) + + annotations = BatchJobTransformer.create_status_annotations(status) + assert json.dumps(annotations) + + # Should persist errors + assert JobAnnotationKey.ERRORS.value in annotations + errors_json = annotations[JobAnnotationKey.ERRORS.value] + errors_data = json.loads(errors_json) + assert len(errors_data) == 1 + assert errors_data[0]["code"] == BatchJobErrorCode.AUTHENTICATION_ERROR.value + assert errors_data[0]["message"] == "Invalid API key" + assert errors_data[0]["param"] == "api_key" + assert errors_data[0]["line"] is None + + # Test with None errors + status.errors = None + annotations = BatchJobTransformer.create_status_annotations(status) + assert JobAnnotationKey.ERRORS.value not in annotations + + # Test with empty errors list + status.errors = [] + annotations = BatchJobTransformer.create_status_annotations(status) + assert JobAnnotationKey.ERRORS.value not in annotations + + +def test_create_status_annotations_edge_cases(): + """Test edge cases for create_status_annotations.""" + base_time = datetime.now(timezone.utc) + + # Test with empty conditions list + status = BatchJobStatus( + jobID="test-job-id", + state=BatchJobState.FINALIZED, + createdAt=base_time, + conditions=[], # Empty list + requestCounts=RequestCountStats(total=0), + ) + + annotations = BatchJobTransformer.create_status_annotations(status) + assert json.dumps(annotations) + + # Should persist finalized state but no condition + assert ( + annotations[JobAnnotationKey.JOB_STATE.value] == BatchJobState.FINALIZED.value + ) + assert JobAnnotationKey.CONDITION.value not in annotations + assert JobAnnotationKey.REQUEST_COUNTS.value not in annotations # total=0 + + # Test with None conditions + status.conditions = None + annotations = BatchJobTransformer.create_status_annotations(status) + + assert ( + annotations[JobAnnotationKey.JOB_STATE.value] == BatchJobState.FINALIZED.value + ) + assert JobAnnotationKey.CONDITION.value not in annotations diff --git a/python/aibrix/tests/batch/test_k8s_job_transformer.py b/python/aibrix/tests/batch/test_k8s_job_transformer.py new file mode 100644 index 000000000..aae320290 --- /dev/null +++ b/python/aibrix/tests/batch/test_k8s_job_transformer.py @@ -0,0 +1,1178 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from types import SimpleNamespace +from typing import Any + +import kopf +import pytest + +from aibrix.batch.job_entity import ( + BatchJobEndpoint, + BatchJobErrorCode, + BatchJobState, + CompletionWindow, + ConditionStatus, + ConditionType, + k8s_job_to_batch_job, +) +from aibrix.metadata.cache.utils import merge_yaml_object + + +class MockK8sJob: + """Mock Kubernetes Job object for testing.""" + + def __init__( + self, + metadata=None, + annotations=None, + status=None, + api_version="batch/v1", + kind="Job", + ): + self.metadata = metadata or {} + self.status = status or {} + self.spec = {"template": {"metadata": {"annotations": annotations or {}}}} + self.api_version = api_version + self.kind = kind + + +creation_time = datetime.fromisoformat("2025-08-05T05:26:10+00:00") +start_time = datetime.fromisoformat("2025-08-05T05:26:13+00:00") +cancel_time = datetime.fromisoformat("2025-08-05T05:26:20+00:00") +update_time = datetime.fromisoformat("2025-08-05T05:26:25+00:00") +end_time = datetime.fromisoformat("2025-08-05T05:26:30+00:00") + + +def _get_job_base_obj(): + return { + "apiVersion": "batch/v1", + "kind": "Job", + "metadata": { + "name": "test-batch-job", + "namespace": "default", + "uid": "test-uid-123", + "creationTimestamp": creation_time.isoformat(), + }, + "spec": { + "template": { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/input-file-id": "file-123", + "batch.job.aibrix.ai/endpoint": "/v1/chat/completions", + "batch.job.aibrix.ai/metadata.priority": "high", + "batch.job.aibrix.ai/metadata.customer": "test-customer", + }, + }, + }, + "activeDeadlineSeconds": 86400, + }, + "status": { + "startTime": creation_time.isoformat(), + "active": 0, + "terminating": 0, + "uncountedTerminatedPods": {}, + "ready": 0, + }, + } + + +def _get_job_created_obj(): + return merge_yaml_object( + _get_job_base_obj(), + { + "spec": { + "suspend": True, + }, + }, + ) + + +def _get_job_in_progress_obj(): + return merge_yaml_object( + _get_job_base_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "in_progress", + "batch.job.aibrix.ai/in-progress-at": start_time.isoformat(), + }, + }, + "spec": { + "template": { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/output-file-id": "output-123", + "batch.job.aibrix.ai/error-file-id": "error-123", + "batch.job.aibrix.ai/temp-output-file-id": "temp-output-123", + "batch.job.aibrix.ai/temp-error-file-id": "temp-error-123", + }, + }, + }, + }, + "status": { + "conditions": None, + "active": 1, + }, + }, + ) + + +def _get_job_succees_with_finalizing_obj(): + return merge_yaml_object( + _get_job_in_progress_obj(), + { + "status": { + "conditions": [ + { + "type": "SuccessCriteriaMet", + "status": "True", + "lastProbeTime": update_time.isoformat(), + "lastTransitionTime": update_time.isoformat(), + "reason": "CompletionsReached", + "message": "Reached expected number of succeeded pods", + }, + { + "type": "Complete", + "status": "True", + "lastProbeTime": update_time.isoformat(), + "lastTransitionTime": update_time.isoformat(), + "reason": "CompletionsReached", + "message": "Reached expected number of succeeded pods", + }, + ], + "completionTime": update_time.isoformat(), + "succeeded": 1, + }, + }, + ) + + +def _get_job_failed_with_finalizing_obj(): + return merge_yaml_object( + _get_job_in_progress_obj(), + { + "status": { + "conditions": [ + { + "type": "Failed", + "status": "True", + "lastTransitionTime": update_time.isoformat(), + "reason": "BackoffLimitExceeded", + "message": "Job has reached the specified backoff limit", + }, + ], + "completionTime": update_time.isoformat(), + "failed": 1, + }, + }, + ) + + +def _get_job_exprired_with_finalizing_obj(): + return merge_yaml_object( + _get_job_in_progress_obj(), + { + "status": { + "conditions": [ + { + "type": "Failed", + "status": "True", + "lastProbeTime": update_time.isoformat(), + "lastTransitionTime": update_time.isoformat(), + "reason": "DeadlineExceeded", + "message": "Job was active longer than specified deadline", + }, + ], + "completionTime": update_time.isoformat(), + "terminating": 1, + }, + }, + ) + + +def _get_job_cancelled_with_finalizing_obj(): + return merge_yaml_object( + _get_job_in_progress_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "finalizing", + "batch.job.aibrix.ai/condition": "cancelled", + "batch.job.aibrix.ai/cancelling-at": cancel_time.isoformat(), + }, + }, + "spec": { + "suspend": True, + }, + "status": { + "conditions": [ + { + "type": "Suspended", + "status": "True", + "lastProbeTime": update_time.isoformat(), + "lastTransitionTime": update_time.isoformat(), + "reason": "JobSuspended", + "message": "Job suspended", + }, + ], + }, + }, + ) + + +def _get_job_validation_failed_obj(): + return merge_yaml_object( + _get_job_created_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "finalized", + "batch.job.aibrix.ai/condition": "failed", + "batch.job.aibrix.ai/errors": '[{"code": "authentication_error", "message": "Simulated authentication failure", "param": "authentication", "line": null}]', + "batch.job.aibrix.ai/finalized-at": end_time.isoformat(), + }, + }, + "spec": { + "suspend": True, + }, + "status": { + "conditions": [ + { + "type": "Suspended", + "status": "True", + "lastProbeTime": start_time.isoformat(), + "lastTransitionTime": start_time.isoformat(), + "reason": "JobSuspended", + "message": "Job suspended", + }, + ], + }, + }, + ) + + +def _get_job_validation_cancelled_obj(): + return merge_yaml_object( + _get_job_created_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "finalized", + "batch.job.aibrix.ai/condition": "cancelled", + "batch.job.aibrix.ai/cancelling-at": cancel_time.isoformat(), + "batch.job.aibrix.ai/finalized-at": end_time.isoformat(), + }, + }, + "spec": { + "suspend": True, + }, + "status": { + "conditions": [ + { + "type": "Suspended", + "status": "True", + "lastProbeTime": start_time.isoformat(), + "lastTransitionTime": start_time.isoformat(), + "reason": "JobSuspended", + "message": "Job suspended", + }, + ], + }, + }, + ) + + +def _get_job_completed_obj(): + return merge_yaml_object( + _get_job_succees_with_finalizing_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "finalized", + "batch.job.aibrix.ai/finalizing-at": update_time.isoformat(), + "batch.job.aibrix.ai/finalized-at": end_time.isoformat(), + "batch.job.aibrix.ai/request-counts": '{"total":10, "completed":10, "failed":0}', + }, + }, + }, + ) + + +def _get_job_failed_obj(): + return merge_yaml_object( + _get_job_failed_with_finalizing_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "finalized", + "batch.job.aibrix.ai/finalizing-at": update_time.isoformat(), + "batch.job.aibrix.ai/finalized-at": end_time.isoformat(), + "batch.job.aibrix.ai/request-counts": '{"total":10, "completed":0, "failed":10}', + }, + }, + }, + ) + + +def _get_job_failed_during_finalizing_obj(): + return merge_yaml_object( + _get_job_succees_with_finalizing_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "finalized", + "batch.job.aibrix.ai/condition": "failed", + "batch.job.aibrix.ai/finalizing-at": update_time.isoformat(), + "batch.job.aibrix.ai/finalized-at": end_time.isoformat(), + "batch.job.aibrix.ai/request-counts": '{"total":10, "completed":0, "failed":10}', + }, + }, + }, + ) + + +def _get_job_expired_obj(): + return merge_yaml_object( + _get_job_exprired_with_finalizing_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "finalized", + "batch.job.aibrix.ai/finalizing-at": update_time.isoformat(), + "batch.job.aibrix.ai/finalized-at": end_time.isoformat(), + "batch.job.aibrix.ai/request-counts": '{"total":10, "completed":9, "failed":1}', + }, + }, + }, + ) + + +def _get_job_cancelled_obj(): + return merge_yaml_object( + _get_job_cancelled_with_finalizing_obj(), + { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/state": "finalized", + "batch.job.aibrix.ai/finalizing-at": update_time.isoformat(), + "batch.job.aibrix.ai/finalized-at": end_time.isoformat(), + "batch.job.aibrix.ai/request-counts": '{"total":10, "completed":5, "failed":0}', + }, + }, + }, + ) + + +def dict_to_obj(d: dict) -> Any: + """Recursively converts a dictionary to a multi-level object.""" + # Convert nested dictionaries recursively + for key, value in d.items(): + if isinstance(value, dict) and key != "annotations": + d[key] = dict_to_obj(value) + + # Convert the top-level dictionary to a SimpleNamespace object + return SimpleNamespace(**d) + + +def test_k8s_job_created(): + """Test successful transformation of Kubernetes job to BatchJob.""" + # Create mock Kubernetes job with required annotations + k8s_job = _get_job_created_obj() + batch_job = k8s_job_to_batch_job(k8s_job) + + # Verify transformation + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + + assert batch_job.metadata.name == "test-batch-job" + assert batch_job.metadata.namespace == "default" + assert batch_job.metadata.uid == "test-uid-123" + + assert batch_job.spec.input_file_id == "file-123" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert ( + batch_job.spec.completion_window + == CompletionWindow.TWENTY_FOUR_HOURS.expires_at() + ) + assert batch_job.spec.metadata == {"priority": "high", "customer": "test-customer"} + + assert batch_job.status.job_id == "test-uid-123" + assert batch_job.status.state == BatchJobState.CREATED + + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at is None + assert batch_job.status.finalizing_at is None + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + assert not batch_job.status.finished + assert batch_job.status.condition is None + + +def test_k8s_job_missing_required_annotation(): + """Test transformation fails when required annotation is missing.""" + k8s_job = MockK8sJob( + metadata={ + "name": "test-batch-job", + }, + annotations={ + # Missing required input-file-id annotation + "batch.job.aibrix.ai/endpoint": "/v1/chat/completions" + }, + ) + + with pytest.raises( + ValueError, match="Required annotation.*input-file-id.*not found" + ): + k8s_job_to_batch_job(k8s_job) + + +def test_k8s_job_invalid_endpoint(): + """Test transformation fails with invalid endpoint.""" + k8s_job = MockK8sJob( + annotations={ + "batch.job.aibrix.ai/input-file-id": "file-123", + "batch.job.aibrix.ai/endpoint": "/invalid/endpoint", + } + ) + + # We don't check validity of k8s job obj + batch_job = k8s_job_to_batch_job(k8s_job) + assert batch_job.spec.endpoint == "/invalid/endpoint" + + +def test_k8s_job_in_progress(): + """Test successful transformation of Kubernetes job to BatchJob.""" + # Create mock Kubernetes job with required annotations + k8s_job = _get_job_in_progress_obj() + batch_job = k8s_job_to_batch_job(k8s_job) + + # Verify transformation + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + + assert batch_job.metadata.name == "test-batch-job" + assert batch_job.metadata.namespace == "default" + assert batch_job.metadata.uid == "test-uid-123" + + assert batch_job.spec.input_file_id == "file-123" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert ( + batch_job.spec.completion_window + == CompletionWindow.TWENTY_FOUR_HOURS.expires_at() + ) + assert batch_job.spec.metadata == {"priority": "high", "customer": "test-customer"} + + assert batch_job.status.job_id == "test-uid-123" + assert batch_job.status.state == BatchJobState.IN_PROGRESS + + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at is None + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + assert not batch_job.status.finished + assert batch_job.status.condition is None + + +def test_k8s_job_success_with_finalizing(): + """Test successful transformation of Kubernetes job to BatchJob.""" + # Create mock Kubernetes job with required annotations + k8s_job = _get_job_succees_with_finalizing_obj() + batch_job = k8s_job_to_batch_job(k8s_job) + + # Skip type_meta, metadata, and spec testing + + assert batch_job.status.job_id == "test-uid-123" + assert batch_job.status.state == BatchJobState.FINALIZING + + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + assert not batch_job.status.finished + assert not batch_job.status.completed + assert batch_job.status.condition == ConditionType.COMPLETED + + +def test_k8s_job_completed(): + """Test transformation of successfully completed and finalized job.""" + batch_job = k8s_job_to_batch_job(_get_job_completed_obj()) + + # Skip type_meta, metadata, and spec testing + + # Should be FINALIZED based on annotations + assert batch_job.status.state == BatchJobState.FINALIZED + + # Should have completed condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert batch_job.status.finished + assert batch_job.status.completed + assert not batch_job.status.cancelled + assert not batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.COMPLETED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at == end_time + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + +def test_k8s_job_validation_failed(): + """Test transformation of validation failed job.""" + batch_job = k8s_job_to_batch_job(_get_job_validation_failed_obj()) + import json + + print(json.dumps(_get_job_validation_failed_obj(), indent=2)) + + # Skip type_meta, metadata, and spec testing + + # Should be FINALIZED based on annotations + assert batch_job.status.state == BatchJobState.FINALIZED + + # Should have failed condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert batch_job.status.finished + assert not batch_job.status.completed + assert not batch_job.status.cancelled + assert batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.FAILED + assert batch_job.status.errors is not None + assert len(batch_job.status.errors) > 0 + assert batch_job.status.errors[0].code == BatchJobErrorCode.AUTHENTICATION_ERROR + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at is None + assert batch_job.status.finalizing_at is None + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at == end_time + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + +def test_k8s_job_failed_with_finalizing(): + """Test successful transformation of Kubernetes job to BatchJob.""" + # Create mock Kubernetes job with required annotations + k8s_job = _get_job_failed_with_finalizing_obj() + batch_job = k8s_job_to_batch_job(k8s_job) + + # Skip type_meta, metadata, and spec testing + + assert batch_job.status.job_id == "test-uid-123" + assert batch_job.status.state == BatchJobState.FINALIZING + + # Should have failed condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert not batch_job.status.finished + assert not batch_job.status.completed + assert not batch_job.status.cancelled + assert not batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.FAILED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + +def test_k8s_job_failed(): + """Test transformation of successfully completed and finalized job.""" + batch_job = k8s_job_to_batch_job(_get_job_failed_obj()) + + # Skip type_meta, metadata, and spec testing + + # Should be FINALIZED based on annotations + assert batch_job.status.state == BatchJobState.FINALIZED + + # Should have completed condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert batch_job.status.finished + assert not batch_job.status.completed + assert not batch_job.status.cancelled + assert batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.FAILED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at == end_time + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + +def test_k8s_job_failed_during_finalizing(): + """Test transformation of successfully completed and finalized job.""" + batch_job = k8s_job_to_batch_job(_get_job_failed_during_finalizing_obj()) + + # Skip type_meta, metadata, and spec testing + print(str(batch_job.status.dict())) + + # Should be FINALIZED based on annotations + assert batch_job.status.state == BatchJobState.FINALIZED + + # Should have completed condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert batch_job.status.finished + assert batch_job.status.completed + assert not batch_job.status.cancelled + assert batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.FAILED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at == end_time + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + +def test_k8s_job_expired_with_finalizing(): + """Test successful transformation of Kubernetes job to BatchJob.""" + # Create mock Kubernetes job with required annotations + k8s_job = _get_job_exprired_with_finalizing_obj() + batch_job = k8s_job_to_batch_job(k8s_job) + + # Skip type_meta, metadata, and spec testing + + assert batch_job.status.job_id == "test-uid-123" + assert batch_job.status.state == BatchJobState.FINALIZING + + # Should have expired condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert not batch_job.status.finished + assert not batch_job.status.completed + assert not batch_job.status.cancelled + assert not batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.EXPIRED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + +def test_k8s_job_expired(): + """Test transformation of successfully completed and finalized job.""" + batch_job = k8s_job_to_batch_job(_get_job_expired_obj()) + + # Skip type_meta, metadata, and spec testing + + # Should be FINALIZED based on annotations + assert batch_job.status.state == BatchJobState.FINALIZED + + # Should have completed condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert batch_job.status.finished + assert not batch_job.status.completed + assert not batch_job.status.cancelled + assert not batch_job.status.failed + assert batch_job.status.expired + assert batch_job.status.condition == ConditionType.EXPIRED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at == end_time + assert batch_job.status.cancelling_at is None + assert batch_job.status.cancelled_at is None + + +def test_k8s_job_validation_cancelled(): + """Test transformation of cancelled job in finalized state.""" + batch_job = k8s_job_to_batch_job(_get_job_validation_cancelled_obj()) + + # Skip type_meta, metadata, and spec testing + + # Should be FINALIZED based on annotations + assert batch_job.status.state == BatchJobState.FINALIZED + + # Should have cancelled condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert batch_job.status.finished + assert not batch_job.status.completed + assert batch_job.status.cancelled + assert not batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.CANCELLED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at is None + assert batch_job.status.finalizing_at is None + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at == cancel_time + assert batch_job.status.cancelled_at == end_time + + +def test_k8s_job_cancelled_with_finalizing(): + """Test transformation of cancelled job in finalizing state.""" + batch_job = k8s_job_to_batch_job(_get_job_cancelled_with_finalizing_obj()) + + # Skip type_meta, metadata, and spec testing + + # Should be FINALIZING since it has conditions + assert batch_job.status.state == BatchJobState.FINALIZING + + # Should have cancelled condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert not batch_job.status.finished + assert not batch_job.status.completed + assert not batch_job.status.cancelled + assert not batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.CANCELLED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at == cancel_time + assert batch_job.status.cancelled_at is None + + +def test_k8s_job_cancelled(): + """Test transformation of cancelled job in finalizing state.""" + batch_job = k8s_job_to_batch_job(_get_job_cancelled_obj()) + + # Skip type_meta, metadata, and spec testing + + # Should be FINALIZING since it has conditions + assert batch_job.status.state == BatchJobState.FINALIZED + + # Should have cancelled condition + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) > 0 + assert batch_job.status.finished + assert not batch_job.status.completed + assert batch_job.status.cancelled + assert not batch_job.status.failed + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.CANCELLED + + # Verify timestamp requirements + assert batch_job.status.created_at == creation_time + assert batch_job.status.in_progress_at == start_time + assert batch_job.status.finalizing_at == update_time + assert batch_job.status.completed_at is None + assert batch_job.status.failed_at is None + assert batch_job.status.expired_at is None + assert batch_job.status.cancelling_at == cancel_time + assert batch_job.status.cancelled_at == end_time + + +def test_k8s_job_obj_access(): + """Test transformer works with object-style access.""" + obj = dict_to_obj(_get_job_created_obj()) + batch_job = k8s_job_to_batch_job(obj) + + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + assert batch_job.metadata.name == "test-batch-job" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert batch_job.status.state == BatchJobState.CREATED + assert not batch_job.status.finished + assert batch_job.status.condition is None + + obj = dict_to_obj(_get_job_in_progress_obj()) + batch_job = k8s_job_to_batch_job(obj) + + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + assert batch_job.metadata.name == "test-batch-job" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert batch_job.status.state == BatchJobState.IN_PROGRESS + assert not batch_job.status.finished + assert batch_job.status.condition is None + + obj = dict_to_obj(_get_job_succees_with_finalizing_obj()) + batch_job = k8s_job_to_batch_job(obj) + + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + assert batch_job.metadata.name == "test-batch-job" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert batch_job.status.state == BatchJobState.FINALIZING + assert not batch_job.status.finished + assert not batch_job.status.completed + assert batch_job.status.condition == ConditionType.COMPLETED + + obj = dict_to_obj(_get_job_exprired_with_finalizing_obj()) + batch_job = k8s_job_to_batch_job(obj) + batch_job.status.state == BatchJobState.FINALIZING + assert not batch_job.status.finished + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.EXPIRED + + +def test_k8s_job_kopf_access(): + """Test transformer works with dict-style access (e.g., from kopf).""" + + kopf_body = kopf.Body(_get_job_created_obj()) + batch_job = k8s_job_to_batch_job(kopf_body) + + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + assert batch_job.metadata.name == "test-batch-job" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert batch_job.status.state == BatchJobState.CREATED + assert not batch_job.status.finished + assert batch_job.status.condition is None + + kopf_body = kopf.Body(_get_job_in_progress_obj()) + batch_job = k8s_job_to_batch_job(kopf_body) + + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + assert batch_job.metadata.name == "test-batch-job" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert batch_job.status.state == BatchJobState.IN_PROGRESS + assert not batch_job.status.finished + assert batch_job.status.condition is None + + kopf_body = kopf.Body(_get_job_succees_with_finalizing_obj()) + batch_job = k8s_job_to_batch_job(kopf_body) + + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + assert batch_job.metadata.name == "test-batch-job" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert batch_job.status.state == BatchJobState.FINALIZING + assert not batch_job.status.finished + assert not batch_job.status.completed + assert batch_job.status.condition == ConditionType.COMPLETED + + kopf_body = kopf.Body(_get_job_exprired_with_finalizing_obj()) + batch_job = k8s_job_to_batch_job(kopf_body) + batch_job.status.state == BatchJobState.FINALIZING + assert not batch_job.status.finished + assert not batch_job.status.expired + assert batch_job.status.condition == ConditionType.EXPIRED + + +def test_k8s_job_s3_integration_case(): + """Test transformer with real S3 integration job object structure.""" + k8s_job = { + "apiVersion": "batch/v1", + "kind": "Job", + "metadata": { + "name": "s3-batch-job", + "namespace": "default", + "uid": "b08af167-8d56-41e9-92b6-efebe4a859ab", + "creationTimestamp": "2025-07-28T21:13:41Z", + "resourceVersion": "767483", + "generation": 1, + "labels": { + "app": "aibrix-batch", + "component": "batch-processor", + }, + }, + "spec": { + "activeDeadlineSeconds": 3600, + "backoffLimit": 3, + "completions": 1, + "parallelism": 1, + "selector": { + "matchLabels": { + "batch.kubernetes.io/controller-uid": "b08af167-8d56-41e9-92b6-efebe4a859ab" + } + }, + "template": { + "metadata": { + "labels": { + "app": "aibrix-batch", + "component": "batch-processor", + "batch.kubernetes.io/controller-uid": "b08af167-8d56-41e9-92b6-efebe4a859ab", + "batch.kubernetes.io/job-name": "s3-batch-job", + "controller-uid": "b08af167-8d56-41e9-92b6-efebe4a859ab", + "job-name": "s3-batch-job", + }, + "annotations": { + "batch.job.aibrix.ai/endpoint": "/v1/chat/completions", + "batch.job.aibrix.ai/input-file-id": "s3-test-input-db5ada19.jsonl", + }, + }, + "spec": { + "restartPolicy": "OnFailure", + "serviceAccountName": "job-reader-sa", + "automountServiceAccountToken": True, + "containers": [ + { + "name": "batch-worker", + "image": "aibrix/runtime:nightly", + "command": ["aibrix_batch_worker"], + "env": [ + { + "name": "JOB_NAME", + "valueFrom": { + "fieldRef": { + "apiVersion": "v1", + "fieldPath": "metadata.labels['job-name']", + } + }, + }, + { + "name": "JOB_NAMESPACE", + "valueFrom": { + "fieldRef": { + "apiVersion": "v1", + "fieldPath": "metadata.namespace", + } + }, + }, + { + "name": "STORAGE_AWS_REGION", + "value": "us-west-1", + }, + { + "name": "STORAGE_AWS_BUCKET", + "value": "tianium.aibrix", + }, + { + "name": "REDIS_HOST", + "value": "aibrix-redis-master.aibrix-system.svc.cluster.local", + }, + ], + }, + { + "name": "llm-engine", + "image": "aibrix/vllm-mock:nightly", + "ports": [{"containerPort": 8000, "protocol": "TCP"}], + "readinessProbe": { + "httpGet": { + "path": "/ready", + "port": 8000, + "scheme": "HTTP", + }, + "periodSeconds": 5, + "timeoutSeconds": 1, + "successThreshold": 1, + "failureThreshold": 3, + }, + }, + ], + }, + }, + }, + "status": { + "active": 1, + "ready": 0, + "startTime": "2025-07-28T21:13:41Z", + "terminating": 0, + }, + } + + batch_job = k8s_job_to_batch_job(k8s_job) + + # Verify transformation results + assert batch_job.type_meta.api_version == "batch/v1" + assert batch_job.type_meta.kind == "Job" + + assert batch_job.metadata.name == "s3-batch-job" + assert batch_job.metadata.namespace == "default" + assert batch_job.metadata.uid == "b08af167-8d56-41e9-92b6-efebe4a859ab" + assert batch_job.metadata.resource_version == "767483" + assert batch_job.metadata.generation == 1 + + assert batch_job.spec.input_file_id == "s3-test-input-db5ada19.jsonl" + assert batch_job.spec.endpoint == BatchJobEndpoint.CHAT_COMPLETIONS + assert batch_job.spec.completion_window == 3600 + + assert batch_job.status.job_id == "b08af167-8d56-41e9-92b6-efebe4a859ab" + assert batch_job.status.state == BatchJobState.CREATED + + +def test_condition_mapping_completed(): + """Test that K8s Complete condition maps to ConditionType.COMPLETED.""" + batch_job = k8s_job_to_batch_job(_get_job_succees_with_finalizing_obj()) + + # Should be FINALIZING since conditions exist + assert batch_job.status.state == BatchJobState.FINALIZING + + # Should have conditions mapped + assert batch_job.status.conditions is not None + assert ( + len(batch_job.status.conditions) == 1 + ) # Only Complete maps, SuccessCriteriaMet doesn't + + condition = batch_job.status.conditions[0] + assert condition.type == ConditionType.COMPLETED + + +def test_condition_mapping_expired(): + """Test that K8s Failed condition with DeadlineExceeded maps to ConditionType.EXPIRED.""" + batch_job = k8s_job_to_batch_job(_get_job_exprired_with_finalizing_obj()) + + # Should be FINALIZING since conditions exist + assert batch_job.status.state == BatchJobState.FINALIZING + + # Should have conditions mapped + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) == 1 + + condition = batch_job.status.conditions[0] + assert condition.type == ConditionType.EXPIRED + assert condition.status == ConditionStatus.TRUE + assert condition.reason == "DeadlineExceeded" + assert condition.last_transition_time is not None + + +def test_condition_mapping_failed(): + """Test that K8s Failed condition (non-deadline) maps to ConditionType.FAILED.""" + batch_job = k8s_job_to_batch_job(_get_job_failed_with_finalizing_obj()) + + # Should be FINALIZING since conditions exist + assert batch_job.status.state == BatchJobState.FINALIZING + + # Should have conditions mapped + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) == 1 + + condition = batch_job.status.conditions[0] + assert condition.type == ConditionType.FAILED + assert condition.status == ConditionStatus.TRUE + assert condition.reason == "BackoffLimitExceeded" + assert condition.last_transition_time is not None + + +def test_no_conditions_legacy_behavior(): + """Test that jobs without conditions use legacy state mapping.""" + batch_job = k8s_job_to_batch_job(_get_job_in_progress_obj()) + + # Should be IN_PROGRESS (legacy mapping) since no conditions exist + assert batch_job.status.state == BatchJobState.IN_PROGRESS + + # Should have no conditions + assert batch_job.status.conditions is None + + +def test_unknown_conditions_ignored(): + """Test that unknown K8s condition types are ignored.""" + k8s_job = { + "apiVersion": "batch/v1", + "kind": "Job", + "metadata": { + "name": "test-batch-job", + "namespace": "default", + "uid": "test-uid-123", + "creationTimestamp": "2024-01-01T12:00:00Z", + "annotations": { + "batch.job.aibrix.ai/state": "in_progress", + "batch.job.aibrix.ai/in-progress-at": start_time.isoformat(), + }, + }, + "spec": { + "template": { + "metadata": { + "annotations": { + "batch.job.aibrix.ai/input-file-id": "file-123", + "batch.job.aibrix.ai/endpoint": "/v1/embeddings", + "batch.job.aibrix.ai/output-file-id": "output-123", + "batch.job.aibrix.ai/error-file-id": "error-123", + "batch.job.aibrix.ai/temp-output-file-id": "temp-output-123", + "batch.job.aibrix.ai/temp-error-file-id": "temp-error-123", + }, + }, + }, + }, + "status": { + "conditions": [ + { + "type": "ProgressDeadlineExceeded", # Unknown condition type + "status": "True", + "lastTransitionTime": "2025-08-05T05:26:25Z", + "reason": "UnknownReason", + "message": "Unknown condition message", + }, + { + "type": "Complete", # Known condition type + "status": "True", + "lastTransitionTime": "2025-08-05T05:26:25Z", + "reason": "CompletionsReached", + "message": "Job completed successfully", + }, + ], + }, + } + + batch_job = k8s_job_to_batch_job(k8s_job) + + # Should be FINALIZING since valid conditions exist + assert batch_job.status.state == BatchJobState.FINALIZING + + # Should have only the known condition mapped + assert batch_job.status.conditions is not None + assert len(batch_job.status.conditions) == 1 + + condition = batch_job.status.conditions[0] + assert condition.type == ConditionType.COMPLETED diff --git a/python/aibrix/tests/batch/test_rbac_setup.py b/python/aibrix/tests/batch/test_rbac_setup.py new file mode 100644 index 000000000..e605d7118 --- /dev/null +++ b/python/aibrix/tests/batch/test_rbac_setup.py @@ -0,0 +1,98 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test to verify that RBAC setup works correctly for tests that depend on create_test_app. +""" + +import pytest +from kubernetes import client + + +@pytest.mark.asyncio +async def test_rbac_resources_exist(k8s_config, ensure_job_rbac, test_namespace): + """ + Test that verifies the RBAC resources are properly created and accessible. + This test ensures that: + 1. unittest-job-reader-sa service account exists + 2. Required roles and bindings are in place + """ + core_v1 = client.CoreV1Api() + rbac_v1 = client.RbacAuthorizationV1Api() + + # The fixture returns the service account name + service_account_name = ensure_job_rbac + expected_sa_name = "unittest-job-reader-sa" + + # Verify the fixture returns the correct service account name + assert ( + service_account_name == expected_sa_name + ), f"Expected {expected_sa_name}, got {service_account_name}" + print(f"✓ Fixture returned correct service account name: {service_account_name}") + + # Check that unittest-job-reader-sa service account exists + try: + sa = core_v1.read_namespaced_service_account( + name=service_account_name, namespace=test_namespace + ) + assert sa.metadata.name == service_account_name + print( + f"✓ {service_account_name} service account exists in namespace {test_namespace}" + ) + except client.ApiException as e: + pytest.fail(f"{service_account_name} service account not found: {e}") + + # Check that unittest-job-reader role exists + try: + role = rbac_v1.read_namespaced_role( + name="unittest-job-reader-role", namespace=test_namespace + ) + assert role.metadata.name == "unittest-job-reader-role" + print(f"✓ unittest-job-reader-role role exists in namespace {test_namespace}") + except client.ApiException as e: + pytest.fail(f"unittest-job-reader-role role not found: {e}") + + # Check that unittest-job-reader role binding exists + try: + role_binding = rbac_v1.read_namespaced_role_binding( + name="unittest-job-reader-binding", namespace=test_namespace + ) + assert role_binding.metadata.name == "unittest-job-reader-binding" + print( + f"✓ unittest-job-reader-binding role binding exists in namespace {test_namespace}" + ) + except client.ApiException as e: + pytest.fail(f"unittest-job-reader-binding role binding not found: {e}") + + print("✅ All RBAC resources are properly configured!") + + +@pytest.mark.asyncio +async def test_create_test_app_with_rbac(ensure_job_rbac): + """ + Test that create_test_app can be called with enable_k8s_job=True + when RBAC resources are available. + """ + from aibrix.storage import StorageType + from tests.batch.conftest import create_test_app + + # This should not raise any errors if RBAC is properly set up + app = create_test_app( + enable_k8s_job=True, + storage_type=StorageType.LOCAL, + metastore_type=StorageType.LOCAL, + ) + + assert app is not None + print("✅ create_test_app with enable_k8s_job=True works correctly!") diff --git a/python/aibrix/tests/batch/test_worker_s3_integration.py b/python/aibrix/tests/batch/test_worker_s3_integration.py new file mode 100644 index 000000000..bb2f011b0 --- /dev/null +++ b/python/aibrix/tests/batch/test_worker_s3_integration.py @@ -0,0 +1,536 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import copy +import json +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +import boto3 +import pytest +from kubernetes import client + +import aibrix.batch.storage as _storage +import aibrix.batch.storage.batch_metastore as _metastore +from aibrix.batch.job_entity import BatchJob, BatchJobSpec, BatchJobState +from aibrix.logger import init_logger +from aibrix.metadata.cache.job import JobCache +from aibrix.storage.types import StorageType + +logger = init_logger(__name__) + + +class TestWorkerS3Integration: + """Test worker with S3 storage and Redis metadata integration.""" + + @pytest.fixture + def init_storage(self, test_s3_bucket): + """Get or create test S3 bucket.""" + try: + _storage.initialize_storage(StorageType.S3, {"bucket_name": test_s3_bucket}) + _metastore.initialize_batch_metastore(StorageType.REDIS) + except Exception as e: + pytest.skip(f"Cannot initialize S3 storage and redis metastore: {e}") + + @pytest.fixture + def test_input_data(self): + """Load test input data from sample_job_input.jsonl.""" + import json + + sample_file_path = Path(__file__).parent / "testdata" / "sample_job_input.jsonl" + + test_data = [] + with open(sample_file_path, "r") as f: + for line in f: + line = line.strip() + if line: + test_data.append(json.loads(line)) + + return test_data + + async def _upload_test_data_to_s3( + self, test_s3_bucket: str, input_file_id: str, test_data: list + ) -> Any: + """ + Upload test data to S3 bucket. + + Returns: + S3 client for cleanup operations + """ + session = boto3.Session() + s3_client = session.client("s3") + + # Create JSONL content + jsonl_content = "\n".join(json.dumps(item) for item in test_data) + s3_key = input_file_id + + s3_client.put_object( + Bucket=test_s3_bucket, + Key=s3_key, + Body=jsonl_content.encode("utf-8"), + ContentType="application/jsonl", + ) + logger.info( + f"Uploaded {len(test_data)} requests to S3: s3://{test_s3_bucket}/{s3_key}" + ) + + return s3_client + + async def _submit_patch_and_monitor_job( + self, + test_namespace: str, + input_file_id: str, + test_s3_bucket: str, + job_name: str, + timeout: int, + job_cache: JobCache, + is_parallel: bool = False, + ) -> BatchJob: + """ + Submit job to Kubernetes and monitor until completion. + + Args: + batch_client: Kubernetes batch client + test_namespace: Kubernetes namespace + job_spec: Job specification + job_name: Name of the job + timeout: Timeout in seconds + is_parallel: Whether to log parallel-specific progress + """ + # Submit job to Kubernetes + job_type = "parallel" if is_parallel else "single" + logger.info(f"Submitting S3 {job_type} batch job to Kubernetes...") + + # Set up job monitoring with kopf-powered JobCache + main_loop = asyncio.get_running_loop() + main_loop.name = "test" # type: ignore[attr-defined] + job_committed = main_loop.create_future() + job_done = main_loop.create_future() + + async def job_commited_handler(job: BatchJob) -> bool: + if not job_committed.done(): + main_loop.call_soon_threadsafe(job_committed.set_result, job) + return True + + return False + + async def job_updated_handler(_: BatchJob, newJob: BatchJob) -> bool: + if job_done.done(): + return False + + if newJob.status.state == BatchJobState.FINALIZING: + main_loop.call_soon_threadsafe(job_done.set_result, newJob) + + return True + + # Register handlers with the kopf-powered JobCache + job_cache.on_job_committed(job_commited_handler) + job_cache.on_job_updated(job_updated_handler) + + job_spec = BatchJobSpec.from_strings( + input_file_id, "/v1/chat/completions", "24h", None + ) + + # Create job + await job_cache.submit_job( + "test-session-id", + job_spec, + job_name=job_name, + parallelism=3 if is_parallel else 1, + ) + + # Wait for kopf to detect the job creation and trigger handlers + logger.info("Waiting for kopf to detect job creation...") + created_job = await asyncio.wait_for(job_committed, timeout=timeout) + + try: + assert created_job.metadata.name == job_name + assert created_job.status.state == BatchJobState.CREATED + logger.info( + f"S3 {job_type} batch job submitted successfully:", job_name=job_name + ) # type:ignore[call-arg] + + # Emulate job_driver behavior and create tempfiles + created_job.status.in_progress_at = datetime.now(timezone.utc) + created_job.status.state = BatchJobState.IN_PROGRESS + prepared_job = await _storage.prepare_job_ouput_files(created_job) + + await job_cache.update_job_ready(prepared_job) + logger.info( + "Batch job patched with file ids successfully:", + job_name=job_name, + output_file_id=prepared_job.status.output_file_id, + temp_output_file_id=prepared_job.status.temp_output_file_id, + error_file_id=prepared_job.status.error_file_id, + temp_error_file_id=prepared_job.status.temp_error_file_id, + ) # type:ignore[call-arg] + + # Wait for job completion + finished_job = await asyncio.wait_for(job_done, timeout=timeout) + assert finished_job.status.state == BatchJobState.FINALIZING + + return finished_job + + except Exception as e: + await self._log_pod_details(test_namespace, job_name) + pytest.fail( + f"S3 {job_type} job did not complete within {timeout} seconds, error={str(e)}" + ) + raise + + async def _cleanup_s3_and_k8s_resources( + self, + s3_client: Any, + test_s3_bucket: str, + input_file_id: str, + job: Any, + test_namespace: str, + job_name: str, + ) -> None: + """ + Clean up S3 objects and Kubernetes job. + + Args: + s3_client: S3 client + test_s3_bucket: S3 bucket name + input_file_id: Input file identifier + job: BatchJob object + batch_client: Kubernetes batch client + test_namespace: Kubernetes namespace + job_name: Job name + """ + # Cleanup S3 objects + try: + # Delete input files + response = s3_client.list_objects_v2( + Bucket=test_s3_bucket, Prefix=input_file_id + ) + + if "Contents" in response: + for obj in response["Contents"]: + s3_client.delete_object(Bucket=test_s3_bucket, Key=obj["Key"]) + logger.info(f"Deleted S3 object: {obj['Key']}") + + # Delete output files + if job and job.status.temp_output_file_id: + output_prefix = f".multipart/{job.status.temp_output_file_id}/" + output_response = s3_client.list_objects_v2( + Bucket=test_s3_bucket, Prefix=output_prefix + ) + + if "Contents" in output_response: + for obj in output_response["Contents"]: + s3_client.delete_object(Bucket=test_s3_bucket, Key=obj["Key"]) + logger.info(f"Deleted S3 output object: {obj['Key']}") + + # Delete error files + error_prefix = f".multipart/{job.status.temp_error_file_id}/" + error_response = s3_client.list_objects_v2( + Bucket=test_s3_bucket, Prefix=error_prefix + ) + + if "Contents" in error_response: + for obj in error_response["Contents"]: + s3_client.delete_object(Bucket=test_s3_bucket, Key=obj["Key"]) + logger.info(f"Deleted S3 error object: {obj['Key']}") + except Exception as e: + logger.warning(f"Failed to cleanup S3 objects: {e}") + + # Delete the Kubernetes job + try: + batch_client = client.BatchV1Api() + batch_client.delete_namespaced_job( + name=job_name, + namespace=test_namespace, + propagation_policy="Background", + ) + logger.info(f"Deleted S3 batch job: {job_name}") + except client.ApiException as e: + if e.status != 404: + logger.warning(f"Failed to delete job {job_name}: {e}") + + def _expand_test_data(self, test_input_data: list, multiplier: int) -> list: + """ + Expand test data by duplicating and modifying custom_id. + + Args: + test_input_data: Original test data + multiplier: How many times to duplicate the data + + Returns: + Expanded test data list + """ + expanded_test_data = [] + for i in range(multiplier): + for item in test_input_data: + expanded_item = copy.deepcopy(item) + expanded_item["custom_id"] = ( + f"{item.get('custom_id', 'test')}-batch-{i}" + ) + expanded_test_data.append(expanded_item) + + logger.info(f"Created {len(expanded_test_data)} test requests for processing") + return expanded_test_data + + @pytest.mark.asyncio + async def test_single_worker( + self, + k8s_config, + test_namespace, + s3_credentials_secret, + test_s3_bucket, + init_storage, + test_input_data, + job_cache, + ) -> None: + """Test worker using S3 storage and Redis metadata.""" + + # Generate unique job name and setup + job_name = "s3-batch-job" + input_file_id = f"s3-test-input-{str(uuid.uuid4())[:8]}.jsonl" + + # Upload test data to S3 + s3_client = await self._upload_test_data_to_s3( + test_s3_bucket, input_file_id, test_input_data + ) + + job: Optional[BatchJob] = None + try: + # Submit job and monitor until completion + job = await self._submit_patch_and_monitor_job( + test_namespace, + input_file_id, + test_s3_bucket, + job_name, + 60, + job_cache, + is_parallel=False, + ) + + # Verify S3 outputs exist + await self._verify_s3_outputs( + s3_client, + test_s3_bucket, + job.status.temp_output_file_id, + len(test_input_data), + ) + + # Verify Redis locking worked correctly by checking completion keys + await self._verify_redis_completion_keys(job, len(test_input_data)) + + logger.info("S3 integration test completed successfully!") + + finally: + # Cleanup all resources + await self._cleanup_s3_and_k8s_resources( + s3_client, + test_s3_bucket, + input_file_id, + job, + test_namespace, + job_name, + ) + + @pytest.mark.asyncio + async def test_parallel_workers( + self, + k8s_config, + test_namespace, + s3_credentials_secret, + test_s3_bucket, + init_storage, + test_input_data, + job_cache, + ) -> None: + """Test 3 concurrent workers using S3 storage and Redis metadata with request locking.""" + + # Generate unique job name and setup + job_name = "s3-parallel-batch-job" + input_file_id = f"s3-parallel-test-input-{str(uuid.uuid4())[:8]}.jsonl" + + # Create expanded test data for parallel processing + expanded_test_data = self._expand_test_data(test_input_data, multiplier=3) + + # Upload expanded test data to S3 + s3_client = await self._upload_test_data_to_s3( + test_s3_bucket, input_file_id, expanded_test_data + ) + + job: Optional[BatchJob] = None + try: + # Submit job and monitor until completion + job = await self._submit_patch_and_monitor_job( + test_namespace, + input_file_id, + test_s3_bucket, + job_name, + 60, + job_cache, + is_parallel=True, + ) + + # Verify S3 outputs exist for all requests + await self._verify_s3_outputs( + s3_client, + test_s3_bucket, + job.status.temp_output_file_id, + len(expanded_test_data), + ) + + # Verify Redis locking worked correctly by checking completion keys + await self._verify_redis_completion_keys(job, len(expanded_test_data)) + + logger.info( + "S3 parallel integration test with 3 workers completed successfully!" + ) + + finally: + # Cleanup all resources + await self._cleanup_s3_and_k8s_resources( + s3_client, + test_s3_bucket, + input_file_id, + job, + test_namespace, + job_name, + ) + + async def _log_pod_details(self, namespace, job_name): + """Log pod details for debugging.""" + core_v1 = client.CoreV1Api() + + try: + pods = core_v1.list_namespaced_pod( + namespace=namespace, label_selector=f"job-name={job_name}" + ) + + for pod in pods.items: + logger.info(f"S3 Pod: {pod.metadata.name}, Status: {pod.status.phase}") + + # Log container statuses + if pod.status.container_statuses: + for container_status in pod.status.container_statuses: + logger.info( + f"S3 Container: {container_status.name}, Ready: {container_status.ready}" + ) + + # Get pod logs + for container in ["batch-worker", "llm-engine"]: + try: + logs = core_v1.read_namespaced_pod_log( + name=pod.metadata.name, + namespace=namespace, + container=container, + tail_lines=100, # More logs for S3 debugging + ) + print( + f"####################### S3 {container} logs starts #######################" + ) + for line in logs.splitlines(): + print(line) + print( + f"####################### S3 {container} logs ends #######################" + ) + except client.ApiException: + logger.warning( + f"Could not get S3 logs for container {container}" + ) + + except client.ApiException as e: + logger.error(f"Failed to get S3 pod details: {e}") + + async def _verify_s3_outputs( + self, s3_client, bucket, temp_output_file_id, expected_count + ): + """Verify that output files were created in S3.""" + + # Check for output files in S3 + output_prefix = f".multipart/{temp_output_file_id}/" # See BaseStorage::_multipart_upload_key + response = s3_client.list_objects_v2(Bucket=bucket, Prefix=output_prefix) + + assert "Contents" in response + assert ( + len(response["Contents"]) == expected_count + 1 + ) # Including .multipart metadata + logger.info(f"Found {len(response['Contents'])} output files in S3:") + + for obj in response["Contents"]: + logger.info("Loading request output", key=obj["Key"], size=obj["Size"]) # type:ignore[call-arg] + + # check file content + content = s3_client.get_object(Bucket=bucket, Key=obj["Key"]) + raw_output = content["Body"].read().decode("utf-8") + logger.info("Loaded request output", key=obj["Key"], output=raw_output) # type:ignore[call-arg] + # Skip metadata + if obj["Key"].endswith("/metadata"): + continue + output = json.loads(raw_output) + assert "id" in output + assert "custom_id" in output + assert "response" in output + + response = output["response"] + assert "status_code" in response + assert "request_id" in response + assert "body" in response + + body = response["body"] + assert "id" in body + assert "model" in body + assert "object" in body + + async def _verify_redis_completion_keys(self, job, expected_count): + """Verify that all requests have completion keys in Redis metastore.""" + try: + # Initialize Redis metastore to check completion status + + import aibrix.batch.storage.batch_metastore as metastore + from aibrix.storage import StorageType + + metastore.initialize_batch_metastore(StorageType.REDIS) + + completed_count = 0 + missing_keys = [] + + # Check each request's completion status + for i in range(expected_count): + completion_key = f"batch:{job.job_id}:done/{i}" + status, exists = await metastore.get_metadata(completion_key) + + if exists: + completed_count += 1 + logger.info(f"Found completion key {completion_key}: {status}") + else: + missing_keys.append(completion_key) + + logger.info( + f"Found {completed_count}/{expected_count} completion keys in Redis" + ) + + if missing_keys: + logger.warning( + f"Missing completion keys: {missing_keys[:10]}..." + ) # Show first 10 + + # We expect all requests to be completed + assert ( + completed_count == expected_count + ), f"Expected {expected_count} completed requests, but found {completed_count}" + + except Exception as e: + logger.warning(f"Could not verify Redis completion keys: {e}") + # Don't fail the test if Redis verification fails diff --git a/python/aibrix/tests/batch/testdata/job_rbac.yaml b/python/aibrix/tests/batch/testdata/job_rbac.yaml new file mode 100644 index 000000000..6b326e7b2 --- /dev/null +++ b/python/aibrix/tests/batch/testdata/job_rbac.yaml @@ -0,0 +1,37 @@ +# RBAC resources for api-extension testing +# This file contains the necessary Kubernetes RBAC resources for tests that depend on create_test_app +# with enable_k8s_job=True + +--- +# Service Account for job reader (as referenced in config/api-extension/rbac.yaml) +apiVersion: v1 +kind: ServiceAccount +metadata: + name: unittest-job-reader-sa + namespace: default +--- +# Role for job reader +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: unittest-job-reader-role + namespace: default +rules: +- apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get"] # Get permissions only +--- +# RoleBinding for job reader +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: unittest-job-reader-binding + namespace: default +subjects: +- kind: ServiceAccount + name: unittest-job-reader-sa + namespace: default +roleRef: + kind: Role + name: unittest-job-reader-role + apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/python/aibrix/tests/batch/testdata/k8s_job_patch_unittest.yaml b/python/aibrix/tests/batch/testdata/k8s_job_patch_unittest.yaml new file mode 100644 index 000000000..dcec64fb8 --- /dev/null +++ b/python/aibrix/tests/batch/testdata/k8s_job_patch_unittest.yaml @@ -0,0 +1,14 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: batch-job-template + namespace: default +spec: + template: + spec: + serviceAccountName: unittest-job-reader-sa # Use unittest-specific service account + containers: + - name: batch-worker + image: aibrix/runtime:nightly + - name: llm-engine + image: aibrix/vllm-mock:nightly \ No newline at end of file diff --git a/python/aibrix/tests/batch/testdata/s3_secret.yaml b/python/aibrix/tests/batch/testdata/s3_secret.yaml new file mode 100644 index 000000000..d42bbca9f --- /dev/null +++ b/python/aibrix/tests/batch/testdata/s3_secret.yaml @@ -0,0 +1,14 @@ +# Kubernetes Secret template for S3 credentials +# This is a template that will be populated by the test with actual values +apiVersion: v1 +kind: Secret +metadata: + name: aibrix-s3-credentials + namespace: default +type: Opaque +data: + # Base64 encoded values will be populated by the test + access-key-id: "" + secret-access-key: "" + region: "" + bucket-name: "" \ No newline at end of file diff --git a/python/aibrix/tests/batch/testdata/sample_job_input.jsonl b/python/aibrix/tests/batch/testdata/sample_job_input.jsonl new file mode 100644 index 000000000..63550a9cb --- /dev/null +++ b/python/aibrix/tests/batch/testdata/sample_job_input.jsonl @@ -0,0 +1,10 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Explain quantum computing in simple terms."}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a creative writing assistant."},{"role": "user", "content": "Write a short story about a robot discovering emotions."}],"max_tokens": 1000}} +{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a code reviewer."},{"role": "user", "content": "Review this Python function: def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)"}],"max_tokens": 1000}} +{"custom_id": "request-4", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a cooking instructor."},{"role": "user", "content": "How do I make perfect scrambled eggs?"}],"max_tokens": 1000}} +{"custom_id": "request-5", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a travel advisor."},{"role": "user", "content": "What are the top 5 must-see attractions in Tokyo for first-time visitors?"}],"max_tokens": 1000}} +{"custom_id": "request-6", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a fitness coach."},{"role": "user", "content": "Design a 30-minute beginner workout routine that requires no equipment."}],"max_tokens": 1000}} +{"custom_id": "request-7", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a history teacher."},{"role": "user", "content": "Explain the causes and consequences of the Industrial Revolution."}],"max_tokens": 1000}} +{"custom_id": "request-8", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a language tutor."},{"role": "user", "content": "Teach me the most important Spanish phrases for ordering food at a restaurant."}],"max_tokens": 1000}} +{"custom_id": "request-9", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a science explainer."},{"role": "user", "content": "Why do leaves change color in autumn? Explain the biological process."}],"max_tokens": 1000}} +{"custom_id": "request-10", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a financial advisor."},{"role": "user", "content": "What are the basic principles of investing for a complete beginner?"}],"max_tokens": 1000}} \ No newline at end of file diff --git a/python/aibrix/tests/metadata/test_app_integration.py b/python/aibrix/tests/metadata/test_app_integration.py new file mode 100644 index 000000000..d8d844a98 --- /dev/null +++ b/python/aibrix/tests/metadata/test_app_integration.py @@ -0,0 +1,176 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from unittest.mock import patch + +from fastapi.testclient import TestClient + +# Set required environment variable before importing +os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing") + +from aibrix.metadata.app import build_app + + +def test_build_app_without_k8s_job(): + """Test building app without K8s job support.""" + args = argparse.Namespace( + enable_fastapi_docs=False, + disable_batch_api=True, + disable_file_api=True, + enable_k8s_job=False, + e2e_test=False, + ) + + app = build_app(args) + + # App should not have kopf operator wrapper + assert not hasattr(app.state, "kopf_operator_wrapper") + assert hasattr(app.state, "httpx_client_wrapper") + + +def test_build_app_with_k8s_job(): + """Test building app with K8s job support.""" + args = argparse.Namespace( + enable_fastapi_docs=False, + disable_batch_api=False, + disable_file_api=True, + enable_k8s_job=True, + k8s_namespace="test-namespace", + k8s_job_patch=None, + kopf_startup_timeout=5.0, + kopf_shutdown_timeout=2.0, + e2e_test=False, + ) + + with patch("aibrix.metadata.app.JobCache"): + app = build_app(args) + + # App should have kopf operator wrapper + assert hasattr(app.state, "kopf_operator_wrapper") + assert hasattr(app.state, "httpx_client_wrapper") + assert hasattr(app.state, "batch_driver") + + # Check kopf operator wrapper configuration + kopf_wrapper = app.state.kopf_operator_wrapper + assert kopf_wrapper.namespace == "test-namespace" + assert kopf_wrapper.startup_timeout == 5.0 + assert kopf_wrapper.shutdown_timeout == 2.0 + + +def test_status_endpoint_without_k8s(): + """Test /status endpoint without K8s support.""" + args = argparse.Namespace( + enable_fastapi_docs=False, + disable_batch_api=True, + disable_file_api=True, + enable_k8s_job=False, + e2e_test=False, + ) + + app = build_app(args) + client = TestClient(app) + + response = client.get("/status") + assert response.status_code == 200 + + data = response.json() + assert "httpx_client" in data + assert "kopf_operator" in data + assert "batch_driver" in data + + assert data["httpx_client"]["available"] is True + assert data["kopf_operator"]["available"] is False + assert data["batch_driver"]["available"] is False + + +def test_status_endpoint_with_k8s(): + """Test /status endpoint with K8s support.""" + args = argparse.Namespace( + enable_fastapi_docs=False, + disable_batch_api=False, + disable_file_api=True, + enable_k8s_job=True, + k8s_job_patch=None, + k8s_namespace="test-namespace", + kopf_startup_timeout=5.0, + kopf_shutdown_timeout=2.0, + e2e_test=False, + ) + + with patch("aibrix.metadata.app.JobCache"): + app = build_app(args) + + client = TestClient(app) + + response = client.get("/status") + assert response.status_code == 200 + + data = response.json() + assert "httpx_client" in data + assert "kopf_operator" in data + assert "batch_driver" in data + + assert data["httpx_client"]["available"] is True + assert data["kopf_operator"]["available"] is True + assert data["batch_driver"]["available"] is True + + # Check kopf operator status details + kopf_status = data["kopf_operator"] + assert "is_running" in kopf_status + assert "namespace" in kopf_status + assert kopf_status["namespace"] == "test-namespace" + assert kopf_status["startup_timeout"] == 5.0 + assert kopf_status["shutdown_timeout"] == 2.0 + + +def test_healthz_endpoint(): + """Test /healthz endpoint.""" + args = argparse.Namespace( + enable_fastapi_docs=False, + disable_batch_api=True, + disable_file_api=True, + enable_k8s_job=False, + e2e_test=False, + ) + + app = build_app(args) + client = TestClient(app) + + response = client.get("/healthz") + assert response.status_code == 200 + + data = response.json() + assert data["status"] == "ok" + + +def test_ready_endpoint(): + """Test /ready endpoint.""" + args = argparse.Namespace( + enable_fastapi_docs=False, + disable_batch_api=True, + disable_file_api=True, + enable_k8s_job=False, + e2e_test=False, + ) + + app = build_app(args) + client = TestClient(app) + + response = client.get("/ready") + assert response.status_code == 200 + + data = response.json() + assert data["status"] == "ready" diff --git a/python/aibrix/tests/metadata/test_kopf_integration.py b/python/aibrix/tests/metadata/test_kopf_integration.py new file mode 100644 index 000000000..946fd1876 --- /dev/null +++ b/python/aibrix/tests/metadata/test_kopf_integration.py @@ -0,0 +1,236 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +from unittest.mock import patch + +import pytest + +# Set required environment variable before importing +os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing") + +from aibrix.metadata.core.kopf_operator import KopfOperatorWrapper + + +@pytest.mark.asyncio +async def test_kopf_operator_wrapper_lifecycle(): + """Test the basic lifecycle of KopfOperatorWrapper.""" + # Use short timeouts for testing + wrapper = KopfOperatorWrapper( + namespace="test-namespace", + startup_timeout=2.0, + shutdown_timeout=2.0, + ) + + # Initial state should be not running + assert not wrapper.is_running() + + status = wrapper.get_status() + assert status["is_running"] is False + assert status["namespace"] == "test-namespace" + assert status["startup_timeout"] == 2.0 + assert status["shutdown_timeout"] == 2.0 + + +@pytest.mark.asyncio +async def test_kopf_operator_wrapper_start_stop_mock(): + """Test kopf operator wrapper start/stop with mocked kopf.run.""" + + # Mock kopf.run to avoid actual operator startup + with patch("aibrix.metadata.core.kopf_operator.kopf.run") as mock_kopf_run: + # Mock kopf.run to signal ready immediately and then wait for stop + def mock_run(**kwargs): + ready_flag = kwargs.get("ready_flag") + stop_flag = kwargs.get("stop_flag") + + # Signal ready immediately + if ready_flag: + ready_flag.set() + + # Wait for stop signal + if stop_flag: + stop_flag.wait() + + mock_kopf_run.side_effect = mock_run + + wrapper = KopfOperatorWrapper( + namespace="test-namespace", + startup_timeout=1.0, + shutdown_timeout=1.0, + ) + + # Test start + wrapper.start() + + # Should be running after start + assert wrapper.is_running() + + status = wrapper.get_status() + assert status["is_running"] is True + assert "thread_name" in status + assert "thread_id" in status + assert status["thread_alive"] is True + + # Test stop + wrapper.stop() + + # Should not be running after stop + assert not wrapper.is_running() + + # Verify kopf.run was called with correct parameters + mock_kopf_run.assert_called_once() + call_args = mock_kopf_run.call_args + assert call_args.kwargs["standalone"] is True + assert call_args.kwargs["namespace"] == "test-namespace" + assert call_args.kwargs["peering_name"] is None + + +@pytest.mark.asyncio +async def test_kopf_operator_wrapper_startup_timeout(): + """Test that startup timeout works correctly.""" + + # Mock kopf.run to never signal ready + with patch("aibrix.metadata.core.kopf_operator.kopf.run") as mock_kopf_run: + + def mock_run(**kwargs): + # Never signal ready, just wait + stop_flag = kwargs.get("stop_flag") + if stop_flag: + stop_flag.wait() + + mock_kopf_run.side_effect = mock_run + + wrapper = KopfOperatorWrapper( + namespace="test-namespace", + startup_timeout=0.1, # Very short timeout + shutdown_timeout=0.1, + ) + + # Start should timeout and raise RuntimeError + with pytest.raises(RuntimeError, match="did not start within 0.1s"): + wrapper.start() + + # Should not be running after failed start + assert not wrapper.is_running() + + +@pytest.mark.asyncio +async def test_kopf_operator_wrapper_startup_error(): + """Test that startup errors are properly handled.""" + + # Mock kopf.run to raise an exception + with patch("aibrix.metadata.core.kopf_operator.kopf.run") as mock_kopf_run: + + def mock_run(**kwargs): + ready_flag = kwargs.get("ready_flag") + if ready_flag: + ready_flag.set() # Signal ready first + raise RuntimeError("Mock startup error") + + mock_kopf_run.side_effect = mock_run + + wrapper = KopfOperatorWrapper( + namespace="test-namespace", + startup_timeout=1.0, + shutdown_timeout=1.0, + ) + + # Start should re-raise the exception + with pytest.raises(RuntimeError, match="Mock startup error"): + wrapper.start() + + # Should not be running after failed start + assert not wrapper.is_running() + + +@pytest.mark.asyncio +async def test_kopf_operator_wrapper_double_start(): + """Test that calling start twice doesn't cause issues.""" + + with patch("aibrix.metadata.core.kopf_operator.kopf.run") as mock_kopf_run: + + def mock_run(**kwargs): + ready_flag = kwargs.get("ready_flag") + stop_flag = kwargs.get("stop_flag") + + if ready_flag: + ready_flag.set() + if stop_flag: + stop_flag.wait() + + mock_kopf_run.side_effect = mock_run + + wrapper = KopfOperatorWrapper( + namespace="test-namespace", + startup_timeout=1.0, + shutdown_timeout=1.0, + ) + + # First start should work + wrapper.start() + assert wrapper.is_running() + + # Second start should be ignored (no exception) + wrapper.start() + assert wrapper.is_running() + + # Cleanup + wrapper.stop() + assert not wrapper.is_running() + + +@pytest.mark.asyncio +async def test_kopf_operator_wrapper_stop_not_running(): + """Test that calling stop when not running doesn't cause issues.""" + + wrapper = KopfOperatorWrapper() + + # Stop should work even when not running + wrapper.stop() + assert not wrapper.is_running() + + +def test_kopf_operator_wrapper_threading(): + """Test that the kopf operator runs in a separate thread.""" + + main_thread_id = threading.get_ident() + operator_thread_id = None + + with patch("aibrix.metadata.core.kopf_operator.kopf.run") as mock_kopf_run: + + def mock_run(**kwargs): + nonlocal operator_thread_id + operator_thread_id = threading.get_ident() + + ready_flag = kwargs.get("ready_flag") + stop_flag = kwargs.get("stop_flag") + + if ready_flag: + ready_flag.set() + if stop_flag: + stop_flag.wait() + + mock_kopf_run.side_effect = mock_run + + wrapper = KopfOperatorWrapper(startup_timeout=1.0, shutdown_timeout=1.0) + + wrapper.start() + + # Verify kopf runs in a different thread + assert operator_thread_id is not None + assert operator_thread_id != main_thread_id + + # Cleanup + wrapper.stop() diff --git a/python/aibrix/tests/metadata/test_secret_gen.py b/python/aibrix/tests/metadata/test_secret_gen.py new file mode 100644 index 000000000..3063ac178 --- /dev/null +++ b/python/aibrix/tests/metadata/test_secret_gen.py @@ -0,0 +1,212 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the secret_gen module. +""" + +import base64 +import os +from unittest.mock import Mock, patch + +import pytest + +from aibrix.metadata.secret_gen import SecretGenerator + + +class TestSecretGenerator: + """Test cases for SecretGenerator class.""" + + def test_init(self): + """Test SecretGenerator initialization.""" + generator = SecretGenerator(namespace="test-namespace") + assert generator.namespace == "test-namespace" + assert generator.core_v1 is not None + assert generator.setting_dir.name == "setting" + + def test_encode_data(self): + """Test base64 encoding of secret data.""" + generator = SecretGenerator() + + data = { + "key1": "value1", + "key2": "value2", + "key3": None, # Should be filtered out + } + + encoded = generator._encode_data(data) + + assert "key1" in encoded + assert "key2" in encoded + assert "key3" not in encoded + + # Verify base64 encoding + assert encoded["key1"] == base64.b64encode("value1".encode()).decode() + assert encoded["key2"] == base64.b64encode("value2".encode()).decode() + + @patch("aibrix.metadata.secret_gen.open") + @patch("aibrix.metadata.secret_gen.yaml.safe_load") + def test_load_template(self, mock_yaml_load, mock_open): + """Test loading secret templates.""" + generator = SecretGenerator() + + mock_template = {"apiVersion": "v1", "kind": "Secret"} + mock_yaml_load.return_value = mock_template + + result = generator._load_template("test_template.yaml") + + assert result == mock_template + mock_open.assert_called_once() + mock_yaml_load.assert_called_once() + + @patch("aibrix.metadata.secret_gen.boto3.Session") + def test_get_s3_credentials_success(self, mock_session): + """Test successful S3 credentials retrieval.""" + generator = SecretGenerator() + + # Mock boto3 session and credentials + mock_credentials = Mock() + mock_credentials.access_key = "test_access_key" + mock_credentials.secret_key = "test_secret_key" + + mock_session_instance = Mock() + mock_session_instance.get_credentials.return_value = mock_credentials + mock_session_instance.region_name = "us-west-2" + mock_session.return_value = mock_session_instance + + credentials = generator._get_s3_credentials() + + assert credentials["access_key"] == "test_access_key" + assert credentials["secret_key"] == "test_secret_key" + assert credentials["region"] == "us-west-2" + + @patch("aibrix.metadata.secret_gen.boto3.Session") + def test_get_s3_credentials_no_credentials(self, mock_session): + """Test S3 credentials retrieval when no credentials found.""" + generator = SecretGenerator() + + mock_session_instance = Mock() + mock_session_instance.get_credentials.return_value = None + mock_session.return_value = mock_session_instance + + with pytest.raises(RuntimeError, match="No AWS credentials found"): + generator._get_s3_credentials() + + @patch.dict( + os.environ, + { + "TOS_ACCESS_KEY": "tos_access", + "TOS_SECRET_KEY": "tos_secret", + "TOS_ENDPOINT": "https://tos.example.com", + "TOS_REGION": "us-east-1", + }, + ) + def test_get_tos_credentials_success(self): + """Test successful TOS credentials retrieval.""" + generator = SecretGenerator() + + credentials = generator._get_tos_credentials() + + assert credentials["access_key"] == "tos_access" + assert credentials["secret_key"] == "tos_secret" + assert credentials["endpoint"] == "https://tos.example.com" + assert credentials["region"] == "us-east-1" + + @patch.dict( + os.environ, + { + "TOS_ACCESS_KEY": "tos_access", + # Missing other required variables + }, + clear=True, + ) + def test_get_tos_credentials_missing_vars(self): + """Test TOS credentials retrieval with missing environment variables.""" + generator = SecretGenerator() + + with pytest.raises(RuntimeError, match="Missing TOS environment variables"): + generator._get_tos_credentials() + + @patch("aibrix.metadata.secret_gen.client.CoreV1Api") + def test_secret_exists_true(self, mock_core_v1_class): + """Test secret_exists when secret exists.""" + mock_core_v1 = Mock() + mock_core_v1_class.return_value = mock_core_v1 + + generator = SecretGenerator() + generator.core_v1 = mock_core_v1 + + # Mock successful read + mock_core_v1.read_namespaced_secret.return_value = Mock() + + result = generator.secret_exists("test-secret") + + assert result is True + mock_core_v1.read_namespaced_secret.assert_called_once_with( + name="test-secret", namespace="default" + ) + + @patch("aibrix.metadata.secret_gen.client.CoreV1Api") + def test_secret_exists_false(self, mock_core_v1_class): + """Test secret_exists when secret doesn't exist.""" + from kubernetes import client + + mock_core_v1 = Mock() + mock_core_v1_class.return_value = mock_core_v1 + + generator = SecretGenerator() + generator.core_v1 = mock_core_v1 + + # Mock 404 exception + mock_exception = client.ApiException(status=404) + mock_core_v1.read_namespaced_secret.side_effect = mock_exception + + result = generator.secret_exists("test-secret") + + assert result is False + + @patch("aibrix.metadata.secret_gen.client.CoreV1Api") + def test_delete_secret_success(self, mock_core_v1_class): + """Test successful secret deletion.""" + mock_core_v1 = Mock() + mock_core_v1_class.return_value = mock_core_v1 + + generator = SecretGenerator() + generator.core_v1 = mock_core_v1 + + result = generator.delete_secret("test-secret") + + assert result is True + mock_core_v1.delete_namespaced_secret.assert_called_once_with( + name="test-secret", namespace="default" + ) + + @patch("aibrix.metadata.secret_gen.client.CoreV1Api") + def test_delete_secret_not_found(self, mock_core_v1_class): + """Test secret deletion when secret doesn't exist.""" + from kubernetes import client + + mock_core_v1 = Mock() + mock_core_v1_class.return_value = mock_core_v1 + + generator = SecretGenerator() + generator.core_v1 = mock_core_v1 + + # Mock 404 exception + mock_exception = client.ApiException(status=404) + mock_core_v1.delete_namespaced_secret.side_effect = mock_exception + + result = generator.delete_secret("test-secret") + + assert result is False diff --git a/python/aibrix/tests/storage/test_reader.py b/python/aibrix/tests/storage/test_reader.py index fd2a699e6..ffbbad121 100644 --- a/python/aibrix/tests/storage/test_reader.py +++ b/python/aibrix/tests/storage/test_reader.py @@ -254,7 +254,7 @@ async def test_reader_with_custom_async_file_object(self): """Test Reader with custom async file-like objects.""" test_data = b"Hello, async custom world!" - init_storage_loop_thread() + init_storage_loop_thread("test_reader_with_custom_async_file_object") class AsyncFileObject: def __init__(self, data: bytes): @@ -919,7 +919,7 @@ async def test_reader_size_limiter_with_async_file_object(self): def size_limiter(bytes_read, bytes_to_read): return bytes_read + bytes_to_read <= 10 - init_storage_loop_thread() + init_storage_loop_thread("test_reader_size_limiter_with_async_file_object") class AsyncFileObject: def __init__(self, data: bytes): diff --git a/python/aibrix/tests/storage/test_redis_storage.py b/python/aibrix/tests/storage/test_redis_storage.py index ababfd16f..77da700fc 100644 --- a/python/aibrix/tests/storage/test_redis_storage.py +++ b/python/aibrix/tests/storage/test_redis_storage.py @@ -342,3 +342,210 @@ async def test_redis_hierarchical_token_pagination(): finally: await storage.close() + + +def test_feature_detection(): + """Test feature detection methods.""" + storage = RedisStorage() + + # Redis should support all advanced features + assert storage.is_ttl_supported() is True + assert storage.is_set_if_not_exists_supported() is True + assert storage.is_set_if_exists_supported() is True + + +def test_put_object_options_validation(): + """Test PutObjectOptions validation.""" + from aibrix.storage.base import PutObjectOptions + + # Valid options + options = PutObjectOptions() + assert options.ttl_seconds is None + assert options.ttl_milliseconds is None + assert options.set_if_not_exists is False + assert options.set_if_exists is False + + # Valid options with TTL seconds + options = PutObjectOptions(ttl_seconds=60) + assert options.ttl_seconds == 60 + + # Valid conditional options + options = PutObjectOptions(set_if_not_exists=True) + assert options.set_if_not_exists is True + + # Invalid: both conditions + with pytest.raises( + ValueError, match="Cannot specify both set_if_not_exists and set_if_exists" + ): + PutObjectOptions(set_if_not_exists=True, set_if_exists=True) + + # Invalid: both TTL types + with pytest.raises( + ValueError, match="Cannot specify both ttl_seconds and ttl_milliseconds" + ): + PutObjectOptions(ttl_seconds=60, ttl_milliseconds=60000) + + +def test_put_object_options_builder(): + """Test PutObjectOptionsBuilder helper class.""" + from aibrix.storage.base import PutObjectOptionsBuilder + + # Test building with TTL seconds + options = PutObjectOptionsBuilder().ttl_seconds(60).build() + assert options.ttl_seconds == 60 + assert options.ttl_milliseconds is None + + # Test building with TTL milliseconds + options = PutObjectOptionsBuilder().ttl_milliseconds(60000).build() + assert options.ttl_milliseconds == 60000 + assert options.ttl_seconds is None + + # Test building with conditional operations + options = PutObjectOptionsBuilder().if_not_exists().build() + assert options.set_if_not_exists is True + assert options.set_if_exists is False + + options = PutObjectOptionsBuilder().if_exists().build() + assert options.set_if_exists is True + assert options.set_if_not_exists is False + + # Test chaining + options = PutObjectOptionsBuilder().ttl_seconds(300).if_not_exists().build() + assert options.ttl_seconds == 300 + assert options.set_if_not_exists is True + + +@requires_redis +@pytest.mark.asyncio +async def test_redis_put_object_with_ttl(): + """Test Redis put_object with TTL options (requires Redis running).""" + storage = get_redis_storage() + try: + from aibrix.storage.base import PutObjectOptions + + # Test TTL in seconds + options = PutObjectOptions(ttl_seconds=1) # 1 second TTL + result = await storage.put_object("test_ttl_key", b"test_data", options=options) + assert result is True + + # Verify data exists initially + data = await storage.get_object("test_ttl_key") + assert data == b"test_data" + + # Wait for TTL to expire + await asyncio.sleep(1.1) + + # Verify data expired + with pytest.raises(FileNotFoundError): + await storage.get_object("test_ttl_key") + + # Test TTL in milliseconds + options = PutObjectOptions(ttl_milliseconds=500) # 500ms TTL + result = await storage.put_object( + "test_ttl_ms_key", b"test_data_ms", options=options + ) + assert result is True + + # Verify data exists initially + data = await storage.get_object("test_ttl_ms_key") + assert data == b"test_data_ms" + + # Wait for TTL to expire + await asyncio.sleep(0.6) + + # Verify data expired + with pytest.raises(FileNotFoundError): + await storage.get_object("test_ttl_ms_key") + + finally: + await storage.close() + + +@requires_redis +@pytest.mark.asyncio +async def test_redis_put_object_conditional(): + """Test Redis put_object conditional operations (requires Redis running).""" + storage = get_redis_storage() + try: + from aibrix.storage.base import PutObjectOptions + + key = "test_conditional_key" + + # Ensure key doesn't exist + await storage.delete_object(key) + + # Test SET IF NOT EXISTS (NX) - should succeed + options = PutObjectOptions(set_if_not_exists=True) + result = await storage.put_object(key, b"first_value", options=options) + assert result is True + + # Verify data was set + data = await storage.get_object(key) + assert data == b"first_value" + + # Test SET IF NOT EXISTS again - should fail since key exists + result = await storage.put_object(key, b"second_value", options=options) + assert result is False + + # Verify data unchanged + data = await storage.get_object(key) + assert data == b"first_value" + + # Test SET IF EXISTS (XX) - should succeed since key exists + options = PutObjectOptions(set_if_exists=True) + result = await storage.put_object(key, b"updated_value", options=options) + assert result is True + + # Verify data was updated + data = await storage.get_object(key) + assert data == b"updated_value" + + # Delete key and test SET IF EXISTS - should fail + await storage.delete_object(key) + result = await storage.put_object(key, b"should_fail", options=options) + assert result is False + + # Verify key doesn't exist + with pytest.raises(FileNotFoundError): + await storage.get_object(key) + + finally: + await storage.close() + + +@requires_redis +@pytest.mark.asyncio +async def test_redis_put_object_combined_options(): + """Test Redis put_object with combined TTL and conditional options (requires Redis running).""" + storage = get_redis_storage() + try: + from aibrix.storage.base import PutObjectOptionsBuilder + + key = "test_combined_key" + + # Ensure key doesn't exist + await storage.delete_object(key) + + # Test NX with TTL + options = PutObjectOptionsBuilder().ttl_seconds(2).if_not_exists().build() + + result = await storage.put_object(key, b"ttl_nx_value", options=options) + assert result is True + + # Verify data exists + data = await storage.get_object(key) + assert data == b"ttl_nx_value" + + # Try to set again with NX - should fail + result = await storage.put_object(key, b"should_fail", options=options) + assert result is False + + # Wait for TTL to expire + await asyncio.sleep(2.1) + + # Verify data expired + with pytest.raises(FileNotFoundError): + await storage.get_object(key) + + finally: + await storage.close() diff --git a/python/aibrix/tests/storage/test_storage.py b/python/aibrix/tests/storage/test_storage.py index 8903cfbd0..77daf7db4 100644 --- a/python/aibrix/tests/storage/test_storage.py +++ b/python/aibrix/tests/storage/test_storage.py @@ -39,12 +39,13 @@ async def test_put_and_get_string(self, storage: BaseStorage): key = "test/string.txt" content = "Hello, World! 🌍" - # Store string - await storage.put_object(key, content) + # Store string - should return True + result = await storage.put_object(key, content) + assert result is True # Retrieve and verify - result = await storage.get_object(key) - assert result.decode("utf-8") == content + data = await storage.get_object(key) + assert data.decode("utf-8") == content # Cleanup await storage.delete_object(key) @@ -58,11 +59,12 @@ async def test_put_and_get_text_io(self, storage: BaseStorage): # Store string data = io.StringIO(content) assert isinstance(data, io.TextIOBase) - await storage.put_object(key, data) + result = await storage.put_object(key, data) + assert result is True # Retrieve and verify - result = await storage.get_object(key) - assert result.decode("utf-8") == content + data_result = await storage.get_object(key) + assert data_result.decode("utf-8") == content # Cleanup await storage.delete_object(key) @@ -797,6 +799,80 @@ async def test_multipart_upload_edge_cases(self, storage: BaseStorage): await storage.delete_object(key2) await storage.delete_object(key3) + @pytest.mark.asyncio + async def test_put_object_options_unsupported(self, storage: BaseStorage): + """Test that non-Redis storage backends reject unsupported options.""" + from aibrix.storage.base import PutObjectOptions + from aibrix.storage.redis import RedisStorage + + key = "test/options_unsupported.txt" + content = "test content" + + # Skip this test for Redis storage since it supports all options + if isinstance(storage, RedisStorage): + pytest.skip("Redis storage supports all options") + + # Test TTL not supported + options = PutObjectOptions(ttl_seconds=60) + with pytest.raises(ValueError, match="TTL not supported"): + await storage.put_object(key, content, options=options) + + # Test SET IF NOT EXISTS not supported + options = PutObjectOptions(set_if_not_exists=True) + with pytest.raises(ValueError, match="SET IF NOT EXISTS not supported"): + await storage.put_object(key, content, options=options) + + # Test SET IF EXISTS not supported + options = PutObjectOptions(set_if_exists=True) + with pytest.raises(ValueError, match="SET IF EXISTS not supported"): + await storage.put_object(key, content, options=options) + + @pytest.mark.asyncio + async def test_put_object_options_none(self, storage: BaseStorage): + """Test that put_object works with options=None.""" + key = "test/options_none.txt" + content = "test content with no options" + + # Should work without options + result = await storage.put_object(key, content, options=None) + assert result is True + + # Verify content + data = await storage.get_object(key) + assert data.decode("utf-8") == content + + # Cleanup + await storage.delete_object(key) + + def test_feature_detection_default(self, storage: BaseStorage): + """Test that storage backends correctly report feature support.""" + from aibrix.storage.redis import RedisStorage + + if isinstance(storage, RedisStorage): + # Redis should support all features + assert storage.is_ttl_supported() is True + assert storage.is_set_if_not_exists_supported() is True + assert storage.is_set_if_exists_supported() is True + else: + # Other storage backends should not support advanced features + assert storage.is_ttl_supported() is False + assert storage.is_set_if_not_exists_supported() is False + assert storage.is_set_if_exists_supported() is False + + @pytest.mark.asyncio + async def test_put_object_return_type(self, storage: BaseStorage): + """Test that put_object returns boolean values correctly.""" + key = "test/return_type.txt" + content = "test return type" + + # Normal put should return True + result = await storage.put_object(key, content) + assert isinstance(result, bool) + assert result is True + + # Cleanup + await storage.delete_object(key) + # Conditionally add S3 and TOS storage to test parameters if available def pytest_generate_tests(metafunc): diff --git a/python/aibrix/tests/test_metadata_logger.py b/python/aibrix/tests/test_logger.py similarity index 99% rename from python/aibrix/tests/test_metadata_logger.py rename to python/aibrix/tests/test_logger.py index 8320b147f..0909330ad 100644 --- a/python/aibrix/tests/test_metadata_logger.py +++ b/python/aibrix/tests/test_logger.py @@ -15,7 +15,7 @@ import pytest -from aibrix.metadata.logger import init_logger +from aibrix.logger import init_logger def test_init_logger_basic_functionality(): From 375a42b4fde397cf9f304c164026bd6c77ceda94 Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Thu, 2 Oct 2025 14:27:25 -0700 Subject: [PATCH 02/11] envoy integration Signed-off-by: Jingyuan Zhang --- config/api-extension/deployment.yaml | 88 +++++ config/api-extension/kustomization.yaml | 8 + config/api-extension/rbac.yaml | 46 +++ config/api-extension/service.yaml | 13 + config/default/kustomization.yaml | 1 + .../gateway-plugin/gateway-plugin.yaml | 21 ++ .../templates/api-extension/deployment.yaml | 52 +++ dist/chart/templates/api-extension/rbac.yaml | 54 +++ .../templates/api-extension/service.yaml | 16 + .../templates/gateway-plugin/httproute.yaml | 23 ++ dist/chart/values.yaml | 20 ++ python/aibrix/.python-version | 1 + python/aibrix/tests/e2e/README.md | 91 +++++ python/aibrix/tests/e2e/__init__.py | 21 ++ python/aibrix/tests/e2e/test_batch_api.py | 321 ++++++++++++++++++ 15 files changed, 776 insertions(+) create mode 100644 config/api-extension/deployment.yaml create mode 100644 config/api-extension/kustomization.yaml create mode 100644 config/api-extension/rbac.yaml create mode 100644 config/api-extension/service.yaml create mode 100644 dist/chart/templates/api-extension/deployment.yaml create mode 100644 dist/chart/templates/api-extension/rbac.yaml create mode 100644 dist/chart/templates/api-extension/service.yaml create mode 100644 python/aibrix/.python-version create mode 100644 python/aibrix/tests/e2e/README.md create mode 100644 python/aibrix/tests/e2e/__init__.py create mode 100644 python/aibrix/tests/e2e/test_batch_api.py diff --git a/config/api-extension/deployment.yaml b/config/api-extension/deployment.yaml new file mode 100644 index 000000000..c0444ee0e --- /dev/null +++ b/config/api-extension/deployment.yaml @@ -0,0 +1,88 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: api-extension + namespace: system +spec: + replicas: 1 + selector: + matchLabels: + app: api-extension + template: + metadata: + labels: + app: api-extension + spec: + serviceAccountName: api-extension-sa + automountServiceAccountToken: true + containers: + - name: api-extension + image: aibrix/runtime:nightly + command: + - aibrix_api_extension + - --enable-k8s-job + ports: + - containerPort: 8100 + resources: + limits: + cpu: 1 + memory: 1Gi + requests: + cpu: 1 + memory: 1Gi + env: + # Metadata store configuration + - name: REDIS_HOST + value: "aibrix-redis-master.aibrix-system.svc.cluster.local" + - name: REDIS_PORT + value: "6379" + - name: REDIS_DB + value: "0" + # Object store configuration + # Comment the following lines to disable S3 as the object store + - name: STORAGE_AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: access-key-id + - name: STORAGE_AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: secret-access-key + - name: STORAGE_AWS_REGION + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: region + - name: STORAGE_AWS_BUCKET + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: bucket-name + # Uncomment the following lines to enable TOS as the object store + # - name: STORAGE_TOS_ACCESS_KEY + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: access-key + # - name: STORAGE_TOS_SECRET_KEY + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: secret-key + # - name: STORAGE_TOS_ENDPOINT + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: endpoint + # - name: STORAGE_TOS_REGION + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: region + # - name: STORAGE_TOS_BUCKET + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: bucket-name \ No newline at end of file diff --git a/config/api-extension/kustomization.yaml b/config/api-extension/kustomization.yaml new file mode 100644 index 000000000..0cfe867fb --- /dev/null +++ b/config/api-extension/kustomization.yaml @@ -0,0 +1,8 @@ +resources: +- deployment.yaml +- service.yaml +- rbac.yaml + +labels: + - pairs: + app.kubernetes.io/component: aibrix-api-extension \ No newline at end of file diff --git a/config/api-extension/rbac.yaml b/config/api-extension/rbac.yaml new file mode 100644 index 000000000..4da22e329 --- /dev/null +++ b/config/api-extension/rbac.yaml @@ -0,0 +1,46 @@ +# Service Account for api extension +apiVersion: v1 +kind: ServiceAccount +metadata: + name: api-extension-sa + namespace: system +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: api-extension-clusterrole +rules: + - apiGroups: ["batch"] # for batch job watching + resources: ["jobs"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + - apiGroups: ["coordination.k8s.io"] # for kopf high availability + resources: ["leases"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + - apiGroups: ["apiextensions.k8s.io"] # required by kopf + resources: ["customresourcedefinitions"] + verbs: ["get", "list", "watch"] + - apiGroups: [""] # required by kopf + resources: ["namespaces"] + verbs: ["list", "watch"] + - apiGroups: [""] # for ServiceAccount management + resources: ["serviceaccounts"] + verbs: ["get", "create", "update", "patch", "delete"] + - apiGroups: ["rbac.authorization.k8s.io"] # for Role management + resources: ["roles"] + verbs: ["get", "create", "update", "patch", "delete"] + - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management + resources: ["rolebindings"] + verbs: ["get", "create", "update", "patch", "delete"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: api-extension-clusterrole-binding +subjects: + - kind: ServiceAccount + name: api-extension-sa + namespace: system +roleRef: + kind: ClusterRole + name: api-extension-clusterrole + apiGroup: rbac.authorization.k8s.io diff --git a/config/api-extension/service.yaml b/config/api-extension/service.yaml new file mode 100644 index 000000000..36a9f177d --- /dev/null +++ b/config/api-extension/service.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Service +metadata: + name: api-extension + namespace: system +spec: + selector: + app: api-extension + ports: + - protocol: TCP + port: 8100 + targetPort: 8100 + type: ClusterIP \ No newline at end of file diff --git a/config/default/kustomization.yaml b/config/default/kustomization.yaml index aa90415d1..fb255b933 100644 --- a/config/default/kustomization.yaml +++ b/config/default/kustomization.yaml @@ -25,6 +25,7 @@ resources: - ../gateway - ../metadata - ../gpu-optimizer +- ../api-extension - ../dependency/kuberay-operator # [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in # crd/kustomization.yaml diff --git a/config/gateway/gateway-plugin/gateway-plugin.yaml b/config/gateway/gateway-plugin/gateway-plugin.yaml index 2c0387939..e62d82d13 100644 --- a/config/gateway/gateway-plugin/gateway-plugin.yaml +++ b/config/gateway/gateway-plugin/gateway-plugin.yaml @@ -212,3 +212,24 @@ spec: response: body: Streamed messageTimeout: 60s +--- +# this is a route for extension APIs implemented in python/aibrix/metadata +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: reserved-router-extensions-endpoint + namespace: system +spec: + parentRefs: + - name: aibrix-eg + rules: + - matches: + - path: + type: PathPrefix + value: /v1/files + - path: + type: PathPrefix + value: /v1/batches + backendRefs: + - name: aibrix-api-extension + port: 8100 diff --git a/dist/chart/templates/api-extension/deployment.yaml b/dist/chart/templates/api-extension/deployment.yaml new file mode 100644 index 000000000..f8e0b270d --- /dev/null +++ b/dist/chart/templates/api-extension/deployment.yaml @@ -0,0 +1,52 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: aibrix-api-extension + namespace: {{ .Release.Namespace }} + labels: + {{- include "chart.labels" . | nindent 4 }} + app.kubernetes.io/component: aibrix-api-extension +spec: + replicas: {{ .Values.apiExtension.replicaCount }} + selector: + matchLabels: + {{- include "chart.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: aibrix-api-extension + template: + metadata: + labels: + {{- include "chart.labels" . | nindent 8 }} + app.kubernetes.io/component: aibrix-api-extension + spec: + serviceAccountName: aibrix-api-extension-sa + automountServiceAccountToken: true + {{- include "chart.imagePullSecrets" (dict "componentSecrets" .Values.apiExtension.imagePullSecrets "globalSecrets" .Values.global.imagePullSecrets) | nindent 6 }} + containers: + - name: api-extension + image: {{ .Values.apiExtension.container.image.repository }}:{{ .Values.apiExtension.container.image.tag }} + command: + - aibrix_api_extension + - --enable-k8s-job + ports: + - containerPort: 8100 + resources: {{ toYaml .Values.apiExtension.container.resources | nindent 12 }} + # TODO: Add liveness and readiness probes + env: + # Metadata store configuration + - name: REDIS_HOST + value: {{ .Values.apiExtension.dependencies.redis.host }} + - name: REDIS_PORT + value: "{{ .Values.apiExtension.dependencies.redis.port }}" + {{- if .Values.metadata.redis.enablePassword }} + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: aibrix-redis + key: redis-password + {{- end }} + # Object store configuration + # TODO: Add object store config + {{- with .Values.apiExtension.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/dist/chart/templates/api-extension/rbac.yaml b/dist/chart/templates/api-extension/rbac.yaml new file mode 100644 index 000000000..0950b5538 --- /dev/null +++ b/dist/chart/templates/api-extension/rbac.yaml @@ -0,0 +1,54 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: aibrix-api-extension-sa + namespace: {{ .Release.Namespace }} + labels: + {{- include "chart.labels" . | nindent 4 }} + app.kubernetes.io/component: aibrix-api-extension +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: aibrix-api-extension-clusterrole + labels: + {{- include "chart.labels" . | nindent 4 }} + app.kubernetes.io/component: aibrix-api-extension +rules: + - apiGroups: ["batch"] # for batch job watching + resources: ["jobs"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + - apiGroups: ["coordination.k8s.io"] # for kopf high availability + resources: ["leases"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + - apiGroups: ["apiextensions.k8s.io"] # required by kopf + resources: ["customresourcedefinitions"] + verbs: ["get", "list", "watch"] + - apiGroups: [""] # required by kopf + resources: ["namespaces"] + verbs: ["list", "watch"] + - apiGroups: [""] # for ServiceAccount management + resources: ["serviceaccounts"] + verbs: ["get", "create", "update", "patch", "delete"] + - apiGroups: ["rbac.authorization.k8s.io"] # for Role management + resources: ["roles"] + verbs: ["get", "create", "update", "patch", "delete"] + - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management + resources: ["rolebindings"] + verbs: ["get", "create", "update", "patch", "delete"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: aibrix-api-extension-clusterrole-binding + labels: + {{- include "chart.labels" . | nindent 4 }} + app.kubernetes.io/component: aibrix-api-extension +subjects: + - kind: ServiceAccount + name: aibrix-api-extension-sa + namespace: {{ .Release.Namespace }} +roleRef: + kind: ClusterRole + name: aibrix-api-extension-clusterrole + apiGroup: rbac.authorization.k8s.io diff --git a/dist/chart/templates/api-extension/service.yaml b/dist/chart/templates/api-extension/service.yaml new file mode 100644 index 000000000..aa2e006d2 --- /dev/null +++ b/dist/chart/templates/api-extension/service.yaml @@ -0,0 +1,16 @@ +apiVersion: v1 +kind: Service +metadata: + name: aibrix-api-extension + namespace: {{ .Release.Namespace }} + labels: + {{- include "chart.labels" . | nindent 4 }} + app.kubernetes.io/component: aibrix-api-extension +spec: + selector: + app.kubernetes.io/component: aibrix-api-extension + ports: + - protocol: TCP + port: 8100 + targetPort: 8100 + type: ClusterIP diff --git a/dist/chart/templates/gateway-plugin/httproute.yaml b/dist/chart/templates/gateway-plugin/httproute.yaml index 1ea8fefa9..5ea3a1b9a 100644 --- a/dist/chart/templates/gateway-plugin/httproute.yaml +++ b/dist/chart/templates/gateway-plugin/httproute.yaml @@ -49,3 +49,26 @@ spec: backendRefs: - name: aibrix-metadata-service port: 8090 +--- +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: aibrix-reserved-router-extensions-endpoint + namespace: {{ .Release.Namespace }} + labels: + {{- include "chart.labels" . | nindent 4 }} + app.kubernetes.io/component: aibrix-api-extension +spec: + parentRefs: + - name: aibrix-eg + rules: + - matches: + - path: + type: PathPrefix + value: /v1/files + - path: + type: PathPrefix + value: /v1/batches + backendRefs: + - name: aibrix-api-extension + port: 8100 diff --git a/dist/chart/values.yaml b/dist/chart/values.yaml index 38c52d95e..4a5340b48 100644 --- a/dist/chart/values.yaml +++ b/dist/chart/values.yaml @@ -169,6 +169,26 @@ metadata: enablePassword: false password: "" +apiExtension: + replicaCount: 1 + imagePullSecrets: [] + container: + image: + repository: aibrix/runtime + tag: nightly + resources: + limits: + cpu: "1" + memory: 1Gi + requests: + cpu: "1" + memory: 1Gi + dependencies: + redis: + host: aibrix-redis-master + port: 6379 + tolerations: [] + # [CRDs]: To enable the CRDs crd: # This option determines whether the CRDs are included diff --git a/python/aibrix/.python-version b/python/aibrix/.python-version new file mode 100644 index 000000000..2c0733315 --- /dev/null +++ b/python/aibrix/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/python/aibrix/tests/e2e/README.md b/python/aibrix/tests/e2e/README.md new file mode 100644 index 000000000..8ebba1fdc --- /dev/null +++ b/python/aibrix/tests/e2e/README.md @@ -0,0 +1,91 @@ +# End-to-End Tests + +This directory contains end-to-end tests for Aibrix services that run against real service instances. + +## Files + +- `test_batch_api.py` - E2E tests for OpenAI Batch API endpoints + +## Prerequisites + +1. **Running Aibrix Service**: Ensure the Aibrix service is accessible at `http://localhost:8888/` + ```bash + # Example: Access the service using port-forward + kubectl -n envoy-gateway-system port-forward service/envoy-aibrix-system-aibrix-eg-903790dc 8888:80 + ``` + +2. **Generate Credentials**: Ensure object store is acceesible. Using S3 as an example: + ```bash + python ../../scripts/generate_secrets.py s3 --bucket + ``` + The script will read s3 credentials setup using ```aws configure``` + +## Running Tests + +### All E2E Tests +```bash +cd /path/to/aibrix/python/aibrix +pytest tests/e2e/ -v +``` + +### Batch API Tests Only +```bash +cd /path/to/aibrix/python/aibrix +pytest tests/e2e/test_batch_api.py -v +``` + +### Specific Test +```bash +cd /path/to/aibrix/python/aibrix +pytest tests/e2e/test_batch_api.py::test_batch_api_e2e_real_service -v +``` + +### With Detailed Output +```bash +cd /path/to/aibrix/python/aibrix +pytest tests/e2e/test_batch_api.py -v -s +``` + +## Test Structure + +### Service Health Fixture +- `service_health` - Session-scoped fixture that checks service availability +- Tests `/healthz` endpoint for basic health checking +- Automatically skips all tests if service is not available + +### Service Availability Test +- `test_batch_api_service_availability()` - Verifies service is running and accessible +- Tests basic API endpoint accessibility + +### Complete Workflow Test +- `test_batch_api_e2e_real_service()` - Tests the complete batch processing workflow +- Upload → Create → Poll → Download → Verify + +### Error Handling Test +- `test_batch_api_error_handling_real_service()` - Tests error scenarios +- Invalid inputs, non-existent resources + +## API Endpoints + +### Health Endpoints +- `/healthz` - General service health check + +## Configuration + +Tests connect to `http://localhost:8888` by default. The service URL is hardcoded in the test functions. + +## Expected Output + +Successful test run: +``` +tests/e2e/test_batch_api.py::test_batch_api_service_availability PASSED +tests/e2e/test_batch_api.py::test_batch_api_e2e_real_service PASSED +tests/e2e/test_batch_api.py::test_batch_api_error_handling_real_service PASSED +``` + +If service is not available: +``` +tests/e2e/test_batch_api.py::test_batch_api_service_availability SKIPPED +tests/e2e/test_batch_api.py::test_batch_api_e2e_real_service SKIPPED +tests/e2e/test_batch_api.py::test_batch_api_error_handling_real_service SKIPPED +``` \ No newline at end of file diff --git a/python/aibrix/tests/e2e/__init__.py b/python/aibrix/tests/e2e/__init__.py new file mode 100644 index 000000000..413d7bed8 --- /dev/null +++ b/python/aibrix/tests/e2e/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end tests for Aibrix services. + +This package contains tests that validate complete workflows against real +running services, as opposed to unit tests that test individual components +in isolation. +""" diff --git a/python/aibrix/tests/e2e/test_batch_api.py b/python/aibrix/tests/e2e/test_batch_api.py new file mode 100644 index 000000000..eabfe8c00 --- /dev/null +++ b/python/aibrix/tests/e2e/test_batch_api.py @@ -0,0 +1,321 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end test for OpenAI Batch API against real service. + +This test validates the complete batch processing workflow against a real +Aibrix service running at http://localhost:8888/. + +Test workflow: +1. Upload sample input file via Files API +2. Create batch job via Batch API +3. Poll job status until completion +4. Download and verify output via Files API +5. Verify batch list API works + +Prerequisites: +- Aibrix service running at http://localhost:8888/ +- Service configured with proper storage backend +- Network connectivity to the service + +Usage: + pytest tests/e2e/test_batch_api.py -v + pytest tests/e2e/test_batch_api.py::test_batch_api_e2e_real_service -v +""" + +import asyncio +import copy +import json +from typing import Any, Dict + +import httpx +import pytest + + +def generate_batch_input_data(num_requests: int = 3) -> str: + """Generate test batch input data and return the content as string.""" + base_request: Dict[str, Any] = { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo-0125", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello world!"}, + ], + "max_tokens": 1000, + }, + } + + lines = [] + for i in range(num_requests): + request = copy.deepcopy(base_request) + request["custom_id"] = f"request-{i+1}" + request["body"]["messages"][1]["content"] = f"Hello from request {i+1}!" + lines.append(json.dumps(request)) + + return "\n".join(lines) + + +def verify_batch_output_content(output_content: str, expected_requests: int) -> bool: + """Verify that batch output content has the expected structure.""" + lines = output_content.strip().split("\n") + + if len(lines) != expected_requests: + print(f"Expected {expected_requests} output lines, got {len(lines)}") + return False + + for i, line in enumerate(lines): + try: + output = json.loads(line) + + # Check required fields in OpenAI batch response format + required_fields = ["id", "custom_id", "response"] + for field in required_fields: + if field not in output: + print(f"Missing required field '{field}' in response {i+1}") + return False + + # Verify custom_id matches expected pattern + expected_custom_id = f"request-{i+1}" + if output["custom_id"] != expected_custom_id: + print( + f"Expected custom_id '{expected_custom_id}', got '{output['custom_id']}'" + ) + return False + + response = output["response"] + required_fields = ["status_code", "request_id", "body"] + for field in required_fields: + if field not in response: + print( + f"Missing required field 'response.{field}' in response {i+1}" + ) + return False + + # Check that we got a successful response + if response["status_code"] != 200: + print(f"Expected status_code 200, got {response['status_code']}") + return False + + body = response["body"] + required_fields = ["model", "choices"] + for field in required_fields: + if field not in body: + print( + f"Missing required field 'response.body.{field}' in response {i+1}" + ) + return False + + except json.JSONDecodeError as e: + print(f"Invalid JSON in output line {i+1}: {e}") + return False + + return True + + +async def check_service_health(base_url: str) -> bool: + """Check if the service is running and healthy.""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # Check general health endpoint + health_response = await client.get(f"{base_url}/v1/batches") + return health_response.status_code == 200 + except Exception as e: + print(f"Health check failed: {e}") + return False + + +@pytest.fixture(scope="session") +def service_health(): + """Fixture to check service health and skip tests if service is not available.""" + base_url = "http://localhost:8888" + + print(f"🔍 Checking service health at {base_url}...") + + # Run the async health check in a sync context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + is_healthy = loop.run_until_complete(check_service_health(base_url)) + finally: + loop.close() + + if not is_healthy: + pytest.skip(f"Service at {base_url} is not available or healthy") + + print(f"✅ Service at {base_url} is healthy") + return base_url + + +@pytest.mark.asyncio +async def test_batch_api_e2e_real_service(service_health): + """ + End-to-end test for OpenAI Batch API against real service: + 1. Upload sample input file via Files API + 2. Create batch job via Batch API + 3. Poll job status until completion + 4. Download and verify output via Files API + 5. Verify batch list API works + """ + base_url = service_health + + async with httpx.AsyncClient(timeout=60.0) as client: + # Step 1: Upload sample input file via Files API + print("Step 1: Uploading batch input file...") + + input_data = generate_batch_input_data(3) + files = {"file": ("batch_input.jsonl", input_data, "application/jsonl")} + data = {"purpose": "batch"} + + upload_response = await client.post( + f"{base_url}/v1/files", files=files, data=data + ) + assert ( + upload_response.status_code == 200 + ), f"File upload failed: {upload_response.text}" + + upload_result = upload_response.json() + assert upload_result["object"] == "file" + assert upload_result["purpose"] == "batch" + assert upload_result["status"] == "uploaded" + + input_file_id = upload_result["id"] + print(f"✅ File uploaded successfully with ID: {input_file_id}") + + # Step 2: Create batch job via Batch API + print("Step 2: Creating batch job...") + + batch_request = { + "input_file_id": input_file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + } + + batch_response = await client.post(f"{base_url}/v1/batches", json=batch_request) + assert ( + batch_response.status_code == 200 + ), f"Batch creation failed: {batch_response.text}" + + batch_result = batch_response.json() + assert batch_result["object"] == "batch" + assert batch_result["input_file_id"] == input_file_id + assert batch_result["endpoint"] == "/v1/chat/completions" + + batch_id = batch_result["id"] + print(f"✅ Batch created successfully with ID: {batch_id}") + + # Step 3: Poll job status until completion + print("Step 3: Polling job status until completion...") + + max_polls = ( + 60 # Maximum number of polling attempts (increased for real service) + ) + poll_interval = 5 # seconds (increased for real service) + + for attempt in range(max_polls): + status_response = await client.get(f"{base_url}/v1/batches/{batch_id}") + assert ( + status_response.status_code == 200 + ), f"Status check failed: {status_response.text}" + + status_result = status_response.json() + current_status = status_result["status"] + + print(f" Attempt {attempt + 1}: Status = {current_status}") + + if current_status == "completed": + print("✅ Batch job completed successfully!") + output_file_id = status_result["output_file_id"] + assert ( + output_file_id is not None + ), "Expected output_file_id for completed batch" + + request_counts = status_result.get("request_counts") + if request_counts: + print(f" Request counts: {request_counts}") + assert request_counts["total"] == 3 + assert request_counts["completed"] == 3 + assert request_counts["failed"] == 0 + + break + elif current_status == "failed": + error_info = status_result.get("errors", "Unknown error") + pytest.fail(f"Batch job failed: {error_info}") + elif current_status in ["cancelled", "expired"]: + pytest.fail(f"Batch job was {current_status}") + elif current_status in ["validating", "in_progress", "finalizing"]: + # These are expected intermediate states + pass + else: + print(f" Unknown status: {current_status}") + + # Wait before next poll + await asyncio.sleep(poll_interval) + else: + pytest.fail( + f"Batch job did not complete within {max_polls * poll_interval} seconds" + ) + + # Step 4: Download and verify output via Files API + print("Step 4: Downloading and verifying output...") + + output_response = await client.get( + f"{base_url}/v1/files/{output_file_id}/content" + ) + assert ( + output_response.status_code == 200 + ), f"Output download failed: {output_response.text}" + + output_content = output_response.content.decode("utf-8") + assert output_content, "Output file is empty" + + # Verify output content structure + is_valid = verify_batch_output_content(output_content, 3) + assert ( + is_valid + ), f"Output content verification failed. Content:\n{output_content}" + + print("✅ Output downloaded and verified successfully!") + print(f"Output content preview:\n{output_content[:500]}...") + + # Step 5: Verify batch list API works + print("Step 5: Testing batch list API...") + + list_response = await client.get(f"{base_url}/v1/batches") + assert ( + list_response.status_code == 200 + ), f"Batch list failed: {list_response.text}" + + list_result = list_response.json() + assert list_result["object"] == "list" + assert len(list_result["data"]) >= 1, "Expected at least one batch in the list" + + # Find our batch in the list + our_batch = None + for batch in list_result["data"]: + if batch["id"] == batch_id: + our_batch = batch + break + + assert our_batch is not None, f"Batch {batch_id} not found in list" + assert our_batch["status"] == "completed" + + print("✅ Batch list API verified successfully!") + + print( + "\n🎉 E2E test completed successfully! All OpenAI Batch API endpoints working correctly." + ) From cf173d8d97dabef9c543f4454e1a1778fb4ee4ce Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Tue, 14 Oct 2025 10:17:06 -0700 Subject: [PATCH 03/11] Fix envoy routes. Pass e2e test, and update document Signed-off-by: Jingyuan Zhang --- build/container/Dockerfile.metadata | 4 +- config/api-extension/deployment.yaml | 88 -------- config/api-extension/kustomization.yaml | 8 - config/api-extension/rbac.yaml | 46 ---- config/api-extension/service.yaml | 13 -- config/default/kustomization.yaml | 1 - .../gateway-plugin/gateway-plugin.yaml | 31 +-- config/metadata/job_template_patch.yaml | 16 ++ config/metadata/kustomization.yaml | 6 + config/metadata/metadata.yaml | 89 +++++++- .../templates/api-extension/deployment.yaml | 52 ----- dist/chart/templates/api-extension/rbac.yaml | 54 ----- .../templates/api-extension/service.yaml | 16 -- .../envoy_extension_policy.yaml | 2 +- .../templates/gateway-plugin/httproute.yaml | 23 +- .../metadata-service/deployment.yaml | 2 - .../templates/metadata-service/rbac.yaml | 24 +++ dist/chart/values.yaml | 20 -- docs/source/features/batch-api.rst | 202 +++++++++++++++++- python/aibrix/aibrix/metadata/app.py | 2 +- python/aibrix/aibrix/metadata/cache/job.py | 10 +- .../metadata/setting/k8s_job_template.yaml | 12 +- python/aibrix/pyproject.toml | 3 +- python/aibrix/scripts/__init__.py | 13 ++ python/aibrix/scripts/generate_secrets.py | 14 +- python/aibrix/tests/batch/conftest.py | 2 +- .../aibrix/tests/batch/testdata/job_rbac.yaml | 6 +- python/aibrix/tests/test_files_api.py | 2 +- 28 files changed, 383 insertions(+), 378 deletions(-) delete mode 100644 config/api-extension/deployment.yaml delete mode 100644 config/api-extension/kustomization.yaml delete mode 100644 config/api-extension/rbac.yaml delete mode 100644 config/api-extension/service.yaml create mode 100644 config/metadata/job_template_patch.yaml delete mode 100644 dist/chart/templates/api-extension/deployment.yaml delete mode 100644 dist/chart/templates/api-extension/rbac.yaml delete mode 100644 dist/chart/templates/api-extension/service.yaml create mode 100644 python/aibrix/scripts/__init__.py diff --git a/build/container/Dockerfile.metadata b/build/container/Dockerfile.metadata index d53d42ccb..a6dfcf64d 100644 --- a/build/container/Dockerfile.metadata +++ b/build/container/Dockerfile.metadata @@ -45,6 +45,6 @@ RUN apt-get update \ # Expose metadata service port EXPOSE 8090 -# Run metadata server -CMD ["python", "-m", "aibrix.metadata.app", "--host", "0.0.0.0", "--port", "8090"] +# Set entrypoint for Metadata service +ENTRYPOINT ["aibrix_metadata", "--enable-k8s-job", "--host", "0.0.0.0"] diff --git a/config/api-extension/deployment.yaml b/config/api-extension/deployment.yaml deleted file mode 100644 index c0444ee0e..000000000 --- a/config/api-extension/deployment.yaml +++ /dev/null @@ -1,88 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: api-extension - namespace: system -spec: - replicas: 1 - selector: - matchLabels: - app: api-extension - template: - metadata: - labels: - app: api-extension - spec: - serviceAccountName: api-extension-sa - automountServiceAccountToken: true - containers: - - name: api-extension - image: aibrix/runtime:nightly - command: - - aibrix_api_extension - - --enable-k8s-job - ports: - - containerPort: 8100 - resources: - limits: - cpu: 1 - memory: 1Gi - requests: - cpu: 1 - memory: 1Gi - env: - # Metadata store configuration - - name: REDIS_HOST - value: "aibrix-redis-master.aibrix-system.svc.cluster.local" - - name: REDIS_PORT - value: "6379" - - name: REDIS_DB - value: "0" - # Object store configuration - # Comment the following lines to disable S3 as the object store - - name: STORAGE_AWS_ACCESS_KEY_ID - valueFrom: - secretKeyRef: - name: aibrix-s3-credentials - key: access-key-id - - name: STORAGE_AWS_SECRET_ACCESS_KEY - valueFrom: - secretKeyRef: - name: aibrix-s3-credentials - key: secret-access-key - - name: STORAGE_AWS_REGION - valueFrom: - secretKeyRef: - name: aibrix-s3-credentials - key: region - - name: STORAGE_AWS_BUCKET - valueFrom: - secretKeyRef: - name: aibrix-s3-credentials - key: bucket-name - # Uncomment the following lines to enable TOS as the object store - # - name: STORAGE_TOS_ACCESS_KEY - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: access-key - # - name: STORAGE_TOS_SECRET_KEY - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: secret-key - # - name: STORAGE_TOS_ENDPOINT - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: endpoint - # - name: STORAGE_TOS_REGION - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: region - # - name: STORAGE_TOS_BUCKET - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: bucket-name \ No newline at end of file diff --git a/config/api-extension/kustomization.yaml b/config/api-extension/kustomization.yaml deleted file mode 100644 index 0cfe867fb..000000000 --- a/config/api-extension/kustomization.yaml +++ /dev/null @@ -1,8 +0,0 @@ -resources: -- deployment.yaml -- service.yaml -- rbac.yaml - -labels: - - pairs: - app.kubernetes.io/component: aibrix-api-extension \ No newline at end of file diff --git a/config/api-extension/rbac.yaml b/config/api-extension/rbac.yaml deleted file mode 100644 index 4da22e329..000000000 --- a/config/api-extension/rbac.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# Service Account for api extension -apiVersion: v1 -kind: ServiceAccount -metadata: - name: api-extension-sa - namespace: system ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: api-extension-clusterrole -rules: - - apiGroups: ["batch"] # for batch job watching - resources: ["jobs"] - verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - - apiGroups: ["coordination.k8s.io"] # for kopf high availability - resources: ["leases"] - verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - - apiGroups: ["apiextensions.k8s.io"] # required by kopf - resources: ["customresourcedefinitions"] - verbs: ["get", "list", "watch"] - - apiGroups: [""] # required by kopf - resources: ["namespaces"] - verbs: ["list", "watch"] - - apiGroups: [""] # for ServiceAccount management - resources: ["serviceaccounts"] - verbs: ["get", "create", "update", "patch", "delete"] - - apiGroups: ["rbac.authorization.k8s.io"] # for Role management - resources: ["roles"] - verbs: ["get", "create", "update", "patch", "delete"] - - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management - resources: ["rolebindings"] - verbs: ["get", "create", "update", "patch", "delete"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: api-extension-clusterrole-binding -subjects: - - kind: ServiceAccount - name: api-extension-sa - namespace: system -roleRef: - kind: ClusterRole - name: api-extension-clusterrole - apiGroup: rbac.authorization.k8s.io diff --git a/config/api-extension/service.yaml b/config/api-extension/service.yaml deleted file mode 100644 index 36a9f177d..000000000 --- a/config/api-extension/service.yaml +++ /dev/null @@ -1,13 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: api-extension - namespace: system -spec: - selector: - app: api-extension - ports: - - protocol: TCP - port: 8100 - targetPort: 8100 - type: ClusterIP \ No newline at end of file diff --git a/config/default/kustomization.yaml b/config/default/kustomization.yaml index fb255b933..aa90415d1 100644 --- a/config/default/kustomization.yaml +++ b/config/default/kustomization.yaml @@ -25,7 +25,6 @@ resources: - ../gateway - ../metadata - ../gpu-optimizer -- ../api-extension - ../dependency/kuberay-operator # [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in # crd/kustomization.yaml diff --git a/config/gateway/gateway-plugin/gateway-plugin.yaml b/config/gateway/gateway-plugin/gateway-plugin.yaml index e62d82d13..c85f8c369 100644 --- a/config/gateway/gateway-plugin/gateway-plugin.yaml +++ b/config/gateway/gateway-plugin/gateway-plugin.yaml @@ -134,7 +134,7 @@ spec: apiVersion: gateway.networking.k8s.io/v1 kind: HTTPRoute metadata: - name: reserved-router-models-endpoint + name: reserved-router-metadata-endpoint namespace: system spec: parentRefs: @@ -144,6 +144,12 @@ spec: - path: type: PathPrefix value: /v1/models + - path: + type: PathPrefix + value: /v1/files + - path: + type: PathPrefix + value: /v1/batches backendRefs: - name: aibrix-metadata-service port: 8090 @@ -157,7 +163,7 @@ spec: targetRef: group: gateway.networking.k8s.io kind: HTTPRoute - name: aibrix-reserved-router-models-endpoint + name: aibrix-reserved-router-metadata-endpoint --- # this is a dummy route for incoming request and, # then request is routed to httproute using model name OR @@ -212,24 +218,3 @@ spec: response: body: Streamed messageTimeout: 60s ---- -# this is a route for extension APIs implemented in python/aibrix/metadata -apiVersion: gateway.networking.k8s.io/v1 -kind: HTTPRoute -metadata: - name: reserved-router-extensions-endpoint - namespace: system -spec: - parentRefs: - - name: aibrix-eg - rules: - - matches: - - path: - type: PathPrefix - value: /v1/files - - path: - type: PathPrefix - value: /v1/batches - backendRefs: - - name: aibrix-api-extension - port: 8100 diff --git a/config/metadata/job_template_patch.yaml b/config/metadata/job_template_patch.yaml new file mode 100644 index 000000000..b1418687b --- /dev/null +++ b/config/metadata/job_template_patch.yaml @@ -0,0 +1,16 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: batch-job-template + namespace: default +spec: + parallelism: 1 # Customizable. The number of parallel workers. + completions: 1 # Customizable. Must equal to the parallelism. + backoffLimit: 2 # Customizable, but usually no need to change. + template: + spec: + containers: + - name: batch-worker + image: aibrix/runtime:nightly # Customizable, runtime image + - name: llm-engine + image: aibrix/vllm-mock:nightly # Customizable, LLM engine image \ No newline at end of file diff --git a/config/metadata/kustomization.yaml b/config/metadata/kustomization.yaml index 74e8d2ac5..35e726cf7 100644 --- a/config/metadata/kustomization.yaml +++ b/config/metadata/kustomization.yaml @@ -2,6 +2,12 @@ resources: - metadata.yaml - redis.yaml +configMapGenerator: +- name: metadata-config + namespace: aibrix-system + files: + - job_template_patch.yaml + labels: - pairs: app.kubernetes.io/component: aibrix-metadata-service \ No newline at end of file diff --git a/config/metadata/metadata.yaml b/config/metadata/metadata.yaml index dd2360fa0..1c3d1a22a 100644 --- a/config/metadata/metadata.yaml +++ b/config/metadata/metadata.yaml @@ -30,6 +30,30 @@ rules: - apiGroups: ["model.aibrix.ai"] resources: ["modeladapters"] verbs: ["get", "list"] + # For batch job watching + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + # For batch job ServiceAccount management + - apiGroups: [""] + resources: ["serviceaccounts"] + verbs: ["get", "create", "update", "patch", "delete"] + - apiGroups: ["rbac.authorization.k8s.io"] # for Role management + resources: ["roles"] + verbs: ["get", "create", "update", "patch", "delete"] + - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management + resources: ["rolebindings"] + verbs: ["get", "create", "update", "patch", "delete"] + # For kopf high availability + - apiGroups: ["coordination.k8s.io"] + resources: ["leases"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + - apiGroups: ["apiextensions.k8s.io"] # required by kopf + resources: ["customresourcedefinitions"] + verbs: ["get", "list", "watch"] + - apiGroups: [""] # required by kopf + resources: ["namespaces"] + verbs: ["list", "watch"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding @@ -64,14 +88,27 @@ spec: - name: init-redis image: busybox command: ['sh', '-c', 'until echo "ping" | nc aibrix-redis-master 6379 -w 1 | grep -c PONG; do echo waiting for redis; sleep 2; done'] + volumes: + - name: config-volume + configMap: + name: metadata-config containers: - name: metadata-service image: metadata-service:latest imagePullPolicy: IfNotPresent - command: ["python", "-m", "aibrix.metadata.app"] - args: ["--host=0.0.0.0", "--port=8090"] + command: + - aibrix_metadata + - --host + - "0.0.0.0" + - --enable-k8s-job + - --k8s-job-patch + - /app/config/job_template_patch.yaml ports: - containerPort: 8090 + volumeMounts: + - name: config-volume + mountPath: /app/config + readOnly: true resources: limits: cpu: 500m @@ -92,6 +129,54 @@ spec: valueFrom: fieldRef: fieldPath: metadata.namespace + # Object store configuration + # Comment the following lines to disable S3 as the object store + - name: STORAGE_AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: access-key-id + - name: STORAGE_AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: secret-access-key + - name: STORAGE_AWS_REGION + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: region + - name: STORAGE_AWS_BUCKET + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: bucket-name + # Uncomment the following lines to enable TOS as the object store + # - name: STORAGE_TOS_ACCESS_KEY + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: access-key + # - name: STORAGE_TOS_SECRET_KEY + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: secret-key + # - name: STORAGE_TOS_ENDPOINT + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: endpoint + # - name: STORAGE_TOS_REGION + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: region + # - name: STORAGE_TOS_BUCKET + # valueFrom: + # secretKeyRef: + # name: aibrix-tos-credentials + # key: bucket-name livenessProbe: httpGet: path: /healthz diff --git a/dist/chart/templates/api-extension/deployment.yaml b/dist/chart/templates/api-extension/deployment.yaml deleted file mode 100644 index f8e0b270d..000000000 --- a/dist/chart/templates/api-extension/deployment.yaml +++ /dev/null @@ -1,52 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: aibrix-api-extension - namespace: {{ .Release.Namespace }} - labels: - {{- include "chart.labels" . | nindent 4 }} - app.kubernetes.io/component: aibrix-api-extension -spec: - replicas: {{ .Values.apiExtension.replicaCount }} - selector: - matchLabels: - {{- include "chart.selectorLabels" . | nindent 6 }} - app.kubernetes.io/component: aibrix-api-extension - template: - metadata: - labels: - {{- include "chart.labels" . | nindent 8 }} - app.kubernetes.io/component: aibrix-api-extension - spec: - serviceAccountName: aibrix-api-extension-sa - automountServiceAccountToken: true - {{- include "chart.imagePullSecrets" (dict "componentSecrets" .Values.apiExtension.imagePullSecrets "globalSecrets" .Values.global.imagePullSecrets) | nindent 6 }} - containers: - - name: api-extension - image: {{ .Values.apiExtension.container.image.repository }}:{{ .Values.apiExtension.container.image.tag }} - command: - - aibrix_api_extension - - --enable-k8s-job - ports: - - containerPort: 8100 - resources: {{ toYaml .Values.apiExtension.container.resources | nindent 12 }} - # TODO: Add liveness and readiness probes - env: - # Metadata store configuration - - name: REDIS_HOST - value: {{ .Values.apiExtension.dependencies.redis.host }} - - name: REDIS_PORT - value: "{{ .Values.apiExtension.dependencies.redis.port }}" - {{- if .Values.metadata.redis.enablePassword }} - - name: REDIS_PASSWORD - valueFrom: - secretKeyRef: - name: aibrix-redis - key: redis-password - {{- end }} - # Object store configuration - # TODO: Add object store config - {{- with .Values.apiExtension.tolerations }} - tolerations: - {{- toYaml . | nindent 8 }} - {{- end }} diff --git a/dist/chart/templates/api-extension/rbac.yaml b/dist/chart/templates/api-extension/rbac.yaml deleted file mode 100644 index 0950b5538..000000000 --- a/dist/chart/templates/api-extension/rbac.yaml +++ /dev/null @@ -1,54 +0,0 @@ -apiVersion: v1 -kind: ServiceAccount -metadata: - name: aibrix-api-extension-sa - namespace: {{ .Release.Namespace }} - labels: - {{- include "chart.labels" . | nindent 4 }} - app.kubernetes.io/component: aibrix-api-extension ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: aibrix-api-extension-clusterrole - labels: - {{- include "chart.labels" . | nindent 4 }} - app.kubernetes.io/component: aibrix-api-extension -rules: - - apiGroups: ["batch"] # for batch job watching - resources: ["jobs"] - verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - - apiGroups: ["coordination.k8s.io"] # for kopf high availability - resources: ["leases"] - verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - - apiGroups: ["apiextensions.k8s.io"] # required by kopf - resources: ["customresourcedefinitions"] - verbs: ["get", "list", "watch"] - - apiGroups: [""] # required by kopf - resources: ["namespaces"] - verbs: ["list", "watch"] - - apiGroups: [""] # for ServiceAccount management - resources: ["serviceaccounts"] - verbs: ["get", "create", "update", "patch", "delete"] - - apiGroups: ["rbac.authorization.k8s.io"] # for Role management - resources: ["roles"] - verbs: ["get", "create", "update", "patch", "delete"] - - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management - resources: ["rolebindings"] - verbs: ["get", "create", "update", "patch", "delete"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: aibrix-api-extension-clusterrole-binding - labels: - {{- include "chart.labels" . | nindent 4 }} - app.kubernetes.io/component: aibrix-api-extension -subjects: - - kind: ServiceAccount - name: aibrix-api-extension-sa - namespace: {{ .Release.Namespace }} -roleRef: - kind: ClusterRole - name: aibrix-api-extension-clusterrole - apiGroup: rbac.authorization.k8s.io diff --git a/dist/chart/templates/api-extension/service.yaml b/dist/chart/templates/api-extension/service.yaml deleted file mode 100644 index aa2e006d2..000000000 --- a/dist/chart/templates/api-extension/service.yaml +++ /dev/null @@ -1,16 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: aibrix-api-extension - namespace: {{ .Release.Namespace }} - labels: - {{- include "chart.labels" . | nindent 4 }} - app.kubernetes.io/component: aibrix-api-extension -spec: - selector: - app.kubernetes.io/component: aibrix-api-extension - ports: - - protocol: TCP - port: 8100 - targetPort: 8100 - type: ClusterIP diff --git a/dist/chart/templates/gateway-plugin/envoy_extension_policy.yaml b/dist/chart/templates/gateway-plugin/envoy_extension_policy.yaml index 489b2c2a0..4829f10d7 100644 --- a/dist/chart/templates/gateway-plugin/envoy_extension_policy.yaml +++ b/dist/chart/templates/gateway-plugin/envoy_extension_policy.yaml @@ -35,4 +35,4 @@ spec: targetRef: group: gateway.networking.k8s.io kind: HTTPRoute - name: aibrix-reserved-router-models-endpoint \ No newline at end of file + name: aibrix-reserved-router-metadata-endpoint \ No newline at end of file diff --git a/dist/chart/templates/gateway-plugin/httproute.yaml b/dist/chart/templates/gateway-plugin/httproute.yaml index 5ea3a1b9a..b83afe172 100644 --- a/dist/chart/templates/gateway-plugin/httproute.yaml +++ b/dist/chart/templates/gateway-plugin/httproute.yaml @@ -33,7 +33,7 @@ spec: apiVersion: gateway.networking.k8s.io/v1 kind: HTTPRoute metadata: - name: aibrix-reserved-router-models-endpoint + name: aibrix-reserved-router-metadata-endpoint namespace: {{ .Release.Namespace }} labels: {{- include "chart.labels" . | nindent 4 }} @@ -46,23 +46,6 @@ spec: - path: type: PathPrefix value: /v1/models - backendRefs: - - name: aibrix-metadata-service - port: 8090 ---- -apiVersion: gateway.networking.k8s.io/v1 -kind: HTTPRoute -metadata: - name: aibrix-reserved-router-extensions-endpoint - namespace: {{ .Release.Namespace }} - labels: - {{- include "chart.labels" . | nindent 4 }} - app.kubernetes.io/component: aibrix-api-extension -spec: - parentRefs: - - name: aibrix-eg - rules: - - matches: - path: type: PathPrefix value: /v1/files @@ -70,5 +53,5 @@ spec: type: PathPrefix value: /v1/batches backendRefs: - - name: aibrix-api-extension - port: 8100 + - name: aibrix-metadata-service + port: 8090 diff --git a/dist/chart/templates/metadata-service/deployment.yaml b/dist/chart/templates/metadata-service/deployment.yaml index 7c5e8c25b..0c1f8c0e2 100644 --- a/dist/chart/templates/metadata-service/deployment.yaml +++ b/dist/chart/templates/metadata-service/deployment.yaml @@ -31,8 +31,6 @@ spec: - name: metadata-service image: {{ .Values.metadata.service.container.image.repository }}:{{ .Values.metadata.service.container.image.tag }} imagePullPolicy: {{ .Values.metadata.service.container.image.imagePullPolicy | default "IfNotPresent" }} - command: ["python", "-m", "aibrix.metadata.app"] - args: ["--host=0.0.0.0", "--port=8090"] ports: - containerPort: 8090 resources: {{ toYaml .Values.metadata.service.container.resources | nindent 12 }} diff --git a/dist/chart/templates/metadata-service/rbac.yaml b/dist/chart/templates/metadata-service/rbac.yaml index 3fa004362..ea6f33283 100644 --- a/dist/chart/templates/metadata-service/rbac.yaml +++ b/dist/chart/templates/metadata-service/rbac.yaml @@ -24,6 +24,30 @@ rules: - apiGroups: ["model.aibrix.ai"] resources: ["modeladapters"] verbs: ["get", "list"] + # For batch job watching + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + # For batch job ServiceAccount management + - apiGroups: [""] + resources: ["serviceaccounts"] + verbs: ["get", "create", "update", "patch", "delete"] + - apiGroups: ["rbac.authorization.k8s.io"] # for Role management + resources: ["roles"] + verbs: ["get", "create", "update", "patch", "delete"] + - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management + resources: ["rolebindings"] + verbs: ["get", "create", "update", "patch", "delete"] + # For kopf high availability + - apiGroups: ["coordination.k8s.io"] + resources: ["leases"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + - apiGroups: ["apiextensions.k8s.io"] # required by kopf + resources: ["customresourcedefinitions"] + verbs: ["get", "list", "watch"] + - apiGroups: [""] # required by kopf + resources: ["namespaces"] + verbs: ["list", "watch"] --- apiVersion: rbac.authorization.k8s.io/v1 diff --git a/dist/chart/values.yaml b/dist/chart/values.yaml index dcfe32bb4..1f4dc835d 100644 --- a/dist/chart/values.yaml +++ b/dist/chart/values.yaml @@ -168,26 +168,6 @@ metadata: enablePassword: false password: "" -apiExtension: - replicaCount: 1 - imagePullSecrets: [] - container: - image: - repository: aibrix/runtime - tag: nightly - resources: - limits: - cpu: "1" - memory: 1Gi - requests: - cpu: "1" - memory: 1Gi - dependencies: - redis: - host: aibrix-redis-master - port: 6379 - tolerations: [] - # [CRDs]: To enable the CRDs crd: # This option determines whether the CRDs are included diff --git a/docs/source/features/batch-api.rst b/docs/source/features/batch-api.rst index 24cf17282..ce7f9a1ed 100644 --- a/docs/source/features/batch-api.rst +++ b/docs/source/features/batch-api.rst @@ -358,24 +358,206 @@ Using the OpenAI Python SDK (works with AIBrix as a drop-in replacement): else: print(f"Batch failed with status: {batch.status}") -Configuration +Customization ------------- -Metadata Server Configuration -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Customizing Job Executor +^^^^^^^^^^^^^^^^^^^^^^^^^ -The metadata server requires configuration for batch processing: +You can customize the batch job execution environment by modifying the job template patch configuration. This allows you to specify custom container images, resource requirements, and other Kubernetes Job specifications. + +**Job Template Patch Configuration** + +The job executor behavior is controlled by the ``config/metadata/job_template_patch.yaml`` file. This file defines the Kubernetes Job template that will be used for batch processing: + +.. code-block:: yaml + + apiVersion: batch/v1 + kind: Job + metadata: + name: batch-job-template + namespace: default + spec: + parallelism: 1 # Customizable. The number of parallel workers. + completions: 1 # Customizable. Must equal to the parallelism. + backoffLimit: 2 # Customizable, but usually no need to change. + template: + spec: + containers: + - name: batch-worker + image: aibrix/runtime:nightly # Customizable, runtime image + - name: llm-engine + image: aibrix/vllm-mock:nightly # Customizable, LLM engine image + +**Customization Options:** + +- **parallelism**: Number of parallel worker pods (affects throughput) +- **completions**: Must match parallelism for proper job completion +- **backoffLimit**: Number of retries for failed worker pods +- **batch-worker image**: Runtime container that coordinates batch processing +- **llm-engine image**: LLM inference engine container (e.g., vLLM, TensorRT-LLM) + +**Common Customizations:** + +1. **Use Custom LLM Engine:** + + .. code-block:: yaml + + containers: + - name: llm-engine + image: your-registry/custom-vllm:latest + +2. **Increase Parallelism:** + + .. code-block:: yaml + + spec: + parallelism: 4 + completions: 4 + +3. **Add Resource Requirements:** + + .. code-block:: yaml + + containers: + - name: llm-engine + image: aibrix/vllm-mock:nightly + resources: + requests: + nvidia.com/gpu: 1 + memory: "8Gi" + limits: + nvidia.com/gpu: 1 + memory: "16Gi" + +4. **Add Environment Variables:** + + .. code-block:: yaml + + containers: + - name: llm-engine + image: aibrix/vllm-mock:nightly + env: + - name: CUDA_VISIBLE_DEVICES + value: "0" + - name: MODEL_PATH + value: "/models/your-model" + +**Applying Changes:** + +After modifying ``job_template_patch.yaml``, apply the changes using: + +.. code-block:: bash + + kubectl apply -k config/default + +Verification and Testing +------------------------ + +Verifying Batch API Functionality +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Follow these steps to verify that the Batch API is working correctly in your AIBrix deployment: + +**Step 1: Set Up Port Forwarding** + +First, create a port-forward to access the AIBrix services: .. code-block:: bash - # Enable Batch API (note: set to false to enable) - --disable-batch-api=false + # Port-forward the gateway service to access AIBrix APIs + kubectl -n envoy-gateway-system port-forward service/envoy-aibrix-system-aibrix-eg-903790dc 8888:80 1>/dev/null 2>&1 & + + # Verify the port-forward is working + curl -s http://localhost:8888/v1/batches + +**Step 2: Set Up Object Store Credentials** + +Configure S3 credentials for batch file storage: + +.. code-block:: bash + + # Navigate to the Python package directory + cd python/aibrix + + # Install the AIBrix package in development mode + pip install -e . + + # Generate S3 credentials secret (replace with your S3 bucket) + aibrix_gen_secrets s3 --bucket your-s3-bucket-name + + # Example with specific bucket: + # aibrix_gen_secrets s3 --bucket my-aibrix-batch-storage + +This command will: + +- Create the necessary Kubernetes secrets for S3 access +- Configure the metadata service to use your S3 bucket for file storage +- Set up proper IAM credentials for batch job file operations + +**Step 3: Run End-to-End Tests** + +Execute the comprehensive batch API test suite: + +.. code-block:: bash + + # Navigate to the Python package directory (if not already there) + cd python/aibrix + + # Run the batch API end-to-end tests + pytest tests/e2e/test_batch_api.py -v + +**Expected Test Output:** + +.. code-block:: text + + tests/e2e/test_batch_api.py::test_batch_api_e2e_real_service PASSED + + ========================= 1 passed in 10.78s ========================= + +**Test Coverage:** + +The test suite verifies: + +- **File Upload/Download**: Files API functionality with S3 backend +- **Batch Job Creation**: Proper batch job submission and validation +- **Kubernetes Job Execution**: Worker pod creation and execution +- **Status Monitoring**: Real-time batch status tracking +- **Result Collection**: Output file generation and retrieval + +**Troubleshooting Common Issues:** + +1. **Port-forward Connection Issues:** + + .. code-block:: bash + + # Check if port-forward is running + ps aux | grep port-forward + + # Kill existing port-forwards and restart + pkill -f "port-forward.*8888" + kubectl -n envoy-gateway-system port-forward service/envoy-aibrix-system-aibrix-eg-903790dc 8888:80 & + +2. **S3 Credentials Issues:** + + .. code-block:: bash + + # Verify S3 secret was created + kubectl get secret aibrix-s3-credentials -n aibrix-system + + # Check secret contents + kubectl get secret aibrix-s3-credentials -n aibrix-system -o yaml + +3. **Test Failures:** + + .. code-block:: bash + + # Run tests with more verbose output + pytest tests/e2e/test_batch_api.py -v -s --tb=long - # Enable Kubernetes Job execution - --enable-k8s-job=true +**Manual Verification:** - # Optional: Specify custom job template patch - --k8s-job-patch=/path/to/job_patch.yaml +You can also manually verify the batch API using curl commands as shown in the Examples section above, using ``localhost:8888`` as your endpoint after setting up the port-forward. API Reference diff --git a/python/aibrix/aibrix/metadata/app.py b/python/aibrix/aibrix/metadata/app.py index 951e5746a..50e6513b6 100644 --- a/python/aibrix/aibrix/metadata/app.py +++ b/python/aibrix/aibrix/metadata/app.py @@ -218,7 +218,7 @@ def nullable_str(val: str): def main(): parser = argparse.ArgumentParser(description=f"Run {settings.PROJECT_NAME}") parser.add_argument("--host", type=nullable_str, default=None, help="host name") - parser.add_argument("--port", type=int, default=8100, help="port number") + parser.add_argument("--port", type=int, default=8090, help="port number") parser.add_argument( "--enable-fastapi-docs", action="store_true", diff --git a/python/aibrix/aibrix/metadata/cache/job.py b/python/aibrix/aibrix/metadata/cache/job.py index f7860a34d..dbfd94f89 100644 --- a/python/aibrix/aibrix/metadata/cache/job.py +++ b/python/aibrix/aibrix/metadata/cache/job.py @@ -314,7 +314,7 @@ async def submit_job( session_id: str, job_spec: BatchJobSpec, job_name: Optional[str] = None, - parallelism: int = 1, + parallelism: Optional[int] = None, prepared_job: Optional[BatchJob] = None, ) -> None: """Submit job by creating a Kubernetes Job. @@ -322,6 +322,7 @@ async def submit_job( Args: job_spec: BatchJobSpec to submit to Kubernetes. job_name: Optional job name, will generate one if not provided. + parallelism: Optional parallelism for the job, default to None and follow template settings. prepared_job: Optional BatchJob with file IDs to add to pod annotations. Raises: @@ -762,7 +763,7 @@ def _batch_job_spec_to_k8s_job( session_id: str, job_spec: BatchJobSpec, job_name: Optional[str] = None, - parallelism: int = 1, + parallelism: Optional[int] = None, prepared_job: Optional[BatchJob] = None, ) -> Dict[str, Any]: """Convert BatchJobSpec to Kubernetes Job manifest using pre-loaded template. @@ -845,10 +846,11 @@ def _batch_job_spec_to_k8s_job( }, "activeDeadlineSeconds": job_spec.completion_window, "suspend": suspend, - "parallelism": parallelism, - "completions": parallelism, }, } + if parallelism is not None: + job_patch["spec"]["parallelism"] = parallelism + job_patch["spec"]["completions"] = parallelism # Use pre-loaded template (deep copy to avoid modifying the original) job_template = merge_yaml_object(self.job_template, job_patch) diff --git a/python/aibrix/aibrix/metadata/setting/k8s_job_template.yaml b/python/aibrix/aibrix/metadata/setting/k8s_job_template.yaml index 33ff63261..b4af36d93 100644 --- a/python/aibrix/aibrix/metadata/setting/k8s_job_template.yaml +++ b/python/aibrix/aibrix/metadata/setting/k8s_job_template.yaml @@ -8,7 +8,11 @@ metadata: annotations: # Template annotations will be merged with BatchJobSpec annotations spec: - suspend: true # !!Important: creates the job in a paused state + suspend: true # Non-customizable. !!Important: creates the job in a paused state, job controller will validate and start the job + parallelism: 1 # Customizable. The number of parallel workers. + completions: 1 # Customizable. Must equal to the parallelism. + backoffLimit: 2 # Customizable, but usually no need to change. + activeDeadlineSeconds: 86400 # Non-customizable. The deadline will be modified for each request according to template: metadata: labels: @@ -16,8 +20,8 @@ spec: spec: serviceAccountName: job-reader-sa automountServiceAccountToken: true - shareProcessNamespace: true # Allow worker to kill llm-engine - restartPolicy: Never + shareProcessNamespace: true # Non-customizable. Allow worker to kill llm-engine. + restartPolicy: Never # Non-customizable. Must be "Never." containers: - name: batch-worker image: aibrix/runtime:nightly @@ -85,5 +89,3 @@ spec: periodSeconds: 5 successThreshold: 1 timeoutSeconds: 1 - backoffLimit: 2 - activeDeadlineSeconds: 86400 # 24 hours \ No newline at end of file diff --git a/python/aibrix/pyproject.toml b/python/aibrix/pyproject.toml index 317e84de2..982db885e 100644 --- a/python/aibrix/pyproject.toml +++ b/python/aibrix/pyproject.toml @@ -47,7 +47,8 @@ aibrix_download = 'aibrix.downloader.__main__:main' aibrix_benchmark = "aibrix.gpu_optimizer.optimizer.profiling.benchmark:main" aibrix_gen_profile = 'aibrix.gpu_optimizer.optimizer.profiling.gen_profile:main' aibrix_batch_worker = 'aibrix.batch.worker:main' -aibrix_api_extension = 'aibrix.metadata.app:main' +aibrix_metadata = 'aibrix.metadata.app:main' +aibrix_gen_secrets = 'scripts.generate_secrets:main' [tool.poetry.dependencies] python = ">=3.10,<3.13" diff --git a/python/aibrix/scripts/__init__.py b/python/aibrix/scripts/__init__.py new file mode 100644 index 000000000..26c84ec4b --- /dev/null +++ b/python/aibrix/scripts/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/python/aibrix/scripts/generate_secrets.py b/python/aibrix/scripts/generate_secrets.py index 1e113ce63..a35a148dd 100644 --- a/python/aibrix/scripts/generate_secrets.py +++ b/python/aibrix/scripts/generate_secrets.py @@ -19,6 +19,10 @@ def create_s3_secret_cli(args): """Create an S3 secret from CLI arguments.""" try: + if not args.bucket: + print("❌ Error: Bucket name is required for S3 secret creation.") + sys.exit(1) + generator = SecretGenerator(namespace=args.namespace) secret_name = generator.create_s3_secret( @@ -27,8 +31,7 @@ def create_s3_secret_cli(args): print(f"✅ Successfully created S3 secret: {secret_name}") print(f" Namespace: {args.namespace}") - if args.bucket: - print(f" Bucket: {args.bucket}") + print(f" Bucket: {args.bucket}") except Exception as e: print(f"❌ Failed to create S3 secret: {e}") @@ -38,6 +41,10 @@ def create_s3_secret_cli(args): def create_tos_secret_cli(args): """Create a TOS secret from CLI arguments.""" try: + if not args.bucket: + print("❌ Error: Bucket name is required for TOS secret creation.") + sys.exit(1) + generator = SecretGenerator(namespace=args.namespace) secret_name = generator.create_tos_secret( @@ -46,8 +53,7 @@ def create_tos_secret_cli(args): print(f"✅ Successfully created TOS secret: {secret_name}") print(f" Namespace: {args.namespace}") - if args.bucket: - print(f" Bucket: {args.bucket}") + print(f" Bucket: {args.bucket}") except Exception as e: print(f"❌ Failed to create TOS secret: {e}") diff --git a/python/aibrix/tests/batch/conftest.py b/python/aibrix/tests/batch/conftest.py index 3fb69b5ca..73c08e2c1 100644 --- a/python/aibrix/tests/batch/conftest.py +++ b/python/aibrix/tests/batch/conftest.py @@ -415,7 +415,7 @@ def create_test_app( app = build_app( argparse.Namespace( host=None, - port=8100, + port=8090, enable_fastapi_docs=False, disable_batch_api=False, enable_k8s_job=enable_k8s_job, diff --git a/python/aibrix/tests/batch/testdata/job_rbac.yaml b/python/aibrix/tests/batch/testdata/job_rbac.yaml index 6b326e7b2..854bf0b41 100644 --- a/python/aibrix/tests/batch/testdata/job_rbac.yaml +++ b/python/aibrix/tests/batch/testdata/job_rbac.yaml @@ -3,14 +3,14 @@ # with enable_k8s_job=True --- -# Service Account for job reader (as referenced in config/api-extension/rbac.yaml) +# Test ServiceAccount for job reader apiVersion: v1 kind: ServiceAccount metadata: name: unittest-job-reader-sa namespace: default --- -# Role for job reader +# Test Role for job reader apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: @@ -21,7 +21,7 @@ rules: resources: ["jobs"] verbs: ["get"] # Get permissions only --- -# RoleBinding for job reader +# Test RoleBinding for job reader apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: diff --git a/python/aibrix/tests/test_files_api.py b/python/aibrix/tests/test_files_api.py index c61a1cb3c..286be243d 100644 --- a/python/aibrix/tests/test_files_api.py +++ b/python/aibrix/tests/test_files_api.py @@ -26,7 +26,7 @@ def create_test_app(): return build_app( argparse.Namespace( host=None, - port=8100, + port=8090, enable_fastapi_docs=False, disable_batch_api=True, # Disable batch API to avoid async issues in tests disable_file_api=False, # Enable file API for testing From 91fa1f637cae851d680b655c92f2384bfc625484 Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Wed, 15 Oct 2025 10:24:02 -0700 Subject: [PATCH 04/11] Lint fix Signed-off-by: Jingyuan Zhang --- python/aibrix/aibrix/metadata/app.py | 26 ---------------------- python/aibrix/aibrix/metadata/cache/job.py | 2 +- python/aibrix/scripts/__init__.py | 2 +- 3 files changed, 2 insertions(+), 28 deletions(-) diff --git a/python/aibrix/aibrix/metadata/app.py b/python/aibrix/aibrix/metadata/app.py index 50e6513b6..3e0edde69 100644 --- a/python/aibrix/aibrix/metadata/app.py +++ b/python/aibrix/aibrix/metadata/app.py @@ -82,32 +82,6 @@ async def status_check(request: Request): return JSONResponse(content=status, status_code=200) -@router.get("/status") -async def status_check(request: Request): - """Get detailed status of all components.""" - status: Dict[str, Any] = { - "httpx_client": { - "available": hasattr(request.app.state, "httpx_client_wrapper"), - "status": "initialized" - if hasattr(request.app.state, "httpx_client_wrapper") - else "not_initialized", - }, - "kopf_operator": { - "available": hasattr(request.app.state, "kopf_operator_wrapper"), - }, - "batch_driver": { - "available": hasattr(request.app.state, "batch_driver"), - }, - } - - # Get detailed kopf operator status if available - if hasattr(request.app.state, "kopf_operator_wrapper"): - kopf_status = request.app.state.kopf_operator_wrapper.get_status() - status["kopf_operator"].update(kopf_status) - - return JSONResponse(content=status, status_code=200) - - @asynccontextmanager async def lifespan(app: FastAPI): # Code executed on startup diff --git a/python/aibrix/aibrix/metadata/cache/job.py b/python/aibrix/aibrix/metadata/cache/job.py index dbfd94f89..1ec09fe06 100644 --- a/python/aibrix/aibrix/metadata/cache/job.py +++ b/python/aibrix/aibrix/metadata/cache/job.py @@ -830,7 +830,7 @@ def _batch_job_spec_to_k8s_job( else: suspend = False - job_patch = { + job_patch: Dict[str, Any] = { "metadata": { "name": job_name, # Minimal job-level annotations - most metadata moved to pod diff --git a/python/aibrix/scripts/__init__.py b/python/aibrix/scripts/__init__.py index 26c84ec4b..6461ec1ad 100644 --- a/python/aibrix/scripts/__init__.py +++ b/python/aibrix/scripts/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. From d328454b14aa07ba38ffdbf74c28d3948bc53e8e Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Wed, 15 Oct 2025 14:21:35 -0700 Subject: [PATCH 05/11] Update document for storage configuration. Cleanup files Signed-off-by: Jingyuan Zhang --- config/metadata/kustomization.yaml | 6 ++ config/metadata/metadata.yaml | 54 +----------- config/metadata/s3-env-patch.yaml | 32 +++++++ config/metadata/tos-env-patch.yaml | 37 ++++++++ docs/source/features/batch-api.rst | 102 ++++++++++++++++++++++ python/aibrix/.python-version | 1 - python/aibrix/scripts/generate_secrets.py | 14 +-- python/aibrix/tests/e2e/README.md | 5 +- 8 files changed, 190 insertions(+), 61 deletions(-) create mode 100644 config/metadata/s3-env-patch.yaml create mode 100644 config/metadata/tos-env-patch.yaml delete mode 100644 python/aibrix/.python-version diff --git a/config/metadata/kustomization.yaml b/config/metadata/kustomization.yaml index 35e726cf7..a133581dd 100644 --- a/config/metadata/kustomization.yaml +++ b/config/metadata/kustomization.yaml @@ -8,6 +8,12 @@ configMapGenerator: files: - job_template_patch.yaml +patches: +# Uncomment the following lines to enable S3 as the object store +# - path: s3-env-patch.yaml +# Uncomment the following lines to enable TOS as the object store +# - path: tos-env-patch.yaml + labels: - pairs: app.kubernetes.io/component: aibrix-metadata-service \ No newline at end of file diff --git a/config/metadata/metadata.yaml b/config/metadata/metadata.yaml index 1c3d1a22a..dfbadfc99 100644 --- a/config/metadata/metadata.yaml +++ b/config/metadata/metadata.yaml @@ -37,13 +37,13 @@ rules: # For batch job ServiceAccount management - apiGroups: [""] resources: ["serviceaccounts"] - verbs: ["get", "create", "update", "patch", "delete"] + verbs: ["get", "create", "update", "patch"] - apiGroups: ["rbac.authorization.k8s.io"] # for Role management resources: ["roles"] - verbs: ["get", "create", "update", "patch", "delete"] + verbs: ["get", "create", "update", "patch"] - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management resources: ["rolebindings"] - verbs: ["get", "create", "update", "patch", "delete"] + verbs: ["get", "create", "update", "patch"] # For kopf high availability - apiGroups: ["coordination.k8s.io"] resources: ["leases"] @@ -129,54 +129,6 @@ spec: valueFrom: fieldRef: fieldPath: metadata.namespace - # Object store configuration - # Comment the following lines to disable S3 as the object store - - name: STORAGE_AWS_ACCESS_KEY_ID - valueFrom: - secretKeyRef: - name: aibrix-s3-credentials - key: access-key-id - - name: STORAGE_AWS_SECRET_ACCESS_KEY - valueFrom: - secretKeyRef: - name: aibrix-s3-credentials - key: secret-access-key - - name: STORAGE_AWS_REGION - valueFrom: - secretKeyRef: - name: aibrix-s3-credentials - key: region - - name: STORAGE_AWS_BUCKET - valueFrom: - secretKeyRef: - name: aibrix-s3-credentials - key: bucket-name - # Uncomment the following lines to enable TOS as the object store - # - name: STORAGE_TOS_ACCESS_KEY - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: access-key - # - name: STORAGE_TOS_SECRET_KEY - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: secret-key - # - name: STORAGE_TOS_ENDPOINT - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: endpoint - # - name: STORAGE_TOS_REGION - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: region - # - name: STORAGE_TOS_BUCKET - # valueFrom: - # secretKeyRef: - # name: aibrix-tos-credentials - # key: bucket-name livenessProbe: httpGet: path: /healthz diff --git a/config/metadata/s3-env-patch.yaml b/config/metadata/s3-env-patch.yaml new file mode 100644 index 000000000..4059c2275 --- /dev/null +++ b/config/metadata/s3-env-patch.yaml @@ -0,0 +1,32 @@ +# This patch contains the S3 object store configuration for the metadata service +apiVersion: apps/v1 +kind: Deployment +metadata: + name: metadata-service + namespace: aibrix-system +spec: + template: + spec: + containers: + - name: metadata-service + env: + - name: STORAGE_AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: access-key-id + - name: STORAGE_AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: secret-access-key + - name: STORAGE_AWS_REGION + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: region + - name: STORAGE_AWS_BUCKET + valueFrom: + secretKeyRef: + name: aibrix-s3-credentials + key: bucket-name \ No newline at end of file diff --git a/config/metadata/tos-env-patch.yaml b/config/metadata/tos-env-patch.yaml new file mode 100644 index 000000000..b3a6ad8ca --- /dev/null +++ b/config/metadata/tos-env-patch.yaml @@ -0,0 +1,37 @@ +# This patch contains the TOS object store configuration for the metadata service +apiVersion: apps/v1 +kind: Deployment +metadata: + name: metadata-service + namespace: aibrix-system +spec: + template: + spec: + containers: + - name: metadata-service + env: + - name: STORAGE_TOS_ACCESS_KEY + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: access-key + - name: STORAGE_TOS_SECRET_KEY + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: secret-key + - name: STORAGE_TOS_ENDPOINT + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: endpoint + - name: STORAGE_TOS_REGION + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: region + - name: STORAGE_TOS_BUCKET + valueFrom: + secretKeyRef: + name: aibrix-tos-credentials + key: bucket-name \ No newline at end of file diff --git a/docs/source/features/batch-api.rst b/docs/source/features/batch-api.rst index ce7f9a1ed..b3a2a87e5 100644 --- a/docs/source/features/batch-api.rst +++ b/docs/source/features/batch-api.rst @@ -145,6 +145,108 @@ Components 4. **Storage Backend**: S3, Redis, or local filesystem for file storage and job state 5. **Files API**: OpenAI-compatible file upload/download endpoints +Deployment +---------- + +Storage Backend Configuration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The Batch API requires a storage backend for file operations. AIBrix supports multiple storage backends including S3, TOS, and local storage. To enable cloud object storage, you need to configure credentials and enable the appropriate storage patches. + +**Enabling S3 Storage** + +To enable S3 as the storage backend for batch operations: + +1. **Generate S3 Credentials Secret:** + + Use the AIBrix secret generation tool to create the necessary Kubernetes secrets: + + .. code-block:: bash + + # Install the AIBrix package in development mode + cd python/aibrix && pip install -e . + + # Generate S3 credentials secret + aibrix_gen_secrets s3 --bucket your-s3-bucket-name --namespace aibrix-system + + # Generate S3 credentials secret for Job Executor + aibrix_gen_secrets s3 --bucket your-s3-bucket-name --namespace default + + This command will: + + - Create a Kubernetes secret named ``aibrix-s3-credentials`` in the ``aibrix-system`` namespace + - Configure the secret with your S3 bucket name and credentials + - Set up the necessary environment variables for the metadata service + +2. **Enable S3 Environment Variables:** + + Uncomment the S3 patch in the metadata service configuration: + + .. code-block:: bash + + # Edit the kustomization file + vim config/metadata/kustomization.yaml + + Find and uncomment the following line: + + .. code-block:: yaml + + patches: + - path: s3-env-patch.yaml # Uncomment this line + + The patch will inject the S3 environment variables into the metadata service deployment. + +3. **Apply the Configuration:** + + Deploy the updated configuration: + + .. code-block:: bash + + kubectl apply -k config/default + +**Enabling TOS Storage** + +For TOS (Tencent Object Storage), follow similar steps: + +1. **Generate TOS Credentials Secret:** + + .. code-block:: bash + + # Install the AIBrix package in development mode + cd python/aibrix && pip install -e . + + # Generate TOS credentials secret + aibrix_gen_secrets tos --bucket your-tos-bucket-name --namespace aibrix-system + + # Generate TOS credentials secret for Job Executor + aibrix_gen_secrets tos --bucket your-tos-bucket-name --namespace default + +2. **Enable TOS Environment Variables:** + + Uncomment the TOS patch in the metadata service configuration: + + .. code-block:: bash + + # Edit the kustomization file + vim config/metadata/kustomization.yaml + + Find and uncomment the following line: + + .. code-block:: yaml + + patches: + - path: tos-env-patch.yaml # Uncomment this line + + The patch will inject the TOS environment variables into the metadata service deployment. + +3. **Apply the Configuration:** + + Deploy the updated configuration: + + .. code-block:: bash + + kubectl apply -k config/default + Examples -------- diff --git a/python/aibrix/.python-version b/python/aibrix/.python-version deleted file mode 100644 index 2c0733315..000000000 --- a/python/aibrix/.python-version +++ /dev/null @@ -1 +0,0 @@ -3.11 diff --git a/python/aibrix/scripts/generate_secrets.py b/python/aibrix/scripts/generate_secrets.py index a35a148dd..ef1a39088 100644 --- a/python/aibrix/scripts/generate_secrets.py +++ b/python/aibrix/scripts/generate_secrets.py @@ -111,23 +111,23 @@ def main(): epilog=""" Examples: # Create S3 secret with default name - python generate_secrets.py s3 --bucket my-bucket + python -m scripts.generate_secrets s3 --bucket my-bucket # Create S3 secret with custom name - python generate_secrets.py s3 --bucket my-bucket --name my-s3-creds + python -m scripts.generate_secrets s3 --bucket my-bucket --name my-s3-creds # Create TOS secret (requires TOS_* environment variables) - python generate_secrets.py tos --bucket my-tos-bucket + python -m scripts.generate_secrets tos --bucket my-tos-bucket # Delete a secret - python generate_secrets.py delete my-secret-name + python -m scripts.generate_secrets delete my-secret-name # List all secrets in namespace - python generate_secrets.py list + python -m scripts.generate_secrets list # Use custom namespace (either position works) - python generate_secrets.py --namespace my-namespace s3 --bucket my-bucket - python generate_secrets.py s3 --bucket my-bucket --namespace my-namespace + python -m scripts.generate_secrets --namespace my-namespace s3 --bucket my-bucket + python -m scripts.generate_secrets s3 --bucket my-bucket --namespace my-namespace """, ) diff --git a/python/aibrix/tests/e2e/README.md b/python/aibrix/tests/e2e/README.md index 8ebba1fdc..ccbf76647 100644 --- a/python/aibrix/tests/e2e/README.md +++ b/python/aibrix/tests/e2e/README.md @@ -16,7 +16,8 @@ This directory contains end-to-end tests for Aibrix services that run against re 2. **Generate Credentials**: Ensure object store is acceesible. Using S3 as an example: ```bash - python ../../scripts/generate_secrets.py s3 --bucket + cd /path/to/aibrix/python/aibrix + python -m scripts.generate_secrets s3 --bucket ``` The script will read s3 credentials setup using ```aws configure``` @@ -68,7 +69,7 @@ pytest tests/e2e/test_batch_api.py -v -s ## API Endpoints ### Health Endpoints -- `/healthz` - General service health check +- `/v1/batches` - General service availability check by list all batches. ## Configuration From c35aa2bae266abeed14df1174d4451ff4092a702 Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Wed, 15 Oct 2025 14:25:29 -0700 Subject: [PATCH 06/11] Explicit expose commands Signed-off-by: Jingyuan Zhang --- dist/chart/templates/metadata-service/deployment.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dist/chart/templates/metadata-service/deployment.yaml b/dist/chart/templates/metadata-service/deployment.yaml index 0c1f8c0e2..2d9b3dcd6 100644 --- a/dist/chart/templates/metadata-service/deployment.yaml +++ b/dist/chart/templates/metadata-service/deployment.yaml @@ -31,6 +31,11 @@ spec: - name: metadata-service image: {{ .Values.metadata.service.container.image.repository }}:{{ .Values.metadata.service.container.image.tag }} imagePullPolicy: {{ .Values.metadata.service.container.image.imagePullPolicy | default "IfNotPresent" }} + command: + - 'aibrix_metadata' + - '--host' + - '0.0.0.0' + - '--enable-k8s-job' ports: - containerPort: 8090 resources: {{ toYaml .Values.metadata.service.container.resources | nindent 12 }} From 36b01d26923bdd89d5c1558a009dccaea90e44d0 Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Wed, 15 Oct 2025 14:49:01 -0700 Subject: [PATCH 07/11] Fix comand line in helm chart. Signed-off-by: Jingyuan Zhang --- dist/chart/templates/metadata-service/deployment.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dist/chart/templates/metadata-service/deployment.yaml b/dist/chart/templates/metadata-service/deployment.yaml index 2d9b3dcd6..e32d5a449 100644 --- a/dist/chart/templates/metadata-service/deployment.yaml +++ b/dist/chart/templates/metadata-service/deployment.yaml @@ -32,10 +32,10 @@ spec: image: {{ .Values.metadata.service.container.image.repository }}:{{ .Values.metadata.service.container.image.tag }} imagePullPolicy: {{ .Values.metadata.service.container.image.imagePullPolicy | default "IfNotPresent" }} command: - - 'aibrix_metadata' - - '--host' - - '0.0.0.0' - - '--enable-k8s-job' + - aibrix_metadata + - --host + - "0.0.0.0" + - --enable-k8s-job ports: - containerPort: 8090 resources: {{ toYaml .Values.metadata.service.container.resources | nindent 12 }} From c59b1cb4b3b0a75b7acb208eadb125b05fb53b0b Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Wed, 15 Oct 2025 15:07:44 -0700 Subject: [PATCH 08/11] use full command for aibrix_metadata Signed-off-by: Jingyuan Zhang --- dist/chart/templates/metadata-service/deployment.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dist/chart/templates/metadata-service/deployment.yaml b/dist/chart/templates/metadata-service/deployment.yaml index e32d5a449..596804545 100644 --- a/dist/chart/templates/metadata-service/deployment.yaml +++ b/dist/chart/templates/metadata-service/deployment.yaml @@ -32,7 +32,9 @@ spec: image: {{ .Values.metadata.service.container.image.repository }}:{{ .Values.metadata.service.container.image.tag }} imagePullPolicy: {{ .Values.metadata.service.container.image.imagePullPolicy | default "IfNotPresent" }} command: - - aibrix_metadata + - python + - -m + - aibrix.metadata.app - --host - "0.0.0.0" - --enable-k8s-job From d5d247c358ea0275ffaa3693dc17667f8f8493c0 Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Wed, 15 Oct 2025 22:48:26 -0700 Subject: [PATCH 09/11] Disable k8s-job in metadata service by default. For it requires object store to work. Signed-off-by: Jingyuan Zhang --- build/container/Dockerfile.metadata | 2 +- config/metadata/job_template_patch.yaml | 2 +- config/metadata/metadata.yaml | 8 +------- config/metadata/s3-env-patch.yaml | 8 ++++++++ config/metadata/tos-env-patch.yaml | 8 ++++++++ dist/chart/templates/metadata-service/deployment.yaml | 1 - python/aibrix/tests/e2e/test_batch_api.py | 11 +++-------- 7 files changed, 22 insertions(+), 18 deletions(-) diff --git a/build/container/Dockerfile.metadata b/build/container/Dockerfile.metadata index a6dfcf64d..29696634d 100644 --- a/build/container/Dockerfile.metadata +++ b/build/container/Dockerfile.metadata @@ -46,5 +46,5 @@ RUN apt-get update \ EXPOSE 8090 # Set entrypoint for Metadata service -ENTRYPOINT ["aibrix_metadata", "--enable-k8s-job", "--host", "0.0.0.0"] +ENTRYPOINT ["aibrix_metadata", "--host", "0.0.0.0"] diff --git a/config/metadata/job_template_patch.yaml b/config/metadata/job_template_patch.yaml index b1418687b..7e34b726b 100644 --- a/config/metadata/job_template_patch.yaml +++ b/config/metadata/job_template_patch.yaml @@ -13,4 +13,4 @@ spec: - name: batch-worker image: aibrix/runtime:nightly # Customizable, runtime image - name: llm-engine - image: aibrix/vllm-mock:nightly # Customizable, LLM engine image \ No newline at end of file + image: aibrix/vllm-mock:nightly # Customizable, customize your LLM engine and readinessProbe \ No newline at end of file diff --git a/config/metadata/metadata.yaml b/config/metadata/metadata.yaml index dfbadfc99..cf533f06c 100644 --- a/config/metadata/metadata.yaml +++ b/config/metadata/metadata.yaml @@ -96,13 +96,7 @@ spec: - name: metadata-service image: metadata-service:latest imagePullPolicy: IfNotPresent - command: - - aibrix_metadata - - --host - - "0.0.0.0" - - --enable-k8s-job - - --k8s-job-patch - - /app/config/job_template_patch.yaml + # Enable S3 or TOS to enable-k8s-job, default disabled. ports: - containerPort: 8090 volumeMounts: diff --git a/config/metadata/s3-env-patch.yaml b/config/metadata/s3-env-patch.yaml index 4059c2275..379d90bac 100644 --- a/config/metadata/s3-env-patch.yaml +++ b/config/metadata/s3-env-patch.yaml @@ -1,4 +1,5 @@ # This patch contains the S3 object store configuration for the metadata service +# !!Important: Please make sure aibrix-s3-credentials secret exists in both aibrix-system and default namespaces. apiVersion: apps/v1 kind: Deployment metadata: @@ -9,6 +10,13 @@ spec: spec: containers: - name: metadata-service + command: + - aibrix_metadata + - --host + - "0.0.0.0" + - --enable-k8s-job + - --k8s-job-patch + - /app/config/job_template_patch.yaml env: - name: STORAGE_AWS_ACCESS_KEY_ID valueFrom: diff --git a/config/metadata/tos-env-patch.yaml b/config/metadata/tos-env-patch.yaml index b3a6ad8ca..dd735923a 100644 --- a/config/metadata/tos-env-patch.yaml +++ b/config/metadata/tos-env-patch.yaml @@ -1,4 +1,5 @@ # This patch contains the TOS object store configuration for the metadata service +# !!Important: Please make sure aibrix-tos-credentials secret exists in both aibrix-system and default namespaces. apiVersion: apps/v1 kind: Deployment metadata: @@ -9,6 +10,13 @@ spec: spec: containers: - name: metadata-service + command: + - aibrix_metadata + - --host + - "0.0.0.0" + - --enable-k8s-job + - --k8s-job-patch + - /app/config/job_template_patch.yaml env: - name: STORAGE_TOS_ACCESS_KEY valueFrom: diff --git a/dist/chart/templates/metadata-service/deployment.yaml b/dist/chart/templates/metadata-service/deployment.yaml index 596804545..93ecc6cd4 100644 --- a/dist/chart/templates/metadata-service/deployment.yaml +++ b/dist/chart/templates/metadata-service/deployment.yaml @@ -37,7 +37,6 @@ spec: - aibrix.metadata.app - --host - "0.0.0.0" - - --enable-k8s-job ports: - containerPort: 8090 resources: {{ toYaml .Values.metadata.service.container.resources | nindent 12 }} diff --git a/python/aibrix/tests/e2e/test_batch_api.py b/python/aibrix/tests/e2e/test_batch_api.py index eabfe8c00..9eeb9ab67 100644 --- a/python/aibrix/tests/e2e/test_batch_api.py +++ b/python/aibrix/tests/e2e/test_batch_api.py @@ -133,7 +133,8 @@ async def check_service_health(base_url: str) -> bool: async with httpx.AsyncClient(timeout=10.0) as client: # Check general health endpoint health_response = await client.get(f"{base_url}/v1/batches") - return health_response.status_code == 200 + assert health_response.status_code == 200, f"Health check response: {health_response}" + return True except Exception as e: print(f"Health check failed: {e}") return False @@ -147,13 +148,7 @@ def service_health(): print(f"🔍 Checking service health at {base_url}...") # Run the async health check in a sync context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - is_healthy = loop.run_until_complete(check_service_health(base_url)) - finally: - loop.close() - + is_healthy = asyncio.run(check_service_health(base_url)) if not is_healthy: pytest.skip(f"Service at {base_url} is not available or healthy") From 9711d674b6e5728052c0fa1f775ee1e4b627e274 Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Wed, 15 Oct 2025 23:07:53 -0700 Subject: [PATCH 10/11] Lint fix Signed-off-by: Jingyuan Zhang --- python/aibrix/tests/e2e/test_batch_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/aibrix/tests/e2e/test_batch_api.py b/python/aibrix/tests/e2e/test_batch_api.py index 9eeb9ab67..5d5bbbd48 100644 --- a/python/aibrix/tests/e2e/test_batch_api.py +++ b/python/aibrix/tests/e2e/test_batch_api.py @@ -133,7 +133,9 @@ async def check_service_health(base_url: str) -> bool: async with httpx.AsyncClient(timeout=10.0) as client: # Check general health endpoint health_response = await client.get(f"{base_url}/v1/batches") - assert health_response.status_code == 200, f"Health check response: {health_response}" + assert ( + health_response.status_code == 200 + ), f"Health check response: {health_response}" return True except Exception as e: print(f"Health check failed: {e}") From 23bfb48d138d4488346ffe16890b9a669603b4e1 Mon Sep 17 00:00:00 2001 From: Jingyuan Zhang Date: Thu, 16 Oct 2025 17:07:46 -0700 Subject: [PATCH 11/11] Metadata will not create serviceaccount for job anymore, and new config/job is create for user to deploy job rbac. Signed-off-by: Jingyuan Zhang --- .../setting => config/job}/k8s_job_rbac.yaml | 0 config/job/kustomization.yaml | 4 + config/metadata/job_template_patch.yaml | 12 +- config/metadata/metadata.yaml | 10 -- .../templates/metadata-service/rbac.yaml | 10 -- docs/source/features/batch-api.rst | 90 +++++++------- python/aibrix/aibrix/metadata/cache/job.py | 111 ------------------ 7 files changed, 60 insertions(+), 177 deletions(-) rename {python/aibrix/aibrix/metadata/setting => config/job}/k8s_job_rbac.yaml (100%) create mode 100644 config/job/kustomization.yaml diff --git a/python/aibrix/aibrix/metadata/setting/k8s_job_rbac.yaml b/config/job/k8s_job_rbac.yaml similarity index 100% rename from python/aibrix/aibrix/metadata/setting/k8s_job_rbac.yaml rename to config/job/k8s_job_rbac.yaml diff --git a/config/job/kustomization.yaml b/config/job/kustomization.yaml new file mode 100644 index 000000000..cf3b5d382 --- /dev/null +++ b/config/job/kustomization.yaml @@ -0,0 +1,4 @@ +kind: Kustomization + +resources: +- k8s_job_rbac.yaml \ No newline at end of file diff --git a/config/metadata/job_template_patch.yaml b/config/metadata/job_template_patch.yaml index 7e34b726b..a454e163a 100644 --- a/config/metadata/job_template_patch.yaml +++ b/config/metadata/job_template_patch.yaml @@ -11,6 +11,14 @@ spec: spec: containers: - name: batch-worker - image: aibrix/runtime:nightly # Customizable, runtime image + image: aibrix/runtime:nightly # Customizable, batch job worker image - name: llm-engine - image: aibrix/vllm-mock:nightly # Customizable, customize your LLM engine and readinessProbe \ No newline at end of file + image: aibrix/vllm-mock:nightly # Customizable, customize your LLM engine image + # command: ["/bin/sh", "-c"] # Customization is not recommended. Know what you are doing. + args: # Customizable in the format of "WORKER_VICTIM=1 [your command] || true" + - | + # Run llm engine. + # 'WORKER_VICTIM=1' helps the batch-worker to identify llm-engine process. + # '|| true' at the end ensures the container llm-engine never fails. + WORKER_VICTIM=1 python app.py || true + readinessProbe: # Customizable, customize your readinessProbe \ No newline at end of file diff --git a/config/metadata/metadata.yaml b/config/metadata/metadata.yaml index cf533f06c..968445fb6 100644 --- a/config/metadata/metadata.yaml +++ b/config/metadata/metadata.yaml @@ -34,16 +34,6 @@ rules: - apiGroups: ["batch"] resources: ["jobs"] verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - # For batch job ServiceAccount management - - apiGroups: [""] - resources: ["serviceaccounts"] - verbs: ["get", "create", "update", "patch"] - - apiGroups: ["rbac.authorization.k8s.io"] # for Role management - resources: ["roles"] - verbs: ["get", "create", "update", "patch"] - - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management - resources: ["rolebindings"] - verbs: ["get", "create", "update", "patch"] # For kopf high availability - apiGroups: ["coordination.k8s.io"] resources: ["leases"] diff --git a/dist/chart/templates/metadata-service/rbac.yaml b/dist/chart/templates/metadata-service/rbac.yaml index ea6f33283..705b64cc7 100644 --- a/dist/chart/templates/metadata-service/rbac.yaml +++ b/dist/chart/templates/metadata-service/rbac.yaml @@ -28,16 +28,6 @@ rules: - apiGroups: ["batch"] resources: ["jobs"] verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - # For batch job ServiceAccount management - - apiGroups: [""] - resources: ["serviceaccounts"] - verbs: ["get", "create", "update", "patch", "delete"] - - apiGroups: ["rbac.authorization.k8s.io"] # for Role management - resources: ["roles"] - verbs: ["get", "create", "update", "patch", "delete"] - - apiGroups: ["rbac.authorization.k8s.io"] # for RoleBinding management - resources: ["rolebindings"] - verbs: ["get", "create", "update", "patch", "delete"] # For kopf high availability - apiGroups: ["coordination.k8s.io"] resources: ["leases"] diff --git a/docs/source/features/batch-api.rst b/docs/source/features/batch-api.rst index b3a2a87e5..c280f2fad 100644 --- a/docs/source/features/batch-api.rst +++ b/docs/source/features/batch-api.rst @@ -159,50 +159,51 @@ To enable S3 as the storage backend for batch operations: 1. **Generate S3 Credentials Secret:** - Use the AIBrix secret generation tool to create the necessary Kubernetes secrets: +Use the AIBrix secret generation tool to create the necessary Kubernetes secrets: - .. code-block:: bash +.. code-block:: bash - # Install the AIBrix package in development mode - cd python/aibrix && pip install -e . + # Install the AIBrix package in development mode + cd python/aibrix && pip install -e . - # Generate S3 credentials secret - aibrix_gen_secrets s3 --bucket your-s3-bucket-name --namespace aibrix-system + # Generate S3 credentials secret + aibrix_gen_secrets s3 --bucket your-s3-bucket-name --namespace aibrix-system - # Generate S3 credentials secret for Job Executor - aibrix_gen_secrets s3 --bucket your-s3-bucket-name --namespace default + # Generate S3 credentials secret for Job Executor + aibrix_gen_secrets s3 --bucket your-s3-bucket-name --namespace default - This command will: - - - Create a Kubernetes secret named ``aibrix-s3-credentials`` in the ``aibrix-system`` namespace - - Configure the secret with your S3 bucket name and credentials - - Set up the necessary environment variables for the metadata service +This command will: + +- Create a Kubernetes secret named ``aibrix-s3-credentials`` in the ``aibrix-system`` namespace +- Configure the secret with your S3 bucket name and credentials +- Set up the necessary environment variables for the metadata service 2. **Enable S3 Environment Variables:** - Uncomment the S3 patch in the metadata service configuration: +Uncomment the S3 patch in the metadata service configuration: - .. code-block:: bash +.. code-block:: bash - # Edit the kustomization file - vim config/metadata/kustomization.yaml + # Edit the kustomization file + vim config/metadata/kustomization.yaml - Find and uncomment the following line: +Find and uncomment the following line: - .. code-block:: yaml +.. code-block:: yaml - patches: - - path: s3-env-patch.yaml # Uncomment this line + patches: + - path: s3-env-patch.yaml # Uncomment this line - The patch will inject the S3 environment variables into the metadata service deployment. +The patch will inject the S3 environment variables into the metadata service deployment. 3. **Apply the Configuration:** - Deploy the updated configuration: +Deploy the job rbac andupdated configuration: - .. code-block:: bash +.. code-block:: bash - kubectl apply -k config/default + kubectl apply -k config/job + kubectl apply -k config/default **Enabling TOS Storage** @@ -210,42 +211,43 @@ For TOS (Tencent Object Storage), follow similar steps: 1. **Generate TOS Credentials Secret:** - .. code-block:: bash +.. code-block:: bash - # Install the AIBrix package in development mode - cd python/aibrix && pip install -e . + # Install the AIBrix package in development mode + cd python/aibrix && pip install -e . - # Generate TOS credentials secret - aibrix_gen_secrets tos --bucket your-tos-bucket-name --namespace aibrix-system + # Generate TOS credentials secret + aibrix_gen_secrets tos --bucket your-tos-bucket-name --namespace aibrix-system - # Generate TOS credentials secret for Job Executor - aibrix_gen_secrets tos --bucket your-tos-bucket-name --namespace default + # Generate TOS credentials secret for Job Executor + aibrix_gen_secrets tos --bucket your-tos-bucket-name --namespace default 2. **Enable TOS Environment Variables:** - Uncomment the TOS patch in the metadata service configuration: +Uncomment the TOS patch in the metadata service configuration: - .. code-block:: bash +.. code-block:: bash - # Edit the kustomization file - vim config/metadata/kustomization.yaml + # Edit the kustomization file + vim config/metadata/kustomization.yaml - Find and uncomment the following line: +Find and uncomment the following line: - .. code-block:: yaml +.. code-block:: yaml - patches: - - path: tos-env-patch.yaml # Uncomment this line + patches: + - path: tos-env-patch.yaml # Uncomment this line - The patch will inject the TOS environment variables into the metadata service deployment. +The patch will inject the TOS environment variables into the metadata service deployment. 3. **Apply the Configuration:** - Deploy the updated configuration: +Deploy the job rbac and updated configuration: - .. code-block:: bash +.. code-block:: bash - kubectl apply -k config/default + kubectl apply -k config/job + kubectl apply -k config/default Examples -------- diff --git a/python/aibrix/aibrix/metadata/cache/job.py b/python/aibrix/aibrix/metadata/cache/job.py index 1ec09fe06..a874f2fc7 100644 --- a/python/aibrix/aibrix/metadata/cache/job.py +++ b/python/aibrix/aibrix/metadata/cache/job.py @@ -168,117 +168,6 @@ def __init__(self, template_patch_path: Optional[Path] = None) -> None: self.batch_v1_api = client.BatchV1Api() self.core_v1_api = client.CoreV1Api() - self.rbac_v1_api = client.RbacAuthorizationV1Api() - - # Apply RBAC resources for job execution - self._apply_job_rbac(template_dir) - - def _apply_job_rbac(self, template_dir: Path) -> None: - """Apply RBAC resources for job execution from k8s_job_rbac.yaml.""" - try: - rbac_path = template_dir / "k8s_job_rbac.yaml" - with open(rbac_path, "r") as f: - rbac_docs = list(yaml.safe_load_all(f)) - - for doc in rbac_docs: - if not doc: # Skip empty documents - continue - - kind = doc.get("kind") - metadata = doc.get("metadata", {}) - name = metadata.get("name") - namespace = metadata.get("namespace", "default") - - try: - if kind == "ServiceAccount": - # Try to create, if exists then update - try: - self.core_v1_api.create_namespaced_service_account( - namespace=namespace, body=doc - ) - logger.info( - f"Created ServiceAccount: {doc['metadata']['name']}" - ) - except ApiException as e: - if e.status == 409: # Already exists - self.core_v1_api.patch_namespaced_service_account( - name=doc["metadata"]["name"], - namespace=namespace, - body=doc, - ) - logger.info( - f"Updated ServiceAccount: {doc['metadata']['name']}" - ) - else: - raise - - elif kind == "Role": - try: - self.rbac_v1_api.create_namespaced_role( - namespace=namespace, body=doc - ) - logger.info(f"Created Role: {doc['metadata']['name']}") - except ApiException as e: - if e.status == 409: # Already exists - self.rbac_v1_api.patch_namespaced_role( - name=doc["metadata"]["name"], - namespace=namespace, - body=doc, - ) - logger.info(f"Updated Role: {doc['metadata']['name']}") - else: - raise - - elif kind == "RoleBinding": - try: - self.rbac_v1_api.create_namespaced_role_binding( - namespace=namespace, body=doc - ) - logger.info( - f"Created RoleBinding: {doc['metadata']['name']}" - ) - except ApiException as e: - if e.status == 409: # Already exists - self.rbac_v1_api.patch_namespaced_role_binding( - name=doc["metadata"]["name"], - namespace=namespace, - body=doc, - ) - logger.info( - f"Updated RoleBinding: {doc['metadata']['name']}" - ) - else: - raise - else: - logger.warning(f"Unsupported RBAC resource kind: {kind}") - - except ApiException as e: - logger.error( - f"Failed to apply {kind} {name}: {e.status} {e.reason}", - error=str(e), - kind=kind, - name=name, - namespace=namespace, - ) # type: ignore[call-arg] - # Don't raise here to allow other resources to be applied - - except FileNotFoundError: - logger.warning( - "RBAC template not found, skipping RBAC setup", - template_path=str(rbac_path), - ) # type: ignore[call-arg] - except yaml.YAMLError as e: - logger.error( - "Failed to parse RBAC template", - error=str(e), - template_path=str(rbac_path), - ) # type: ignore[call-arg] - except Exception as e: - logger.error( - "Unexpected error applying RBAC resources", - error=str(e), - template_path=str(rbac_path), - ) # type: ignore[call-arg] def is_scheduler_enabled(self) -> bool: """Check if JobEntityManager has own scheduler enabled."""