Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions label_studio/core/current_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def set_user(cls, user):
if getattr(user, 'active_organization_id', None):
cls.set_organization_id(user.active_organization_id)

# PERFORMANCE: Cache FSM enabled state at request level when user is set
# This allows all downstream code to check a simple boolean property
# instead of repeatedly calling feature flag checks and possibly having to resolve the user, org and other related objects
cls._cache_fsm_enabled_state(user)

@classmethod
def set_fsm_disabled(cls, disabled: bool):
"""
Expand All @@ -72,6 +77,50 @@ def is_fsm_disabled(cls) -> bool:
"""
return cls.get('fsm_disabled', False)

@classmethod
def _cache_fsm_enabled_state(cls, user):
"""
Cache the FSM enabled state for this request/thread.

PERFORMANCE: This is called once when the user is first set (typically in middleware).
It checks the feature flag once and caches the result, so all downstream code
can check a simple boolean property instead of repeatedly calling feature flag checks.

This eliminates thousands of feature flag lookups per request.

Args:
user: The user to check FSM feature flag for
"""
try:
from core.feature_flags import flag_set

# Only import when needed to avoid circular imports

# Check feature flag once and cache the result
fsm_enabled = flag_set('fflag_feat_fit_568_finite_state_management', user=user) if user else False
cls.set('fsm_enabled_cached', fsm_enabled)
except Exception:
# If feature flag check fails, assume disabled to be safe
cls.set('fsm_enabled_cached', False)

@classmethod
def is_fsm_enabled(cls) -> bool:
"""
Check if FSM is enabled for the current request/thread.

PERFORMANCE: Returns cached value that was set when user was first set.
This avoids repeated feature flag lookups throughout the request.

Returns:
True if FSM is enabled, False otherwise (includes manual disable via set_fsm_disabled)
"""
# Check manual override first (for tests and bulk operations)
if cls.is_fsm_disabled():
return False

# Return cached feature flag state (set once per request in _cache_fsm_enabled_state)
return cls.get('fsm_enabled_cached', False)

@classmethod
def get_job_data(cls) -> dict:
"""
Expand Down
137 changes: 37 additions & 100 deletions label_studio/fsm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,58 +66,27 @@ def __init__(self, *args, **kwargs):
@classmethod
def from_db(cls, db, field_names, values):
"""
Override from_db to capture original values when loading from database.
Override from_db to store raw DB values for lazy capture.

Django calls this method instead of __init__ when loading models from the database.
We need to capture the original field values here for change detection.

PERFORMANCE: We store the raw field values here without processing them.
This avoids accessing any ForeignKey fields (which would trigger queries).
We only process these values into _original_values in save() when we actually
need them for change detection.
"""
instance = super().from_db(db, field_names, values)
# Initialize as empty dict for safe access
instance._original_values = {}
# Capture original values immediately after loading from DB
# This ensures we have the baseline for change detection on the first save
instance._capture_original_values()
return instance

def _capture_original_values(self):
"""
Capture current field values for change detection.

This allows us to detect which fields changed during save operations,
which is crucial for determining appropriate FSM transitions.

For ForeignKey fields, we store the PK instead of the object to avoid
circular references and recursion issues.

Deferred fields (not yet loaded from DB) are skipped to prevent infinite
recursion when accessing them would trigger refresh_from_db().

This is called after each save to refresh the baseline for the next save.
"""
self._original_values = {}

# Get deferred fields to avoid triggering recursive database loads
# Deferred fields haven't been loaded yet, so they can't have changed
deferred_fields = self.get_deferred_fields()

for field in self._meta.fields:
# Skip deferred fields to prevent recursion via refresh_from_db()
if field.attname in deferred_fields:
continue
instance._original_values = dict(zip(field_names, values))

value = getattr(self, field.name, None)
# For ForeignKey fields, store PK to avoid circular references
if field.is_relation and field.many_to_one and value is not None:
self._original_values[field.name] = value.pk if hasattr(value, 'pk') else value
else:
self._original_values[field.name] = value
return instance

def __reduce_ex__(self, protocol):
"""
Override serialization to exclude internal FSM tracking fields.

Django's serialization uses pickle which calls __reduce_ex__.
We exclude _original_values since it's only needed for runtime
We exclude FSM tracking fields since they're only needed for runtime
change detection, not for serialization/restoration.
"""
# Get the default reduction
Expand Down Expand Up @@ -160,19 +129,16 @@ def _get_changed_fields(self) -> Dict[str, tuple]:
# Only check fields that were captured in _original_values
# Fields that were deferred during capture won't be in _original_values
# and should be considered unchanged
if field.name not in self._original_values:
if field.attname not in self._original_values:
continue
if field.is_relation and field.many_to_many:
continue

old_value = self._original_values[field.name]
new_value = getattr(self, field.name, None)
old_value = self._original_values[field.attname]
new_value = getattr(self, field.attname, None)

# For ForeignKey fields, old_value is stored as PK, so compare PK to PK
if field.is_relation and field.many_to_one:
new_pk = new_value.pk if new_value and hasattr(new_value, 'pk') else new_value
if old_value != new_pk:
changed[field.name] = (old_value, new_value)
elif old_value != new_value:
changed[field.name] = (old_value, new_value)
if old_value != new_value:
changed[field.attname] = (old_value, new_value)
return changed

def _determine_fsm_transitions(self, is_creating: bool = None, changed_fields: dict = None) -> list:
Expand Down Expand Up @@ -353,56 +319,26 @@ def _should_execute_fsm(self) -> bool:
Check if FSM processing should be executed.

Returns False if:
- Feature flag is disabled
- User context is unavailable (tests must set CurrentContext explicitly)
- Feature flag is disabled (cached at request level)
- Manually disabled via set_fsm_disabled() (for tests/bulk operations)
- Explicitly skipped via instance attribute

Returns:
True if FSM should execute, False otherwise

Note:
CurrentContext is available in web requests and background jobs.
In tests, it must be set explicitly for the user/organization.
PERFORMANCE: Uses cached FSM enabled state from CurrentContext that was set
once per request when user was initialized. This is a simple boolean check
instead of repeated feature flag lookups and user authentication checks.
"""
# Check for instance-level skip flag
if getattr(self, '_skip_fsm', False):
return False

