Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 4 additions & 91 deletions label_studio/core/current_request.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from threading import local
from typing import Any

from django.core.signals import request_finished
from django.dispatch import receiver
Expand All @@ -8,104 +7,18 @@
_thread_locals = local()


class CurrentContext:
@classmethod
def set(cls, key: str, value: Any, shared: bool = True) -> None:
if not hasattr(_thread_locals, 'data'):
_thread_locals.data = {}
if not hasattr(_thread_locals, 'job_data'):
_thread_locals.job_data = {}

if shared:
_thread_locals.job_data[key] = value
else:
_thread_locals.data[key] = value

@classmethod
def get(cls, key: str, default=None):
return getattr(_thread_locals, 'job_data', {}).get(key, getattr(_thread_locals, 'data', {}).get(key, default))

@classmethod
def set_request(cls, request):
_thread_locals.request = request
if request.user:
cls.set_user(request.user)

@classmethod
def get_organization_id(cls):
return cls.get('organization_id')

@classmethod
def set_organization_id(cls, organization_id: int):
cls.set('organization_id', organization_id)

@classmethod
def get_user(cls):
return cls.get('user')

@classmethod
def set_user(cls, user):
cls.set('user', user)
if getattr(user, 'active_organization_id', None):
cls.set_organization_id(user.active_organization_id)

@classmethod
def set_fsm_disabled(cls, disabled: bool):
"""
Temporarily disable/enable FSM for the current thread.

This is useful for test cleanup and bulk operations where FSM state
tracking is not needed and would cause performance issues.

Args:
disabled: True to disable FSM, False to enable it
"""
cls.set('fsm_disabled', disabled)

@classmethod
def is_fsm_disabled(cls) -> bool:
"""
Check if FSM is disabled for the current thread.

Returns:
True if FSM is disabled, False otherwise
"""
return cls.get('fsm_disabled', False)

@classmethod
def get_job_data(cls) -> dict:
"""
This data will be shared to jobs spawned by the current thread.
"""
return getattr(_thread_locals, 'job_data', {})

@classmethod
def clear(cls) -> None:
if hasattr(_thread_locals, 'data'):
delattr(_thread_locals, 'data')

if hasattr(_thread_locals, 'job_data'):
delattr(_thread_locals, 'job_data')

if hasattr(_thread_locals, 'request'):
del _thread_locals.request

@classmethod
def get_request(cls):
return getattr(_thread_locals, 'request', None)


def get_current_request():
"""returns the request object for this thread"""
result = CurrentContext.get_request()
result = getattr(_thread_locals, 'request', None)
return result


class ThreadLocalMiddleware(CommonMiddleware):
def process_request(self, request):
CurrentContext.set_request(request)
_thread_locals.request = request


@receiver(request_finished)
def clean_request(sender, **kwargs):
CurrentContext.clear()
if hasattr(_thread_locals, 'request'):
del _thread_locals.request
62 changes: 10 additions & 52 deletions label_studio/core/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import sys
from datetime import timedelta
from functools import partial
from typing import Any

import django_rq
import redis
from core.current_request import CurrentContext
from django.conf import settings
from django_rq import get_connection
from rq.command import send_stop_job_command
Expand Down Expand Up @@ -82,40 +80,6 @@ def redis_connected():
return redis_healthcheck()


def _is_serializable(value: Any) -> bool:
"""Check if a value can be serialized for job context."""
return isinstance(value, (str, int, float, bool, list, dict, type(None)))


def _capture_context() -> dict:
"""
Capture the current context for passing to a job.
Returns a dictionary of context data that can be serialized.
"""
context_data = {}

# Get user information
if user := CurrentContext.get_user():
context_data['user_id'] = user.id

# Get organization if set separately
if org_id := CurrentContext.get_organization_id():
context_data['organization_id'] = org_id

# If organization_id is not set, try to get it from the user, this ensures that we have an organization_id for the job
# And it prefers the original requesting user's organization_id over the current active organization_id of the user which could change during async jobs
if not org_id and user and hasattr(user, 'active_organization_id') and user.active_organization_id:
context_data['organization_id'] = user.active_organization_id

# Get any custom context values (exclude non-serializable objects)
job_data = CurrentContext.get_job_data()
for key, value in job_data.items():
if key not in ['user', 'request'] and _is_serializable(value):
context_data[key] = value

return context_data


def redis_get(key):
if not redis_healthcheck():
return
Expand Down Expand Up @@ -148,9 +112,7 @@ def redis_delete(key):

