diff --git a/ami/jobs/migrations/0019_alter_job_job_type_key.py b/ami/jobs/migrations/0019_alter_job_job_type_key.py new file mode 100644 index 000000000..e776d9cf6 --- /dev/null +++ b/ami/jobs/migrations/0019_alter_job_job_type_key.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.10 on 2025-06-19 08:33 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0018_alter_job_job_type_key"), + ] + + operations = [ + migrations.AlterField( + model_name="job", + name="job_type_key", + field=models.CharField( + choices=[ + ("ml", "ML pipeline"), + ("populate_captures_collection", "Populate captures collection"), + ("data_storage_sync", "Data storage sync"), + ("unknown", "Unknown"), + ("data_export", "Data Export"), + ("detection_clustering", "Detection Feature Clustering"), + ("tracking", "Occurrence Tracking"), + ], + default="unknown", + max_length=255, + verbose_name="Job Type", + ), + ), + ] diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 4af39745b..9ea7291f4 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -18,6 +18,7 @@ from ami.jobs.tasks import run_job from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.tracking import perform_tracking from ami.utils.schemas import OrderedEnum logger = logging.getLogger(__name__) @@ -673,6 +674,25 @@ def run(cls, job: "Job"): job.save() +class TrackingJob(JobType): + name = "Occurrence Tracking" + key = "tracking" + + @classmethod + def run(cls, job: "Job"): + job.logger.info("Starting tracking job") + job.update_status(JobState.STARTED) + job.started_at = datetime.datetime.now() + job.finished_at = None + + perform_tracking(job) + + job.update_status(JobState.SUCCESS) + job.logger.info("Tracking job finished successfully.") + job.finished_at = datetime.datetime.now() + job.save() + + class UnknownJobType(JobType): name = "Unknown" key = "unknown" @@ -689,6 +709,7 @@ def run(cls, job: "Job"): UnknownJobType, DataExportJob, DetectionClusteringJob, + TrackingJob, ] diff --git a/ami/main/admin.py b/ami/main/admin.py index 684ec9b54..2c5cb5ef0 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -675,7 +675,43 @@ def cluster_detections(self, request: HttpRequest, queryset: QuerySet[SourceImag self.message_user(request, f"Clustered {queryset.count()} collection(s).") - actions = [populate_collection, populate_collection_async, cluster_detections, create_clustering_job] + @admin.action(description="Create tracking job (but don't run it)") + def create_tracking_job(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None: + from ami.jobs.models import Job, TrackingJob + from ami.ml.tracking import DEFAULT_TRACKING_PARAMS + + for collection in queryset: + job = Job.objects.create( + name=f"Tracking for collection {collection.pk}", + project=collection.project, + source_image_collection=collection, + job_type_key=TrackingJob.key, + params=DEFAULT_TRACKING_PARAMS.__dict__, + ) + self.message_user(request, f"Tracking job #{job.pk} created for collection #{collection.pk}") + + @admin.action(description="Run tracking job") + def run_tracking_job(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None: + from ami.jobs.models import Job, TrackingJob + + for collection in queryset: + job = Job.objects.create( + name=f"Tracking for collection {collection.pk}", + project=collection.project, + source_image_collection=collection, + job_type_key=TrackingJob.key, + ) + job.enqueue() + self.message_user(request, f"Tracking job #{job.pk} started for collection #{collection.pk}") + + actions = [ + populate_collection, + populate_collection_async, + cluster_detections, + create_clustering_job, + create_tracking_job, + run_tracking_job, + ] # Hide images many-to-many field from form. This would list all source images in the database. exclude = ("images",) diff --git a/ami/main/migrations/0069_detection_next_detection.py b/ami/main/migrations/0069_detection_next_detection.py new file mode 100644 index 000000000..949b3e539 --- /dev/null +++ b/ami/main/migrations/0069_detection_next_detection.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.10 on 2025-06-10 10:19 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0068_allow_taxa_without_project"), + ] + + operations = [ + migrations.AddField( + model_name="detection", + name="next_detection", + field=models.OneToOneField( + blank=True, + help_text="The detection that follows this one in the tracking sequence.", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="previous_detection", + to="main.detection", + ), + ), + ] diff --git a/ami/main/models.py b/ami/main/models.py index ce5fbadfa..9f3e3a91b 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -2344,6 +2344,14 @@ class Detection(BaseModel): classifications: models.QuerySet["Classification"] source_image_id: int detection_algorithm_id: int + next_detection = models.OneToOneField( + "self", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="previous_detection", + help_text="The detection that follows this one in the tracking sequence.", + ) def get_bbox(self): if self.bbox: @@ -2388,6 +2396,8 @@ def height(self) -> int | None: if self.bbox and len(self.bbox) == 4: return self.bbox[3] - self.bbox[1] + occurrence_id: int | None = None + class Meta: ordering = [ "frame_num", diff --git a/ami/main/tests/test_tracking.py b/ami/main/tests/test_tracking.py new file mode 100644 index 000000000..483d571ee --- /dev/null +++ b/ami/main/tests/test_tracking.py @@ -0,0 +1,98 @@ +import logging +from collections import defaultdict + +import numpy as np +from django.test import TestCase +from django.utils import timezone + +from ami.main.models import Classification, Detection, Occurrence, Project +from ami.ml.models import Algorithm +from ami.ml.tracking import assign_occurrences_by_tracking_images + +logger = logging.getLogger(__name__) + + +class TestTracking(TestCase): + def setUp(self) -> None: + self.project = Project.objects.first() + self.event = self.project.events.first() + self.source_images = list(self.event.captures.order_by("timestamp")) + self.assign_mock_features_to_occurrence_detections(self.event) + # Save ground truth occurrence groupings + self.ground_truth_groups = defaultdict(set) + for occ in Occurrence.objects.filter(event=self.event): + det_ids = Detection.objects.filter(occurrence=occ).values_list("id", flat=True) + for det_id in det_ids: + self.ground_truth_groups[occ.pk].add(det_id) + + # Clear existing tracking data (next_detection + occurrence) + Detection.objects.filter(source_image__event=self.event).update(next_detection=None) + + def assign_mock_features_to_occurrence_detections(self, event, algorithm_name="MockTrackingAlgorithm"): + algorithm, _ = Algorithm.objects.get_or_create(name=algorithm_name) + + for occurrence in event.occurrences.all(): + base_vector = np.random.rand(2048) # Base feature for this occurrence group + + for det in occurrence.detections.all(): + feature_vector = base_vector + np.random.normal(0, 0.001, size=2048) # Add slight variation + Classification.objects.update_or_create( + detection=det, + algorithm=algorithm, + defaults={ + "timestamp": timezone.now(), + "features_2048": feature_vector.tolist(), + "terminal": True, + "score": 1.0, + }, + ) + + def test_tracking_exactly_reproduces_occurrences(self): + # Clear previous detection chains and occurrences + for det in Detection.objects.filter(source_image__event=self.event): + det.occurrence = None + det.next_detection = None + det.save() + + Occurrence.objects.filter(event=self.event).delete() + + # Run the tracking algorithm to regenerate occurrences + assign_occurrences_by_tracking_images(self.event, logger) + + # Capture new tracking-generated occurrence groups + new_groups = { + occ.pk: set(Detection.objects.filter(occurrence=occ).values_list("id", flat=True)) + for occ in Occurrence.objects.filter(event=self.event) + } + + # Assert that the number of new groups equals the number of ground truth groups + self.assertEqual( + len(new_groups), + len(self.ground_truth_groups), + f"Expected {len(self.ground_truth_groups)} groups, but got {len(new_groups)}", + ) + + # Assert each new group exactly matches one of the original ground truth groups + unmatched_groups = [ + new_set for new_set in new_groups.values() if new_set not in self.ground_truth_groups.values() + ] + + self.assertEqual( + len(unmatched_groups), + 0, + f"{len(unmatched_groups)} of the new groups do not exactly match any ground truth group", + ) + logger.info( + f"All {len(new_groups)} new groups match the ground truth groups exactly.", + ) + logger.info(f"new groups: {new_groups}") + # Assert that each ground truth group is present in the new tracking results + for gt_set in self.ground_truth_groups.values(): + logger.info( + f"Checking ground truth group: {gt_set}", + ) + self.assertIn( + gt_set, + new_groups.values(), + f"Ground truth group {gt_set} not found in new tracking results", + ) diff --git a/ami/ml/tracking.py b/ami/ml/tracking.py new file mode 100644 index 000000000..075c1825d --- /dev/null +++ b/ami/ml/tracking.py @@ -0,0 +1,403 @@ +import dataclasses +import math +import typing +from collections.abc import Iterable + +import numpy as np +from django.db.models import Count + +from ami.main.models import Classification, Detection, Event, Occurrence, SourceImage +from ami.ml.models import Algorithm + +if typing.TYPE_CHECKING: + from ami.jobs.models import Job + + +@dataclasses.dataclass +class TrackingParams: + """ + Parameters for the tracking job. + """ + + cost_threshold: float = 0.2 + skip_if_human_identifications: bool = True + require_completely_processed_session: bool = False + feature_extraction_algorithm_id: int | None = None + + +DEFAULT_TRACKING_PARAMS = TrackingParams() + + +def cosine_similarity(v1: Iterable[float], v2: Iterable[float]) -> float: + v1 = np.array(v1) + v2 = np.array(v2) + sim = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) + return float(np.clip(sim, 0.0, 1.0)) + + +def iou(bb1, bb2): + xA = max(bb1[0], bb2[0]) + yA = max(bb1[1], bb2[1]) + xB = min(bb1[2], bb2[2]) + yB = min(bb1[3], bb2[3]) + interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) + boxAArea = (bb1[2] - bb1[0] + 1) * (bb1[3] - bb1[1] + 1) + boxBArea = (bb2[2] - bb2[0] + 1) * (bb2[3] - bb2[1] + 1) + unionArea = boxAArea + boxBArea - interArea + return interArea / unionArea if unionArea > 0 else 0 + + +def box_ratio(bb1, bb2): + area1 = (bb1[2] - bb1[0] + 1) * (bb1[3] - bb1[1] + 1) + area2 = (bb2[2] - bb2[0] + 1) * (bb2[3] - bb2[1] + 1) + return min(area1, area2) / max(area1, area2) + + +def distance_ratio(bb1, bb2, img_diag): + cx1 = (bb1[0] + bb1[2]) / 2 + cy1 = (bb1[1] + bb1[3]) / 2 + cx2 = (bb2[0] + bb2[2]) / 2 + cy2 = (bb2[1] + bb2[3]) / 2 + dist = math.sqrt((cx2 - cx1) ** 2 + (cy2 - cy1) ** 2) + return dist / img_diag if img_diag > 0 else 1.0 + + +def image_diagonal(width: int, height: int) -> int: + img_diagonal = int(math.ceil(math.sqrt(width**2 + height**2))) + return img_diagonal + + +def total_cost(f1, f2, bb1, bb2, diag): + return ( + (1 - cosine_similarity(f1, f2)) + + (1 - iou(bb1, bb2)) + + (1 - box_ratio(bb1, bb2)) + + distance_ratio(bb1, bb2, diag) + ) + + +def get_most_common_algorithm_for_event(event): + """ + Returns the most common Algorithm object (used in classifications with features_2048) for the given event. + """ + most_common = ( + Classification.objects.filter( + detection__source_image__event=event, + features_2048__isnull=False, + ) + .values("algorithm_id") + .annotate(count=Count("id")) + .order_by("-count") + .first() + ) + + if most_common: + return Algorithm.objects.get(id=most_common["algorithm_id"]) + + return None + + +def event_fully_processed(event: Event, logger, algorithm: Algorithm) -> bool: + """ + Checks if all captures in the event have processed detections with features_2048 + """ + total_captures = event.captures.count() + logger.info(f"Checking if event {event.pk} is fully processed... Total captures: {total_captures}") + + processed_captures = ( + event.captures.filter( + detections__classifications__features_2048__isnull=False, + detections__classifications__algorithm=algorithm, + ) + .distinct() + .count() + ) + + if processed_captures < total_captures: + logger.info( + f"Event {event.pk} is not fully processed. " + f"Only {processed_captures}/{total_captures} captures have processed detections." + ) + return False + + logger.info(f"Event {event.pk} is fully processed.") + return True + + +def get_feature_vector(detection: Detection, algorithm: Algorithm): + """ + Returns the latest non-null features_2048 vector from the given detection, + extracted by a specific algorithm. + """ + return ( + detection.classifications.filter(features_2048__isnull=False, algorithm=algorithm) + .order_by("-timestamp") + .values_list("features_2048", flat=True) + .first() + ) + + +def assign_occurrences_from_detection_chains(source_images: "list[SourceImage]", logger): + """ + Walk detection chains across source images and assign a new occurrence to each chain. + """ + visited: set[int] = set() + created_occurrences_count = 0 + existing_occurrence_count = ( + Occurrence.objects.filter(detections__source_image__in=source_images).distinct().count() + ) + for image in source_images: + for det in image.detections.all(): + if det.pk in visited or getattr(det, "previous_detection", None) is not None: + continue # Already processed or this is not a chain start + + chain: list[Detection] = [] + current = det + while current and current.pk not in visited: + chain.append(current) + visited.add(current.pk) + current = current.next_detection + + if chain and len(chain) > 1: + # Only create new occurrence if there are multiple detections in the chain + logger.debug( + f"Found chain of {len(chain)} detections starting from detection {det.pk} in image {image.pk}" + ) + + old_occurrences = {d.occurrence_id for d in chain if d.occurrence_id} + + if len(old_occurrences) == 1: + # If all detections in the chain belong to the same occurrence, skip reassignment + logger.debug( + f"All detections in chain already assigned to occurrence {old_occurrences.pop()}. Skipping." + ) + continue + + # Delete old occurrences (if any) + # @TODO: Consider if this is the desired behavior. Check for any history on the occurrence. Consider + # soft deleting or just reassign the detections to the new occurrence. + for occ_id in old_occurrences: + try: + logger.debug(f"Deleting old occurrence {occ_id} before reassignment.") + Occurrence.objects.filter(id=occ_id).delete() + except Exception as e: + logger.error(f"Failed to delete occurrence {occ_id}: {e}") + + occurrence = Occurrence.objects.create( + event=chain[0].source_image.event, + deployment=chain[0].source_image.deployment, + project=chain[0].source_image.project, + ) + created_occurrences_count += 1 + + for d in chain: + d.occurrence = occurrence + d.save() + + occurrence.save() + + logger.debug(f"Assigned occurrence {occurrence.pk} to chain of {len(chain)} detections") + + # @TODO report how many detections were processed, length of chains, which are solo vs. chains, etc. + + new_occurrence_count = Occurrence.objects.filter(detections__source_image__in=source_images).distinct().count() + occurrences_removed = existing_occurrence_count - new_occurrence_count + if occurrences_removed > 0: + logger.info(f"Reduced existing occurrences by {occurrences_removed}.") + logger.info( + f"Assigned {created_occurrences_count} new occurrences to detection chains in {len(source_images)} images.\n" + f"Occurrences before: {existing_occurrence_count}, after: {new_occurrence_count}.\n" + f"Total detections processed: {len(visited)}." + ) + + +def assign_occurrences_by_tracking_images( + event: Event, logger, algorithm: Algorithm, params: TrackingParams = DEFAULT_TRACKING_PARAMS, job=None +) -> None: + """ + Track detections across ordered source images and assign them to occurrences. + """ + from ami.jobs.models import JobState + + source_images = event.captures.order_by("timestamp") + logger.info(f"Found {len(source_images)} source images for event {event.pk}") + if len(source_images) < 2: + logger.warn("Not enough images to perform tracking. At least 2 images are required.") + return + for i in range(len(source_images) - 1): + current_image = source_images[i] + next_image = source_images[i + 1] + + current_detections = list(current_image.detections.all()) + next_detections = list(next_image.detections.all()) + + logger.debug(f"""Tracking: Processing image {i + 1} of {len(source_images)}""") + + if not current_image.width or not current_image.height: + logger.warning(f"Image {current_image.pk} has no width and/or height. Skipping tracking for this event.") + return + + pair_detections( + current_detections, + next_detections, + image_width=current_image.width, + image_height=current_image.height, + cost_threshold=params.cost_threshold, + algorithm=algorithm, + logger=logger, + ) + if job: + job.progress.update_stage( + f"event_{event.pk}", + status=JobState.STARTED, + progress=(i + 1) / (len(source_images) - 1), + ) + job.save() + + assign_occurrences_from_detection_chains(source_images, logger) + if job: + job.progress.update_stage( + f"event_{event.pk}", + progress=1.0, + ) + job.save() + + +def pair_detections( + current_detections: list, + next_detections: list, + image_width: int, + image_height: int, + cost_threshold: float, + algorithm, + logger, +) -> None: + """ + Assigns next_detection for each detection in current_detections based on lowest cost match + from next_detections, ensuring unique assignments and no duplicates. + + Only pairs with cost < threshold are considered. + """ + logger.debug(f"Pairing {len(current_detections)} - >{len(next_detections)} detections") + + potential_matches = [] + + for det in current_detections: + det_vec = get_feature_vector(det, algorithm) + if det_vec is None: + logger.debug(f"Skipping detection {det.id} (no features)") + continue + + for next_det in next_detections: + next_vec = get_feature_vector(next_det, algorithm) + if next_vec is None: + logger.debug(f"Skipping next detection {next_det.id} (no features)") + continue + + cost = total_cost( + det_vec, + next_vec, + det.bbox, + next_det.bbox, + image_diagonal(image_width, image_height), + ) + + if cost < cost_threshold: + potential_matches.append((det, next_det, cost)) + + # Sort by cost: lower is better + potential_matches.sort(key=lambda x: x[2]) + + assigned_current_ids = set() + assigned_next_ids = set() + + for det, next_det, cost in potential_matches: + if det.id in assigned_current_ids or next_det.id in assigned_next_ids: + continue + # check if next detection has a previous detection already assigned + previous_detection: Detection | None = getattr(next_det, "previous_detection", None) + if previous_detection is not None: + logger.debug(f"{next_det.id} already has previous detection: {next_det.previous_detection.id}") + previous_detection.next_detection = None + previous_detection.save() + logger.debug(f"Cleared previous detection {previous_detection.pk} -> {next_det.pk} link") + + logger.debug(f"Trying to link {det.id} => {next_det.id}") + det.next_detection = next_det + det.save() + logger.debug(f"Linked detection {det.id} => {next_det.id} with cost {cost:.4f}") + + assigned_current_ids.add(det.id) + assigned_next_ids.add(next_det.id) + + +def perform_tracking(job: "Job"): + """ + Perform detection tracking for all events in the job's source image collection. + Runs tracking only if all images in an event have processed detections with features. + """ + + params = DEFAULT_TRACKING_PARAMS + # Override default params with job params if provided + if job.params: + params = dataclasses.replace(params, **job.params) + + job.logger.info("Tracking started") + job.logger.info(f"Using tracking parameters: {params}") + collection = job.source_image_collection + if not collection: + job.logger.info("Tracking: No source image collection found. Skipping tracking.") + return + job.logger.info("Tracking: Fetching events for collection %s", collection.pk) + events_qs = Event.objects.filter(captures__collections=collection).order_by("created_at").distinct() + total_events = events_qs.count() + events = events_qs.iterator() + job.logger.info("Tracking: Found %d events in collection %s", total_events, collection.pk) + + for event in events_qs: + job.progress.add_stage(name=f"Event {event.pk}", key=f"event_{event.pk}") + job.save() + + for idx, event in enumerate(events, start=1): + job.logger.info(f"Tracking: Processing event {idx} of {total_events} (Event ID: {event.pk})") + + # Get the most common algorithm for the current event + algorithm = get_most_common_algorithm_for_event(event) + if algorithm is not None: + job.logger.info(f"Using most common feature extraction algorithm for event {event}: {algorithm.name}") + else: + job.logger.warning( + f"No feature extraction algorithm found for detections in event {event}. " + "Skipping tracking for this event." + ) + continue + + # Check if there are human identifications in the event + if ( + params.skip_if_human_identifications + and Occurrence.objects.filter(event=event, identifications__isnull=False).exists() + ): + job.logger.info(f"Tracking: Skipping tracking for event {event.pk}: human identifications present.") + continue + + # Check if the all captures in the event have processed detections with features + if params.require_completely_processed_session and not event_fully_processed( + event, logger=job.logger, algorithm=algorithm + ): + job.logger.info( + f"Tracking: Skipping tracking for event {event.pk}: not all detections are fully processed." + ) + continue + + job.logger.info(f"Tracking: Running tracking for event {event.pk}") + + assign_occurrences_by_tracking_images( + event=event, + logger=job.logger, + params=params, + algorithm=algorithm, + job=job, + ) + + job.logger.info("Tracking: Finished tracking.") + job.save() diff --git a/ui/src/pages/occurrences/occurrence-columns.tsx b/ui/src/pages/occurrences/occurrence-columns.tsx index 09c1f5896..69227b15f 100644 --- a/ui/src/pages/occurrences/occurrence-columns.tsx +++ b/ui/src/pages/occurrences/occurrence-columns.tsx @@ -139,10 +139,13 @@ export const columns: ( }, { id: 'date', - name: translate(STRING.FIELD_LABEL_DATE_OBSERVED), - sortField: 'first_appearance_timestamp', - renderCell: (item: Occurrence) => , + name: '# Detections', + sortField: 'detections_count', + renderCell: (item: Occurrence) => ( + + ), }, + { id: 'time', sortField: 'first_appearance_time',