# Use the centralized FSM enabled check from utils
# This handles feature flag and thread-local overrides
try:
from core.current_request import CurrentContext
from fsm.utils import is_fsm_enabled
# Fast path: Check cached FSM enabled state
# This was set once per request in CurrentContext.set_user()
from core.current_request import CurrentContext

# Get user from CurrentContext - don't fall back to AnonymousUser
# If no user in context (e.g., tests without explicit setup), return False
try:
user = CurrentContext.get_user()
user_type = type(user).__name__ if user else None
user_authenticated = getattr(user, 'is_authenticated', None) if user else None
logger.info(
f'FSM check for {self.__class__.__name__}(id={getattr(self, "pk", None)}): '
f'user_type={user_type}, authenticated={user_authenticated}'
)
if user is None:
logger.info(f'FSM check: User is None, skipping FSM for {self.__class__.__name__}')
return False
# Check if user is authenticated (not AnonymousUser)
if not user.is_authenticated:
logger.info(
f'FSM check: User {user_type} not authenticated, skipping FSM for {self.__class__.__name__}'
)
return False
except Exception:
# CurrentContext not available or no user set
# This is expected in tests that don't set up context
logger.info(f'FSM check: Exception getting user, skipping FSM for {self.__class__.__name__}')
return False

return is_fsm_enabled(user=user)
except Exception as e:
logger.debug(f'FSM check failed: {e}')
return False
return CurrentContext.is_fsm_enabled()

def save(self, *args, **kwargs):
"""
Expand All @@ -423,25 +359,29 @@ def save(self, *args, **kwargs):
Returns:
Whatever super().save() returns
"""
# Check for explicit FSM skip flag
skip_fsm = kwargs.pop('skip_fsm', False)

# Also check CurrentContext for skip_fsm flag (for context manager usage)
if not skip_fsm:
from core.current_request import CurrentContext
from core.current_request import CurrentContext

skip_fsm = CurrentContext.get('skip_fsm', False)
# Check for explicit FSM skip flag
skip_fsm = kwargs.pop('skip_fsm', CurrentContext.is_fsm_disabled())

# Check if this is a creation vs update
is_creating = self._state.adding

# Capture changed fields before save (only for updates)
# Note: _original_values should already be populated by from_db() or previous save()
changed_fields = {} if is_creating else self._get_changed_fields()

# Perform the actual save
result = super().save(*args, **kwargs)

# After successful save, update _original_values to current values
# This ensures subsequent saves can detect changes correctly
# Store attname values (raw PK for ForeignKey fields) to match from_db() format
self._original_values = {}
for field in self._meta.fields:
if field.is_relation and field.many_to_many:
continue
self._original_values[field.attname] = getattr(self, field.attname, None)

# After successful save, trigger FSM transitions if enabled and not skipped
should_execute = not skip_fsm and self._should_execute_fsm()

Expand Down Expand Up @@ -500,9 +440,6 @@ def save(self, *args, **kwargs):
exc_info=True,
)

# Update original values after save for next time
self._capture_original_values()

return result

def _execute_fsm_transition(self, transition_name: str, is_creating: bool, changed_fields: Dict[str, tuple]):
Expand Down
19 changes: 10 additions & 9 deletions label_studio/fsm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import uuid_utils
from core.current_request import CurrentContext
from core.feature_flags import flag_set

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -257,22 +256,24 @@ def is_fsm_enabled(user=None) -> bool:
"""
Check if FSM is enabled via feature flags and thread-local override.

PERFORMANCE: This function now checks the cached FSM state that was set
when the user was first initialized in CurrentContext. This avoids repeated
feature flag lookups throughout the request.

The check order is:
1. Check thread-local override (for test cleanup, bulk operations)
2. Check feature flag
2. Check cached feature flag state (set once per request)
3. Fallback to direct feature flag check (for edge cases without context)

Args:
user: User for feature flag evaluation (optional)
user: User for feature flag evaluation (optional, used as fallback only)

Returns:
True if FSM should be active
"""
# Check thread-local override first
if CurrentContext.is_fsm_disabled():
return False

# Then check feature flag
return flag_set('fflag_feat_fit_568_finite_state_management', user=user)
# Fast path: Check cached state from CurrentContext
# This is set once per request when user is initialized
return CurrentContext.is_fsm_enabled()


def get_current_state_safe(entity, user=None) -> Optional[str]:
Expand Down
Loading