Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bd25319
feat: implement detection tracking and integrate it into save_results
mohamedelabbas1996 May 27, 2025
e06f79f
used the latest classification features vector instead of the detecti…
mohamedelabbas1996 May 28, 2025
853607d
feat: added tracking stage, per-session logic, optimal detection pairing
mohamedelabbas1996 May 29, 2025
50ce8fc
test: added testing for tracking
mohamedelabbas1996 Jun 2, 2025
0ebbea5
skip sessions with human identifications
mohamedelabbas1996 Jun 3, 2025
3f4acac
added missing migration
mohamedelabbas1996 Jun 9, 2025
a4eeede
restored migrations
mohamedelabbas1996 Jun 9, 2025
8801f89
fixed migration issues
mohamedelabbas1996 Jun 10, 2025
d6d481a
fix: pin minio containers to known working versions
mihow Jun 18, 2025
8991b44
fix: pin minio container versions in CI stack
mihow Jun 18, 2025
ff3d1bb
moved tracking to a separate job
mohamedelabbas1996 Jun 19, 2025
1ea816d
Merge branch 'feat/restore-tracking' of https://github.com/RolnickLab…
mohamedelabbas1996 Jun 19, 2025
c688e68
removed call to tracking from ml job
mohamedelabbas1996 Jun 20, 2025
db5f104
passed tracking cost threshold as a job param
mohamedelabbas1996 Jun 20, 2025
d3a7b8c
fix: assigned occurrence project and deployment
mohamedelabbas1996 Jun 20, 2025
000c247
changed the observed date field in the occurrence list view to show t…
mohamedelabbas1996 Jun 20, 2025
be69807
fixed frontend code formatting
mohamedelabbas1996 Jun 20, 2025
1526cfb
improved logging and job progress tracking
mohamedelabbas1996 Jun 20, 2025
b82f175
fix: use features from the same algo when comparing detections
mihow Jun 25, 2025
9d0fd96
feat: update some type annotations & logging, resolve warnings
mihow Jun 25, 2025
87c91c1
feat: validate tracking parameters, add more parameters
mihow Jun 25, 2025
a194ffd
fix: only assign new occurrences to tracks with >1 detection
mihow Jun 25, 2025
4c4e4c9
feat: log number of occurrences reduced, and other things.
mihow Jun 25, 2025
b7f28ea
feat: skip chains that don't need new occurrences (len 1 or all same)
mihow Jun 25, 2025
30cc279
feat: don't require fully processed sessions for now
mihow Jul 2, 2025
ed35121
Merge branch 'deployments/ood.antenna.insectai.org' into feat/restore…
mihow Aug 21, 2025
785a8f2
Merge branch 'deployments/ood.antenna.insectai.org' into feat/restore…
mihow Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions ami/jobs/migrations/0019_alter_job_job_type_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generated by Django 4.2.10 on 2025-06-19 08:33

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("jobs", "0018_alter_job_job_type_key"),
]

operations = [
migrations.AlterField(
model_name="job",
name="job_type_key",
field=models.CharField(
choices=[
("ml", "ML pipeline"),
("populate_captures_collection", "Populate captures collection"),
("data_storage_sync", "Data storage sync"),
("unknown", "Unknown"),
("data_export", "Data Export"),
("detection_clustering", "Detection Feature Clustering"),
("tracking", "Occurrence Tracking"),
],
default="unknown",
max_length=255,
verbose_name="Job Type",
),
),
]
21 changes: 21 additions & 0 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ami.jobs.tasks import run_job
from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.tracking import perform_tracking
from ami.utils.schemas import OrderedEnum

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -673,6 +674,25 @@ def run(cls, job: "Job"):
job.save()


class TrackingJob(JobType):
name = "Occurrence Tracking"
key = "tracking"

@classmethod
def run(cls, job: "Job"):
job.logger.info("Starting tracking job")
job.update_status(JobState.STARTED)
job.started_at = datetime.datetime.now()
job.finished_at = None

perform_tracking(job)

job.update_status(JobState.SUCCESS)
job.logger.info("Tracking job finished successfully.")
job.finished_at = datetime.datetime.now()
job.save()


class UnknownJobType(JobType):
name = "Unknown"
key = "unknown"
Expand All @@ -689,6 +709,7 @@ def run(cls, job: "Job"):
UnknownJobType,
DataExportJob,
DetectionClusteringJob,
TrackingJob,
]


