Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion app/ldap_protocol/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE
"""

import asyncio
import functools
import hashlib
import random
Expand All @@ -138,19 +139,23 @@
import time
from calendar import timegm
from datetime import datetime
from functools import wraps
from hashlib import blake2b
from operator import attrgetter
from typing import Callable
from typing import Any, Callable, Iterable
from zoneinfo import ZoneInfo

from loguru import logger
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm.attributes import instance_state
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable

from entities import Directory

DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes


def validate_entry(entry: str) -> bool:
"""Validate entry str.
Expand Down Expand Up @@ -402,3 +407,72 @@ async def explain_query(
for row in await session.execute(explain(query, analyze=True))
),
)


def has_expired_sqla_objs(obj: Any, max_depth: int = 3) -> bool:
def _check(value: Any) -> bool:
try:
state = instance_state(value)
return bool(state.expired_attributes)
except AttributeError:
return False

def _walk(value: Any, depth: int = 0) -> bool:
if depth > max_depth:
return False

if _check(value):
return True

if isinstance(value, str | bytes | bytearray):
return False

if isinstance(value, dict):
return any(_walk(v, depth + 1) for v in value.values())

if isinstance(value, Iterable):
return any(_walk(v, depth + 1) for v in value)

return False

return _walk(obj)


def async_lru_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable:
cache: dict = {}
locks: dict = {}

def _is_value_expired(
value: Any,
now: float,
expires_at: float | None,
) -> bool:
return bool(
expires_at and expires_at < now or has_expired_sqla_objs(value),
)

def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args: tuple, **kwargs: dict) -> Any:
key = (args, tuple(sorted(kwargs.items())))
now = time.monotonic()
if key not in locks:
locks[key] = asyncio.Lock()

async with locks[key]:
if key in cache:
value, expires_at = cache[key]
if not _is_value_expired(value, now, expires_at):
return value
else:
del cache[key]

result = await func(*args, **kwargs)
expires_at = now + ttl if ttl else None
cache[key] = (result, expires_at)
del locks[key]
return result

return wrapper

return decorator
2 changes: 2 additions & 0 deletions app/ldap_protocol/utils/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from .const import EMAIL_RE, GRANT_DN_STRING
from .helpers import (
async_lru_cache,
create_integer_hash,
create_object_sid,
dn_is_base_directory,
Expand All @@ -35,6 +36,7 @@
)


@async_lru_cache()
async def get_base_directories(session: AsyncSession) -> list[Directory]:
"""Get base domain directories."""
result = await session.execute(
Expand Down