Skip to content

Commit 351bae9

Browse files
feat(activity-logs): Incremental cache refresh task (#38801)
1 parent c07955a commit 351bae9

File tree

4 files changed

+226
-54
lines changed

4 files changed

+226
-54
lines changed

posthog/api/advanced_activity_logs/field_discovery.py

Lines changed: 104 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import gc
22
import json
33
import dataclasses
4+
from datetime import timedelta
45
from typing import Any, TypedDict
56

67
from django.db import connection
78
from django.db.models import QuerySet
9+
from django.utils import timezone
810

911
from posthog.models.activity_logging.activity_log import ActivityLog, Change
1012
from posthog.models.utils import UUIDT
@@ -96,8 +98,49 @@ def _analyze_detail_fields_memory(self) -> DetailFieldsResult:
9698
def _get_org_record_count(self) -> int:
9799
return ActivityLog.objects.filter(organization_id=self.organization_id).count()
98100

99-
def process_batch_for_large_org(self, offset: int, limit: int) -> None:
100-
batch_fields = self._process_batch_memory(offset, limit, use_sampling=True)
101+
def get_activity_logs_queryset(self, hours_back: int | None = None) -> QuerySet:
102+
"""Get the base queryset for activity logs, optionally filtered by time."""
103+
queryset = ActivityLog.objects.filter(organization_id=self.organization_id, detail__isnull=False)
104+
105+
if hours_back is not None:
106+
cutoff_time = timezone.now() - timedelta(hours=hours_back)
107+
queryset = queryset.filter(created_at__gte=cutoff_time)
108+
109+
return queryset
110+
111+
def get_sampled_records(self, limit: int, offset: int = 0) -> list[dict]:
112+
"""Get sampled records using SQL TABLESAMPLE for large datasets."""
113+
query = f"""
114+
SELECT scope, detail
115+
FROM posthog_activitylog TABLESAMPLE SYSTEM ({SAMPLING_PERCENTAGE})
116+
WHERE organization_id = %s
117+
AND detail IS NOT NULL
118+
ORDER BY created_at DESC
119+
LIMIT %s OFFSET %s
120+
"""
121+
122+
with connection.cursor() as cursor:
123+
cursor.execute(query, [str(self.organization_id), limit, offset])
124+
records = []
125+
for row in cursor.fetchall():
126+
scope, detail = row
127+
if isinstance(detail, str):
128+
try:
129+
detail = json.loads(detail)
130+
except (json.JSONDecodeError, TypeError):
131+
detail = None
132+
records.append({"scope": scope, "detail": detail})
133+
return records
134+
135+
def process_batch_for_large_org(self, records: list[dict], hours_back: int | None = None) -> None:
136+
"""Process a batch of records for large organizations.
137+
138+
Args:
139+
records: List of activity log records to process
140+
hours_back: If provided, used to get appropriate static filters for the time range
141+
"""
142+
# Process the provided records
143+
batch_fields = self._extract_fields_from_records(records)
101144
batch_converted = self._convert_to_discovery_format(batch_fields)
102145

103146
existing_cache = get_cached_fields(str(self.organization_id))
@@ -108,11 +151,21 @@ def process_batch_for_large_org(self, offset: int, limit: int) -> None:
108151
current_detail_fields = {}
109152
self._merge_fields_into_result(current_detail_fields, batch_converted)
110153

111-
static_filters = (
112-
existing_cache.get("static_filters")
113-
if existing_cache
114-
else self._get_static_filters(self._get_base_queryset())
115-
)
154+
# Get static filters for the appropriate time range
155+
if hours_back is not None:
156+
recent_queryset = self.get_activity_logs_queryset(hours_back=hours_back)
157+
new_static_filters = self._get_static_filters(recent_queryset)
158+
159+
# Merge with existing static filters
160+
if existing_cache and "static_filters" in existing_cache:
161+
static_filters = self._merge_static_filters(existing_cache["static_filters"], new_static_filters)
162+
else:
163+
static_filters = new_static_filters
164+
else:
165+
if existing_cache and existing_cache.get("static_filters"):
166+
static_filters = existing_cache["static_filters"]
167+
else:
168+
static_filters = self._get_static_filters(self._get_base_queryset())
116169

117170
cache_data = {
118171
"static_filters": static_filters,
@@ -181,38 +234,8 @@ def _discover_fields_memory(
181234

182235
return all_fields
183236

184-
def _process_batch_memory(
185-
self, offset: int, limit: int, use_sampling: bool = True
186-
) -> dict[str, set[tuple[str, str]]]:
187-
if use_sampling:
188-
query = f"""
189-
SELECT scope, detail
190-
FROM posthog_activitylog TABLESAMPLE SYSTEM ({SAMPLING_PERCENTAGE})
191-
WHERE organization_id = %s
192-
AND detail IS NOT NULL
193-
ORDER BY created_at DESC
194-
LIMIT %s OFFSET %s
195-
"""
196-
197-
with connection.cursor() as cursor:
198-
cursor.execute(query, [str(self.organization_id), limit, offset])
199-
records = []
200-
for row in cursor.fetchall():
201-
scope, detail = row
202-
if isinstance(detail, str):
203-
try:
204-
detail = json.loads(detail)
205-
except (json.JSONDecodeError, TypeError):
206-
detail = None
207-
records.append({"scope": scope, "detail": detail})
208-
else:
209-
records = [
210-
{"scope": record["scope"], "detail": record["detail"]}
211-
for record in ActivityLog.objects.filter(
212-
organization_id=self.organization_id, detail__isnull=False
213-
).values("scope", "detail")[offset : offset + limit]
214-
]
215-
237+
def _extract_fields_from_records(self, records: list[dict]) -> dict[str, set[tuple[str, str]]]:
238+
"""Extract field information from a list of activity log records."""
216239
batch_fields: dict[str, set[tuple[str, str]]] = {}
217240

218241
for record in records:
@@ -231,6 +254,20 @@ def _process_batch_memory(
231254

232255
return batch_fields
233256

257+
def _process_batch_memory(
258+
self, offset: int, limit: int, use_sampling: bool = True
259+
) -> dict[str, set[tuple[str, str]]]:
260+
"""Legacy method for backward compatibility."""
261+
if use_sampling:
262+
records = self.get_sampled_records(limit, offset)
263+
else:
264+
records = [
265+
{"scope": record["scope"], "detail": record["detail"]}
266+
for record in self.get_activity_logs_queryset().values("scope", "detail")[offset : offset + limit]
267+
]
268+
269+
return self._extract_fields_from_records(records)
270+
234271
def _extract_json_paths(self, obj: Any, prefix: str = "") -> set[tuple[str, str]]:
235272
paths = set()
236273

@@ -304,3 +341,31 @@ def _convert_to_discovery_format(self, fields: dict[str, set[tuple[str, str]]])
304341
result.append((scope, field_path, sorted(types)))
305342

306343
return result
344+
345+
def _merge_static_filters(self, existing: dict, new: dict) -> dict:
346+
"""Merge static filters additively"""
347+
merged = {
348+
"users": existing.get("users", []),
349+
"scopes": existing.get("scopes", []),
350+
"activities": existing.get("activities", []),
351+
}
352+
353+
# Merge users (by uuid)
354+
existing_user_ids = {u["value"] for u in merged["users"]}
355+
for user in new.get("users", []):
356+
if user["value"] not in existing_user_ids:
357+
merged["users"].append(user)
358+
359+
# Merge scopes
360+
existing_scopes = {s["value"] for s in merged["scopes"]}
361+
for scope in new.get("scopes", []):
362+
if scope["value"] not in existing_scopes:
363+
merged["scopes"].append(scope)
364+
365+
# Merge activities
366+
existing_activities = {a["value"] for a in merged["activities"]}
367+
for activity in new.get("activities", []):
368+
if activity["value"] not in existing_activities:
369+
merged["activities"].append(activity)
370+
371+
return merged

posthog/api/advanced_activity_logs/fields_cache.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,14 @@ def cache_fields(organization_id: str, fields_data: dict, record_count: int) ->
4040
client.setex(key, CACHE_TTL_SECONDS, json_data)
4141
except Exception as e:
4242
capture_exception(e)
43+
44+
45+
def delete_cached_fields(organization_id: str) -> bool:
46+
"""Delete cached fields for an organization"""
47+
try:
48+
client = get_client()
49+
key = _get_cache_key(organization_id)
50+
return bool(client.delete(key))
51+
except Exception as e:
52+
capture_exception(e)
53+
return False

posthog/management/commands/refresh_activity_log_fields_cache.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.core.management.base import BaseCommand
22

3-
from posthog.api.advanced_activity_logs.field_discovery import SMALL_ORG_THRESHOLD
3+
from posthog.api.advanced_activity_logs.constants import SMALL_ORG_THRESHOLD
44
from posthog.tasks.tasks import refresh_activity_log_fields_cache
55

66

@@ -9,10 +9,24 @@ class Command(BaseCommand):
99

1010
def add_arguments(self, parser):
1111
parser.add_argument("--dry-run", action="store_true", help="Show what would be processed without running")
12+
parser.add_argument(
13+
"--flush",
14+
action="store_true",
15+
help="Delete existing cache and rebuild from scratch (uses 10% sampling for full rebuild)",
16+
)
17+
parser.add_argument(
18+
"--hours-back",
19+
type=int,
20+
default=14,
21+
help="Number of hours to look back when not using --flush (default: 14 = 12h + 2h buffer)",
22+
)
1223

1324
def handle(self, *args, **options):
1425
if options["dry_run"]:
26+
from datetime import timedelta
27+
1528
from django.db.models import Count
29+
from django.utils import timezone
1630

1731
from posthog.models import Organization
1832
from posthog.models.activity_logging.activity_log import ActivityLog
@@ -35,9 +49,31 @@ def handle(self, *args, **options):
3549
org.activity_count = activity_counts.get(org.id, 0)
3650

3751
self.stdout.write(f"Would process {len(large_orgs)} organizations:")
38-
for org in large_orgs:
39-
self.stdout.write(f" - {org.name} (id={org.id}) - {org.activity_count:,} records")
52+
53+
if options["flush"]:
54+
self.stdout.write("Mode: FLUSH - Delete existing cache and rebuild from scratch with 10% sampling")
55+
for org in large_orgs:
56+
self.stdout.write(f" - {org.name} (id={org.id}) - {org.activity_count:,} total records")
57+
else:
58+
cutoff = timezone.now() - timedelta(hours=options["hours_back"])
59+
self.stdout.write(f"Mode: INCREMENTAL - Process last {options['hours_back']} hours with 100% coverage")
60+
self.stdout.write(f"Cutoff time: {cutoff}")
61+
62+
for org in large_orgs:
63+
recent_count = ActivityLog.objects.filter(
64+
organization_id=org.id, created_at__gte=cutoff, detail__isnull=False
65+
).count()
66+
self.stdout.write(
67+
f" - {org.name} (id={org.id}) - {recent_count:,} records from last {options['hours_back']}h"
68+
)
4069
else:
41-
self.stdout.write("Starting activity log fields cache refresh...")
42-
refresh_activity_log_fields_cache()
70+
mode = (
71+
"FLUSH mode"
72+
if options["flush"]
73+
else f"INCREMENTAL mode (last {options['hours_back']}h with 100% coverage)"
74+
)
75+
self.stdout.write(f"Starting activity log fields cache refresh in {mode}...")
76+
77+
refresh_activity_log_fields_cache(flush=options["flush"], hours_back=options["hours_back"])
78+
4379
self.stdout.write("Cache refresh completed.")

posthog/tasks/tasks.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -926,17 +926,64 @@ def background_delete_model_task(
926926

927927

928928
@shared_task(ignore_result=True, time_limit=7200)
929-
def refresh_activity_log_fields_cache() -> None:
930-
"""Refresh fields cache for large organizations every 12 hours"""
929+
def refresh_activity_log_fields_cache(flush: bool = False, hours_back: int = 14) -> None:
930+
"""
931+
Refresh fields cache for large organizations.
932+
933+
Args:
934+
flush: If True, delete existing cache and rebuild from scratch
935+
hours_back: Number of hours to look back (default: 14 = 12h schedule + 2h buffer)
936+
"""
937+
938+
from uuid import UUID
939+
931940
from django.db.models import Count
932941

933942
from posthog.api.advanced_activity_logs.constants import BATCH_SIZE, SAMPLING_PERCENTAGE, SMALL_ORG_THRESHOLD
934943
from posthog.api.advanced_activity_logs.field_discovery import AdvancedActivityLogFieldDiscovery
944+
from posthog.api.advanced_activity_logs.fields_cache import delete_cached_fields
935945
from posthog.exceptions_capture import capture_exception
936946
from posthog.models import Organization
937947
from posthog.models.activity_logging.activity_log import ActivityLog
938948

939-
logger.info("[refresh_activity_log_fields_cache] running task")
949+
def _process_org_with_flush(discovery: AdvancedActivityLogFieldDiscovery, org_id: UUID) -> None:
950+
"""Rebuild cache from scratch with sampling."""
951+
deleted = delete_cached_fields(str(org_id))
952+
logger.info(f"Flushed cache for org {org_id}: {deleted}")
953+
954+
record_count = discovery._get_org_record_count()
955+
estimated_sampled_records = int(record_count * (SAMPLING_PERCENTAGE / 100))
956+
total_batches = (estimated_sampled_records + BATCH_SIZE - 1) // BATCH_SIZE
957+
958+
logger.info(
959+
f"Rebuilding cache for org {org_id} from scratch: "
960+
f"{record_count} total records, sampling {estimated_sampled_records} records"
961+
)
962+
963+
for batch_num in range(total_batches):
964+
offset = batch_num * BATCH_SIZE
965+
records = discovery.get_sampled_records(limit=BATCH_SIZE, offset=offset)
966+
discovery.process_batch_for_large_org(records)
967+
968+
def _process_org_incremental(discovery: AdvancedActivityLogFieldDiscovery, org_id: UUID, hours_back: int) -> int:
969+
"""Process recent records with 100% coverage."""
970+
recent_queryset = discovery.get_activity_logs_queryset(hours_back=hours_back)
971+
recent_count = recent_queryset.count()
972+
973+
logger.info(f"Processing {recent_count} records from last {hours_back}h for org {org_id} (100% coverage)")
974+
975+
for batch_num in range(0, recent_count, BATCH_SIZE):
976+
records = [
977+
{"scope": record["scope"], "detail": record["detail"]}
978+
for record in recent_queryset.values("scope", "detail")[batch_num : batch_num + BATCH_SIZE]
979+
]
980+
if records:
981+
discovery.process_batch_for_large_org(records, hours_back=hours_back)
982+
983+
return recent_count
984+
985+
mode = "FLUSH" if flush else f"INCREMENTAL (last {hours_back}h, 100% coverage)"
986+
logger.info(f"[refresh_activity_log_fields_cache] running task in {mode} mode")
940987

941988
large_org_data = (
942989
ActivityLog.objects.values("organization_id")
@@ -951,24 +998,37 @@ def refresh_activity_log_fields_cache() -> None:
951998
org_count = len(large_orgs)
952999
logger.info(f"[refresh_activity_log_fields_cache] processing {org_count} large organizations")
9531000

1001+
processed_orgs = 0
1002+
total_recent_records = 0
1003+
9541004
for org in large_orgs:
9551005
try:
9561006
discovery = AdvancedActivityLogFieldDiscovery(org.id)
957-
record_count = discovery._get_org_record_count()
9581007

959-
estimated_sampled_records = int(record_count * (SAMPLING_PERCENTAGE / 100))
960-
total_batches = (estimated_sampled_records + BATCH_SIZE - 1) // BATCH_SIZE
1008+
if flush:
1009+
_process_org_with_flush(discovery, org.id)
1010+
else:
1011+
recent_count = _process_org_incremental(discovery, org.id, hours_back)
1012+
total_recent_records += recent_count
9611013

962-
for batch_num in range(total_batches):
963-
offset = batch_num * BATCH_SIZE
964-
discovery.process_batch_for_large_org(offset, BATCH_SIZE)
1014+
processed_orgs += 1
9651015

9661016
except Exception as e:
9671017
logger.exception(
9681018
"Failed to refresh activity log fields cache for org",
9691019
org_id=org.id,
1020+
mode=mode,
9701021
error=e,
9711022
)
9721023
capture_exception(e)
9731024

974-
logger.info(f"[refresh_activity_log_fields_cache] completed for {org_count} organizations")
1025+
if not flush:
1026+
logger.info(
1027+
f"[refresh_activity_log_fields_cache] completed for {processed_orgs}/{org_count} organizations "
1028+
f"in {mode} mode. Total recent records processed: {total_recent_records}"
1029+
)
1030+
else:
1031+
logger.info(
1032+
f"[refresh_activity_log_fields_cache] completed flush and rebuild for "
1033+
f"{processed_orgs}/{org_count} organizations"
1034+
)

0 commit comments

Comments
 (0)