Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 126 additions & 43 deletions src/datumaro/plugins/ndr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@

import logging as log
from enum import Enum, auto
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import Any, Iterable, List, Optional, Sequence
from pathlib import Path

import cv2
import numpy as np
from scipy.linalg import orth
from tqdm import tqdm

from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME
from datumaro.components.transformer import Transform
from datumaro.util import parse_str_enum_value
from datumaro.util.image import load_image


class Algorithm(Enum):
Expand Down Expand Up @@ -100,6 +106,8 @@ def build_cmdline_parser(cls, **kwargs):
"than result length (default: %(default)s)",
)
parser.add_argument("-s", "--seed", type=int, help="Random seed")
parser.add_argument("-S", "--save_media", action="store_true", help="Save core set images")
parser.add_argument("-o", "--output_dir", type=str, help="Directory to save images")
return parser

def __init__(
Expand All @@ -112,6 +120,8 @@ def __init__(
over_sample=None,
under_sample=None,
seed=None,
save_media= False,
output_dir= None,
**kwargs,
):
"""
Expand Down Expand Up @@ -141,6 +151,12 @@ def __init__(
if uniform, sample data with uniform distribution
if inverse, sample data with reciprocal of the number
of data which have same hash key
save_media: bool
Flag to indicate if media should be saved.
If True, the media files will be saved in the specified output directory.
output_dir: str, optional
Directory to save the media files.
If not provided, defaults to './output'. The directory is created if it doesn't exist.

Algorithm Specific for gradient
block_shape: tuple, (h, w)
Expand Down Expand Up @@ -186,72 +202,137 @@ def __init__(
unknown_member_error="Unknown undersampling method '{value}'.",
)

self._sample_keys = []
self._embeddings = []
self._deduplicated_item_ids = None

if seed:
self.seed = seed
else:
self.seed = None
self.working_subset = working_subset
self.duplicated_subset = duplicated_subset
self.algorithm = algorithm

self.num_cut = num_cut
self.over_sample = over_sample
self.under_sample = under_sample

self.algorithm_specific = kwargs
self.kept_item_id = None
self._initialized = False
self.save_media= save_media
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space before = is missing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey,
I'll resolve it, can you tell me if there are any other changes needed? I can compile the changes and push them all at once.
Thankyou

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my side, that's it. Perhaps @djdameln has something to add
Also, it'd be extremely convenient for us if you provide as small video dataset to check if the plugin is not broken after the refactoring


def _remove(self):
if self.seed:
np.random.seed(self.seed)
if save_media:
self.output_dir = Path(output_dir) if output_dir else Path('./output')
self.output_dir.mkdir(parents=True, exist_ok=True)

working_subset = self._extractor.get_subset(self.working_subset)
having_image = []
all_imgs = []
for item in working_subset:
if item.media.has_data:
having_image.append(item)
img = item.media.data
# Not handle empty image, as utils/image.py if check empty
if len(img.shape) == 2:
img = np.stack((img,) * 3, axis=-1)
elif len(img.shape) == 3:
if img.shape[2] == 1:
img = np.stack((img[:, :, 0],) * 3, axis=-1)
elif img.shape[2] == 4:
img = img[..., :3]
elif img.shape[2] == 3:
pass
else:
raise ValueError(
"Item %s: invalid image shape: "
"unexpected number of channels (%s)" % (item.id, img.shape[2])
)
else:
raise ValueError(
"Item %s: invalid image shape: "
"unexpected number of dimensions (%s)" % (item.id, len(img.shape))
)

if self.algorithm == Algorithm.gradient:
# Calculate gradient
img = self._cgrad_feature(img)
else:
raise NotImplementedError()
all_imgs.append(img)
def get_deduplicated_item_ids(self) -> Sequence[str]:
"""Returns the list of deduplicated items, before resolving under-/oversampling conditions"""
if not self._initialized:
raise Exception("The index is not initialized yet.")
return sorted(self._deduplicated_item_ids)

def save_deduplicated_item_ids(self):
"""Saves list of deduplicated frame IDs (before sampling) as deduplicated.list"""
with (self.output_dir / "deduplicated.list").open('w') as f:
for sample_id in self.get_deduplicated_item_ids():
print(sample_id, file=f)

def get_core_set_item_ids(self):
"""Returns the list of core set frame ids after deduplication and sampling/cutting """

if not self._initialized:
raise Exception("The index is not initialized yet.")
return sorted(self.kept_item_id)

def save_core_set_item_ids(self):
""" Saves list of final selected frame IDs (after sampling) as core_set_frames.list """
with (self.output_dir / "core_set_frames.list").open('w') as f:
for sample_id in self.get_core_set_item_ids():
print(sample_id, file=f)

def append_state(self, values: Iterable[Any]):
"""Append precomputed state values to internal storage"""
for sample, embedding in values:
self._sample_keys.append(sample)
self._embeddings.append(embedding)

def compute_state(self, item, img):
"""Compute embedding state for a given image"""
if isinstance(img, str):
img = load_image(img)

# Not handle empty image, as utils/image.py if check empty
if len(img.shape) == 2:
img = np.stack((img,) * 3, axis=-1)
elif len(img.shape) == 3:
if img.shape[2] == 1:
img = np.stack((img[:, :, 0],) * 3, axis=-1)
elif img.shape[2] == 4:
img = img[..., :3]
elif img.shape[2] == 3:
pass
else:
log.debug("Skipping item %s: no image data available", item.id)
raise ValueError(
f"Item {item.id}: invalid image shape: "
f"unexpected number of channels ({img.shape[2]})")
else:
raise ValueError(
f"Item {item.id}: invalid image shape: "
f"unexpected number of dimensions ({len(img.shape)})")

if self.num_cut and self.num_cut > len(all_imgs):
if self.algorithm == Algorithm.gradient:
embedding = self._cgrad_feature(img)
else:
raise NotImplementedError()

return item.id, embedding

def _remove(self):
# Uses cached states and threading
if not self._sample_keys and not self._embeddings:
if self.working_subset == DEFAULT_SUBSET_NAME:
working_subset_length = float("inf")
working_subset = self._extractor
else:
working_subset = self._extractor.get_subset(self.working_subset)
working_subset_length = len(working_subset)

""" Process with thread pool
# 1 thread reads frames (only sequentially for a video)
# 2 thread computes frame embeddings """

with ThreadPoolExecutor(2) as pool:
queue = Queue()
working_subset_iter = iter(tqdm(
((item, item.media.data) for item in working_subset),
total=working_subset_length
))

next_sample, next_sample_media = next(working_subset_iter, (None, None))
while queue or next_sample:
if not queue.full() and next_sample is not None:
queue.put(pool.submit(self.compute_state, next_sample, next_sample_media))
next_sample, next_sample_media = next(working_subset_iter, (None, None))

processed_sample, embedding = queue.get()
self._sample_keys.append(processed_sample)
self._embeddings.append(embedding)

if self.num_cut and self.num_cut > len(self._embeddings):
raise ValueError("The number of images is smaller than the cut you want")

if self.seed:
np.random.seed(self.seed)

if self.algorithm == Algorithm.gradient:
all_key, fidx, kept_index, key_counter, removed_index_with_sim = self._gradient_based(
all_imgs, **self.algorithm_specific
self._embeddings, **self.algorithm_specific
)
else:
raise NotImplementedError()

self._deduplicated_item_ids = set(self._sample_keys[ii] for ii in kept_index)

kept_index = self._keep_cut(
self.num_cut,
all_key,
Expand All @@ -262,7 +343,9 @@ def _remove(self):
self.over_sample,
self.under_sample,
)
self.kept_item_id = set(having_image[ii].id for ii in kept_index)
self.kept_item_id = set(self._sample_keys[ii] for ii in kept_index)
if self.save_media :
self.save_core_set_item_ids()

def _gradient_based(self, all_imgs, block_shape=(4, 4), hash_dim=32, sim_threshold=0.5):
if len(block_shape) != 2:
Expand Down