def start_job_async_or_sync(job, *args, in_seconds=0, **kwargs):
"""
Start job async with redis or sync if redis is not connected.
Automatically preserves context for async jobs and clears it after completion.

Start job async with redis or sync if redis is not connected
:param job: Job function
:param args: Function arguments
:param in_seconds: Job will be delayed for in_seconds
Expand All @@ -160,29 +122,28 @@ def start_job_async_or_sync(job, *args, in_seconds=0, **kwargs):

redis = redis_connected() and kwargs.get('redis', True)
queue_name = kwargs.get('queue_name', 'default')

if 'queue_name' in kwargs:
del kwargs['queue_name']
if 'redis' in kwargs:
del kwargs['redis']

job_timeout = None
if 'job_timeout' in kwargs:
job_timeout = kwargs['job_timeout']
del kwargs['job_timeout']

if redis:
# Async execution with Redis - wrap job for context management
# Auto-capture request_id from thread local and pass it via job meta
try:
context_data = _capture_context()
from label_studio.core.current_request import _thread_locals

if context_data:
request_id = getattr(_thread_locals, 'request_id', None)
if request_id:
# Store in job meta for worker access
meta = kwargs.get('meta', {})
# Store context data in job meta for worker access
meta.update(context_data)
meta['request_id'] = request_id
kwargs['meta'] = meta
except Exception:
logger.info(f'Failed to capture context for job {job.__name__} on queue {queue_name}')
# Fail silently if no request context
pass

try:
args_info = _truncate_args_for_logging(args, kwargs)
Expand All @@ -193,7 +154,6 @@ def start_job_async_or_sync(job, *args, in_seconds=0, **kwargs):
enqueue_method = queue.enqueue
if in_seconds > 0:
enqueue_method = partial(queue.enqueue_in, timedelta(seconds=in_seconds))

job = enqueue_method(
job,
*args,
Expand All @@ -204,10 +164,8 @@ def start_job_async_or_sync(job, *args, in_seconds=0, **kwargs):
return job
else:
on_failure = kwargs.pop('on_failure', None)

try:
result = job(*args, **kwargs)
return result
return job(*args, **kwargs)
except Exception:
exc_info = sys.exc_info()
if on_failure:
Expand Down
2 changes: 1 addition & 1 deletion label_studio/core/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@
'rest_framework.authtoken',
'rest_framework_simplejwt.token_blacklist',
'drf_generators',
'fsm', # MUST be before apps that register FSM transitions (projects, tasks)
'core',
'users',
'organizations',
Expand All @@ -233,6 +232,7 @@
'ml_model_providers',
'jwt_auth',
'session_policy',
'fsm',
]

MIDDLEWARE = [
Expand Down
6 changes: 0 additions & 6 deletions label_studio/core/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def _assert_delete_and_restore_equal(self, drow, original):
original_dict.pop('_state')
original_created_at = original_dict.pop('created_at')
original_updated_at = original_dict.pop('updated_at')
# Pop _original_values - this is an internal FSM field that's recreated on __init__
# and shouldn't be compared
original_dict.pop('_original_values', None)
original.delete()

for deserialized_object in serializers.deserialize('json', json.dumps([drow.data])):
Expand All @@ -31,9 +28,6 @@ def _assert_delete_and_restore_equal(self, drow, original):
new_dict.pop('_state')
new_created_at = new_dict.pop('created_at')
new_updated_at = new_dict.pop('updated_at')
# Pop _original_values - this is an internal FSM field that's recreated on __init__
# and shouldn't be compared
new_dict.pop('_original_values', None)

assert new_dict == original_dict
# Datetime loses microsecond precision, so we can't compare them directly
Expand Down
4 changes: 1 addition & 3 deletions label_studio/data_manager/actions/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def delete_tasks_annotations(project, queryset, **kwargs):
drafts = drafts.filter(user=int(annotator_id))
project.summary.remove_created_drafts_and_labels(drafts)

# count before delete to return the number of deleted items, not including cascade deletions
count = annotations.count()
annotations.delete()
count, _ = annotations.delete()
drafts.delete() # since task-level annotation drafts will not have been deleted by CASCADE
emit_webhooks_for_instance(project.organization, project, WebhookAction.ANNOTATIONS_DELETED, annotations_ids)
request = kwargs['request']
Expand Down
2 changes: 1 addition & 1 deletion label_studio/fsm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class OrderStateChoices(models.TextChoices):
### 2. Create State Model

```python
from fsm.state_models import BaseState
from fsm.models import BaseState
from fsm.registry import register_state_model

@register_state_model('order')
Expand Down
Loading
Loading