Expand Down
38 changes: 37 additions & 1 deletion ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,43 @@ def cluster_detections(self, request: HttpRequest, queryset: QuerySet[SourceImag

self.message_user(request, f"Clustered {queryset.count()} collection(s).")

actions = [populate_collection, populate_collection_async, cluster_detections, create_clustering_job]
@admin.action(description="Create tracking job (but don't run it)")
def create_tracking_job(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
from ami.jobs.models import Job, TrackingJob
from ami.ml.tracking import DEFAULT_TRACKING_PARAMS

for collection in queryset:
job = Job.objects.create(
name=f"Tracking for collection {collection.pk}",
project=collection.project,
source_image_collection=collection,
job_type_key=TrackingJob.key,
params=DEFAULT_TRACKING_PARAMS.__dict__,
)
self.message_user(request, f"Tracking job #{job.pk} created for collection #{collection.pk}")

@admin.action(description="Run tracking job")
def run_tracking_job(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
from ami.jobs.models import Job, TrackingJob

for collection in queryset:
job = Job.objects.create(
name=f"Tracking for collection {collection.pk}",
project=collection.project,
source_image_collection=collection,
job_type_key=TrackingJob.key,
)
job.enqueue()
self.message_user(request, f"Tracking job #{job.pk} started for collection #{collection.pk}")

actions = [
populate_collection,
populate_collection_async,
cluster_detections,
create_clustering_job,
create_tracking_job,
run_tracking_job,
]

# Hide images many-to-many field from form. This would list all source images in the database.
exclude = ("images",)
Expand Down
25 changes: 25 additions & 0 deletions ami/main/migrations/0069_detection_next_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 4.2.10 on 2025-06-10 10:19

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("main", "0068_allow_taxa_without_project"),
]

operations = [
migrations.AddField(
model_name="detection",
name="next_detection",
field=models.OneToOneField(
blank=True,
help_text="The detection that follows this one in the tracking sequence.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="previous_detection",
to="main.detection",
),
),
]
10 changes: 10 additions & 0 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,14 @@ class Detection(BaseModel):
classifications: models.QuerySet["Classification"]
source_image_id: int
detection_algorithm_id: int
next_detection = models.OneToOneField(
"self",
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name="previous_detection",
help_text="The detection that follows this one in the tracking sequence.",
)

def get_bbox(self):
if self.bbox:
Expand Down Expand Up @@ -2388,6 +2396,8 @@ def height(self) -> int | None:
if self.bbox and len(self.bbox) == 4:
return self.bbox[3] - self.bbox[1]

occurrence_id: int | None = None

class Meta:
ordering = [
"frame_num",
Expand Down
98 changes: 98 additions & 0 deletions ami/main/tests/test_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
from collections import defaultdict

import numpy as np
from django.test import TestCase
from django.utils import timezone

from ami.main.models import Classification, Detection, Occurrence, Project
from ami.ml.models import Algorithm
from ami.ml.tracking import assign_occurrences_by_tracking_images

logger = logging.getLogger(__name__)


class TestTracking(TestCase):
def setUp(self) -> None:
self.project = Project.objects.first()
self.event = self.project.events.first()
self.source_images = list(self.event.captures.order_by("timestamp"))
self.assign_mock_features_to_occurrence_detections(self.event)
# Save ground truth occurrence groupings
self.ground_truth_groups = defaultdict(set)
for occ in Occurrence.objects.filter(event=self.event):
det_ids = Detection.objects.filter(occurrence=occ).values_list("id", flat=True)
for det_id in det_ids:
self.ground_truth_groups[occ.pk].add(det_id)

# Clear existing tracking data (next_detection + occurrence)
Detection.objects.filter(source_image__event=self.event).update(next_detection=None)

def assign_mock_features_to_occurrence_detections(self, event, algorithm_name="MockTrackingAlgorithm"):
algorithm, _ = Algorithm.objects.get_or_create(name=algorithm_name)

for occurrence in event.occurrences.all():
base_vector = np.random.rand(2048) # Base feature for this occurrence group

for det in occurrence.detections.all():
feature_vector = base_vector + np.random.normal(0, 0.001, size=2048) # Add slight variation
Classification.objects.update_or_create(
detection=det,
algorithm=algorithm,
defaults={
"timestamp": timezone.now(),
"features_2048": feature_vector.tolist(),
"terminal": True,
"score": 1.0,
},
)

def test_tracking_exactly_reproduces_occurrences(self):
# Clear previous detection chains and occurrences
for det in Detection.objects.filter(source_image__event=self.event):
det.occurrence = None
det.next_detection = None
det.save()

Occurrence.objects.filter(event=self.event).delete()

# Run the tracking algorithm to regenerate occurrences
assign_occurrences_by_tracking_images(self.event, logger)

# Capture new tracking-generated occurrence groups
new_groups = {
occ.pk: set(Detection.objects.filter(occurrence=occ).values_list("id", flat=True))
for occ in Occurrence.objects.filter(event=self.event)
}

# Assert that the number of new groups equals the number of ground truth groups
self.assertEqual(
len(new_groups),
len(self.ground_truth_groups),
f"Expected {len(self.ground_truth_groups)} groups, but got {len(new_groups)}",
)

# Assert each new group exactly matches one of the original ground truth groups
unmatched_groups = [
new_set for new_set in new_groups.values() if new_set not in self.ground_truth_groups.values()
]

self.assertEqual(
len(unmatched_groups),
0,
f"{len(unmatched_groups)} of the new groups do not exactly match any ground truth group",
)
logger.info(
f"All {len(new_groups)} new groups match the ground truth groups exactly.",
)
logger.info(f"new groups: {new_groups}")
# Assert that each ground truth group is present in the new tracking results
for gt_set in self.ground_truth_groups.values():
logger.info(
f"Checking ground truth group: {gt_set}",
)
self.assertIn(
gt_set,
new_groups.values(),
f"Ground truth group {gt_set} not found in new tracking results",
)
Loading
Loading