Skip to content

Conversation

@bzamanlooy
Copy link
Collaborator

PR Type

Feauture:
TF Attack

Short Description

Adding TF attack to the midst toolkit. Refactored the code, made sure it works with midst toolkit, wrote an integration test.

#Next step:

  • fix typing errors
  • unit tests

@coderabbitai
Copy link

coderabbitai bot commented Dec 3, 2025

📝 Walkthrough

Walkthrough

This PR introduces a comprehensive membership inference attack (MIA) framework for tabular differentially private models. The implementation includes three new Python modules providing data loading utilities (data_utils.py), diffusion-based attack orchestration (tf_attack.py), and classifier training infrastructure (classification.py). Supporting infrastructure consists of integration and unit tests, along with JSON configuration and metadata files for test assets. The attack workflow combines dataset preparation, diffusion score computation, MLP classifier training, and performance evaluation across multiple model instances and timesteps.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • src/midst_toolkit/attacks/tf/tf_attack.py: Core attack orchestration with mixed_loss computation for diffusion models, dataset preparation, classifier training, and end-to-end attack pipeline; contains multiple interconnected functions with conditional logic and external dependencies.
  • src/midst_toolkit/attacks/tf/data_utils.py: Substantial data utilities module with 8+ functions spanning ROC plotting, multi-table dataset loading, deduplication logic, attack data preparation, and performance evaluation; complex validation and error handling patterns.
  • src/midst_toolkit/attacks/tf/classification.py: PyTorch-based neural network training module with custom loss functions, model checkpointing, and iterative training loops; moderate complexity but requires verification of optimization and loss computation correctness.
  • tests/integration/attacks/tf/test_tf_attack.py: Integration test with hardcoded performance assertions; verify expected ROC-AUC and TPR@FPR thresholds are achievable and representative.
  • JSON configuration and metadata files (trans_domain.json, args, updated_config.json variants): Homogeneous, low-complexity additions primarily for test infrastructure.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The PR title 'Bz/tf' is a branch name reference that does not meaningfully describe the changeset. It does not convey what the PR does or what the main change is. Revise the title to clearly describe the main change, e.g., 'Add TensorFlow membership attack module to midst toolkit' or 'Implement TF attack for membership inference attacks'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The PR description covers PR Type and a brief description of the changes, but lacks detail on test coverage and is incomplete per the repository template.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch bz/tf

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

🟡 Minor comments (4)
src/midst_toolkit/attacks/tf/data_utils.py-157-165 (1)

157-165: Deduplication key validation happens after drop_duplicates is called.

The validation for missing keys (lines 162-165) occurs after drop_duplicates is already called (lines 158-159). If keys are missing, drop_duplicates will raise a KeyError before your descriptive ValueError is reached.

+    # Ensure all keys for deduplication exist in both DataFrames before deduplication
+    missing_keys_merge = [key for key in keys_for_deduplication if key not in df_merge.columns]
+    missing_keys_challenge = [key for key in keys_for_deduplication if key not in df_challenge.columns]
+    if missing_keys_merge or missing_keys_challenge:
+        raise ValueError(f"Missing columns for deduplication: {missing_keys_merge + missing_keys_challenge}")
+
     # Deduplicate the datasets once
     df_merge = df_merge.drop_duplicates(subset=keys_for_deduplication)
     df_challenge = df_challenge.drop_duplicates(subset=keys_for_deduplication)
-
-    # Ensure all keys for deduplication exist in both DataFrames
-    missing_keys_merge = [key for key in keys_for_deduplication if key not in df_merge.columns]
-    missing_keys_challenge = [key for key in keys_for_deduplication if key not in df_challenge.columns]
-    if missing_keys_merge or missing_keys_challenge:
-        raise ValueError(f"Missing columns for deduplication: {missing_keys_merge + missing_keys_challenge}")
src/midst_toolkit/attacks/tf/data_utils.py-174-179 (1)

174-179: Potential ValueError on edge case.

If all FPR values are >= max_fpr, tpr[fpr < max_fpr] will be an empty array and max() will raise a ValueError.

 def get_tpr_at_fpr(true_membership: list[int], predictions: list[float], max_fpr: float = 0.1) -> float:
     """
     Calculates the best True Positive Rate when the False Positive Rate is at most `max_fpr`.
     """
     fpr, tpr, _ = roc_curve(true_membership, predictions)
-    return max(tpr[fpr < max_fpr])
+    valid_tpr = tpr[fpr <= max_fpr]
+    if len(valid_tpr) == 0:
+        return 0.0
+    return float(max(valid_tpr))
src/midst_toolkit/attacks/tf/classifcation.py-133-134 (1)

133-134: Condition x_val is not None is always true after tensor conversion.

On line 116, x_val is unconditionally converted to a tensor via torch.tensor(x_val, ...). If the original x_val parameter was None, this would raise an error before reaching line 133. The check should happen before tensor conversion.

+    has_validation = x_val is not None
     x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
     y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device)
-    x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
-    y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device)
+    if has_validation:
+        x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
+        y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)

Then use has_validation in the condition on line 133.

Committable suggestion skipped: line range outside the PR's diff.

tests/integration/attacks/tf/test_tf_attack.py-51-53 (1)

51-53: Use explicit key access instead of dict unpacking for clarity.

The dictionaries returned have keys "max_tpr" and "roc_auc". While Python 3.7+ guarantees dictionary insertion order, explicit key access makes the code clearer and more maintainable.

-    tpr_at_fpr_train, roc_auc_train = mia_performance_train.values()
-    tpr_at_fpr_val, roc_auc_val = mia_performance_val.values()
-    tpr_at_fpr_test, roc_auc_test = mia_performance_test.values()
+    tpr_at_fpr_train = mia_performance_train["max_tpr"]
+    roc_auc_train = mia_performance_train["roc_auc"]
+    tpr_at_fpr_val = mia_performance_val["max_tpr"]
+    roc_auc_val = mia_performance_val["roc_auc"]
+    tpr_at_fpr_test = mia_performance_test["max_tpr"]
+    roc_auc_test = mia_performance_test["roc_auc"]
🧹 Nitpick comments (15)
src/midst_toolkit/attacks/tf/data_utils.py (2)

1-6: Excessive linting suppressions reduce code quality.

Suppressing D102, D105, D103, D200 (docstring rules) and multiple mypy error codes across the entire file is a significant code smell. As per the PR description, "fix typing errors" is listed as a next step — consider addressing these rather than suppressing them.


59-59: Unused verbose parameter.

The verbose parameter is declared but never used in the function body. Either implement verbose logging or remove the parameter.

src/midst_toolkit/attacks/tf/classifcation.py (2)

59-62: Remove dead code: x = x.float() is unused.

The variable x is reassigned on line 60 but never used afterward. This appears to be leftover code.

 def custom_loss_fn(model, x, y):
     confidences = model(x)
-    x = x.float()
     y = y.float()
     return nn.BCELoss()(confidences, y.unsqueeze(1))

114-117: Inconsistent variable naming: y_test vs x_val.

Line 117 uses y_test for labels corresponding to x_val, mixing "test" and "val" terminology. This is confusing given the function parameters use x_val and x_val_label.

     x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
     y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device)
     x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
-    y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device)
+    y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)

Then update references on lines 134, 150 accordingly.

src/midst_toolkit/attacks/tf/tf_attack.py (11)

1-6: Clean up unused lint suppressions.

Static analysis indicates the noqa directive on line 1 is unused. Since the PR author already plans to "fix typing errors", consider removing unnecessary suppressions once proper type annotations are added rather than blanket-disabling mypy checks.


37-82: Several issues in mixed_loss function.

  1. Unused parameter: no_mean is never used (Line 44).
  2. Redundant device assignment: device is assigned on line 51 but immediately overwritten on line 61.
  3. Typo: "defeualt" → "default" (Line 63).
  4. Redundant conditional: The check if not return_random: on line 74 is always True since we already returned on line 67 when return_random=True.
 def mixed_loss(
     diffusion,
     x,
     out_dict,
     noise=None,
     t=None,
     return_random=False,
-    no_mean=False,
     parallel_batch=None,
     addt_value=None,
 ):
     x_num = x[:, : diffusion.num_numerical_features]
     x_cat = x[:, diffusion.num_numerical_features :]

-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    device = x.device
     noise_tensor = torch.tensor(noise, device=device, dtype=torch.float)
     batch_noise = noise_tensor.repeat(x_num.shape[0], 1)

     # there is actually no categorical classes, as we have examined the DM, so we just ignore x_cat here and later
     x_num = x_num.repeat_interleave(parallel_batch, dim=0)
     x_cat = x_cat.repeat_interleave(parallel_batch, dim=0)

     b = x_num.shape[0]

-    device = x.device
     if t is None:
-        # the defeualt is uniform sampling
+        # the default is uniform sampling
         t, pt = diffusion.sample_time(b, device)

     if return_random:
         return noise, t, pt

     additional_t = t * 0 + addt_value

     # forward x_num_t with (t+additional_t) timestamps
     x_num_t = diffusion.gaussian_q_sample(x_num, t + additional_t, noise=batch_noise)

-    if not return_random:
-        current_t = t
-        # predict noises with t timestamps
-        predicted_noise = diffusion._denoise_fn(x_num_t, current_t, **out_dict)
-        current_loss = diffusion._gaussian_loss(predicted_noise, batch_noise, batch_noise, current_t, batch_noise)
-        transformed_current_loss = current_loss.reshape(-1, parallel_batch)
+    current_t = t
+    # predict noises with t timestamps
+    predicted_noise = diffusion._denoise_fn(x_num_t, current_t, **out_dict)
+    current_loss = diffusion._gaussian_loss(predicted_noise, batch_noise, batch_noise, current_t, batch_noise)
+    transformed_current_loss = current_loss.reshape(-1, parallel_batch)

     return transformed_current_loss * 0, transformed_current_loss

136-142: Prefix unused variables with underscore.

label_encoders and column_orders are unpacked but never used. Per convention, prefix with underscore to indicate intentional non-use.

-        dataset, label_encoders, column_orders = Dataset.from_df(
+        dataset, _label_encoders, _column_orders = Dataset.from_df(

207-229: Misplaced docstring.

The docstring is placed in the middle of the function after executable code (lines 201-205). It should be the first statement after the function signature to be recognized by documentation tools and IDEs.

Move the docstring to immediately after line 198 (def get_score(...):) before any executable statements.


247-252: Remove debug print and simplify loop logic.

  1. Debug print: print(iter_max) should be removed or replaced with proper logging.
  2. Confusing loop: The assert iter_max == 1 followed by while iter_id < iter_max: means the loop always runs exactly once. Consider simplifying to remove the loop entirely or documenting why this structure exists for future extensibility.
-        print(iter_max)
         iter_max = iter_max // batch_size
         return_res = torch.zeros([batch_size, parallel_batch])
         assert iter_max == 1
-        iter_id = 0
-        while iter_id < iter_max:
+        # Process single batch (currently only supports iter_max == 1)
+        x, out_dict = next(train_loader)
         ...
-            iter_id += 1

236-236: Prefix unused variable with underscore.

challenge_dataset is unpacked but never used.

-        train_loader, iter_max, challenge_dataset = train_loader_list[loader_count]
+        train_loader, iter_max, _challenge_dataset = train_loader_list[loader_count]

261-269: Prefix unused variables with underscore.

noise and pt are unpacked but never used in this context.

-                noise, t_cur, pt = mixed_loss(
+                _noise, t_cur, _pt = mixed_loss(

309-316: Hardcoded deduplication keys reduce reusability.

The keys ["trans_id", "balance"] are hardcoded in both prepare_data_for_attack calls. Consider making these configurable via a parameter.


479-479: Document magic number for noise dimension.

The noise dimension size=8 is hardcoded without explanation. Consider extracting this to a named constant or parameter with documentation explaining why 8 is the appropriate value.

+    NOISE_DIMENSION = 8  # Must match the diffusion model's expected noise dimension
-    input_noise: list[list[float]] = [np.random.normal(size=8).tolist() for _ in range(num_noise_per_time_step)]
+    input_noise: list[list[float]] = [np.random.normal(size=NOISE_DIMENSION).tolist() for _ in range(num_noise_per_time_step)]

187-198: Unused phase parameter.

The phase parameter is accepted but never used in the function body. Either remove it or implement the intended behavior.

 def get_score(
     data_path,
     save_dir,
     input_noise,
     type="tabddpm",
-    phase=None,
     challenge_name=None,
     batch_size=None,
     parallel_batch=None,
     addt_value=None,
     t_value=None,
 ):

If phase is intended for future use, add a TODO comment or raise NotImplementedError when a non-None value is passed.


191-191: Avoid shadowing built-in type.

Using type as a parameter name shadows Python's built-in type() function, which can cause subtle bugs if the built-in is needed within this function.

 def get_score(
     data_path,
     save_dir,
     input_noise,
-    type="tabddpm",
+    model_type="tabddpm",
     ...
 ):
-    if type == "tabddpm":
+    if model_type == "tabddpm":
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6eb18e9 and 62a6fec.

⛔ Files ignored due to path filters (49)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/data_for_training_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/data_for_training_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/data_for_validating_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/data_for_validating_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_synthetic.csv is excluded by !**/*.csv
📒 Files selected for processing (22)
  • src/midst_toolkit/attacks/tf/classifcation.py (1 hunks)
  • src/midst_toolkit/attacks/tf/data_utils.py (1 hunks)
  • src/midst_toolkit/attacks/tf/tf_attack.py (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/data_configs/dataset_meta.json (1 hunks)
  • tests/integration/attacks/tf/data_configs/trans.json (1 hunks)
  • tests/integration/attacks/tf/data_configs/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/test_tf_attack.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (13)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (2)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (2)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
src/midst_toolkit/models/clavaddpm/model.py (1)
  • DiffusionParameters (19-31)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/workspace/train_1/args (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (2)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/data_configs/trans_domain.json (1)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/integration/attacks/tf/data_configs/trans.json (2)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
src/midst_toolkit/attacks/ensemble/blending.py (1)
  • __init__ (27-66)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/test_tf_attack.py (2)
src/midst_toolkit/attacks/tf/tf_attack.py (1)
  • tf_attack (459-559)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/attacks/tf/data_utils.py (3)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
  • Dataset (77-397)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • Normalization (58-63)
src/midst_toolkit/models/clavaddpm/data_loaders.py (1)
  • FastTensorDataLoader (473-537)
🪛 Ruff (0.14.7)
src/midst_toolkit/attacks/tf/tf_attack.py

1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200, PLR0915)

Remove unused noqa directive

(RUF100)


44-44: Unused function argument: no_mean

(ARG001)


85-85: Avoid specifying long messages outside the exception class

(TRY003)


136-136: Unpacked variable label_encoders is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


136-136: Unpacked variable column_orders is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


192-192: Unused function argument: phase

(ARG001)


203-203: Avoid specifying long messages outside the exception class

(TRY003)


236-236: Unpacked variable challenge_dataset is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


261-261: Unpacked variable noise is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


261-261: Unpacked variable pt is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

src/midst_toolkit/attacks/tf/data_utils.py

1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200)

Remove unused noqa directive

(RUF100)


59-59: Unused function argument: verbose

(ARG001)


75-75: Avoid specifying long messages outside the exception class

(TRY003)


95-95: Avoid specifying long messages outside the exception class

(TRY003)


101-101: Avoid specifying long messages outside the exception class

(TRY003)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


165-165: Avoid specifying long messages outside the exception class

(TRY003)


190-190: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: run-code-check
  • GitHub Check: unit-tests
  • GitHub Check: integration-tests
🔇 Additional comments (16)
src/midst_toolkit/attacks/tf/data_utils.py (1)

218-261: LGTM!

The FastTensorDataLoader implementation correctly handles batching and shuffling. The pattern matches the reference implementation from src/midst_toolkit/models/clavaddpm/data_loaders.py.

tests/integration/attacks/tf/data_configs/trans.json (1)

1-50: LGTM!

This configuration uses relative paths consistently, making it portable across different environments unlike some of the other tabddpm config files.

tests/integration/attacks/tf/data_configs/dataset_meta.json (1)

1-1: ✓ Test metadata structure is appropriate.

The JSON correctly describes a single-table dataset with no relationships, which aligns with the trans domain schema and test configuration patterns.

tests/integration/attacks/tf/data_configs/trans_domain.json (1)

1-1: ✓ Domain schema is consistent and well-formed.

The feature definitions (8 fields with appropriate continuous/discrete classifications) align with the dataset_meta.json and are consistent across all tabddpm model variants used in tests.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)

1-1: ✓ Consistent test asset.

Domain schema matches across all tabddpm variants, which is appropriate for uniform test configuration.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1)

1-1: ✓ Consistent test asset.

Matches schema across all tabddpm variants for uniform testing.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1)

1-1: ✓ Consistent test asset.

Matches schema across all tabddpm variants for uniform testing.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1)

1-1: ✓ Consistent test asset.

Matches schema across all tabddpm variants for uniform testing.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (2)

1-8: Verify absolute workspace_dir path will not break tests.

Line 5 contains an absolute path /projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_2/workspace that does not exist in typical CI/test environments. This may cause test failures when the configuration is used to create directories or write outputs.

Check whether:

  • The integration test overrides this path at runtime (e.g., using save_additional_tabddpm_config)
  • The code gracefully handles missing parent directories
  • Tests are expected to run in an environment where /projects/ exists

The relative paths for data_dir and test_data_dir (lines 3, 7) are appropriate for test portability.


14-42: ✓ Hyperparameter choices are appropriate for fast integration testing.

Minimal values (iterations=2–3, batch_size=1, num_timesteps=3) prioritize test speed while maintaining realistic configuration structure. This aligns with typical test fixtures for ML frameworks.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (2)

1-8: Verify absolute workspace_dir path will not break tests.

Line 5 contains an absolute path /projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace that does not exist in typical CI/test environments. This matches the same issue in tabddpm_2/updated_config.json and suggests a systematic configuration pattern.

Verify that:

  • The integration test overrides workspace_dir paths at runtime (expected based on save_additional_tabddpm_config utility)
  • Absolute paths are intentionally used as templates and not written to during tests

The relative test data paths (lines 3, 7) are appropriate for portability.


14-42: ✓ Hyperparameter choices are appropriate for fast integration testing.

Configuration values match the appropriate minimal test settings (iterations=2–3, batch_size=1, num_timesteps=3), balancing test speed with realistic model structure.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1)

1-1: LGTM!

The domain schema defines the expected field structure with appropriate size and type metadata for the tabddpm_4 model configuration.

tests/integration/attacks/tf/test_tf_attack.py (1)

20-28: Good use of deterministic seeding for reproducible tests.

The combination of set_all_random_seeds with use_deterministic_torch_algos=True and CUBLAS_WORKSPACE_CONFIG environment variable ensures reproducible results, which is essential for testing ML code with fixed expected values.

src/midst_toolkit/attacks/tf/classifcation.py (1)

8-40: LGTM: MLP architecture is straightforward and well-documented.

The 3-layer MLP with tanh activations and sigmoid output is appropriate for binary classification in the attack context.

src/midst_toolkit/attacks/tf/tf_attack.py (1)

81-81: Clarify intent of returning zeroed tensor.

The return statement returns transformed_current_loss * 0 as the first element, which is always a zero tensor. If this is intentional (e.g., for API compatibility or specific algorithm requirements), consider adding a comment explaining why. Otherwise, this may be leftover debugging code.

)
return loss.item(), tpr

best_model_path = best_model_dir / "best_model.pt"
Copy link

@coderabbitai coderabbitai bot Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Potential TypeError if best_model_dir is None.

If best_model_dir is not provided (defaults to None), this line will raise TypeError: unsupported operand type(s) for /: 'NoneType' and 'str'. Add a guard or require the parameter.

+    if best_model_dir is None:
+        best_model_dir = Path(".")  # or raise ValueError
     best_model_path = best_model_dir / "best_model.pt"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
best_model_path = best_model_dir / "best_model.pt"
if best_model_dir is None:
best_model_dir = Path(".") # or raise ValueError
best_model_path = best_model_dir / "best_model.pt"
🤖 Prompt for AI Agents
In src/midst_toolkit/attacks/tf/classifcation.py around line 110, the expression
best_model_dir / "best_model.pt" will raise a TypeError if best_model_dir is
None; add a guard to validate best_model_dir before using the / operator (either
raise a clear ValueError if None, coerce a default Path, or require the
parameter to be non-None via function signature), and ensure best_model_dir is a
pathlib.Path (or convert it with Path(best_model_dir)) before constructing
best_model_path.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to my comment above as well 🙂

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Fetch all review comments from emersodb on PR #104
gh pr view 104 --json comments,reviews --jq '.reviews[] | select(.author.login == "emersodb") | {body: .body, state: .state}' 2>/dev/null || echo "No reviews found"

# Also check individual review comments
gh api repos/VectorInstitute/midst-toolkit/pulls/104/comments --jq '.[] | select(.user.login == "emersodb") | {path: .path, line: .line, body: .body}' 2>/dev/null || echo "No review comments found"

Length of output: 2192


🏁 Script executed:

#!/bin/bash
# Get the full content of the classification.py file to see the logic around lines 110-120
cat src/midst_toolkit/attacks/tf/classification.py | sed -n '105,120p'

Length of output: 625


@emersodb Ah yes, I see both issues now! 🎯

Looking at lines 112-116:

if use_best_checkpoint and best_model_dir is not None:
    best_model_dir = Path(".")  # Line 113 - overwrites the user's input!
    print(f"Best model will be saved to: {best_model_dir}")

best_model_path = best_model_dir / "best_model.pt"  # Line 116 - fails if None

Two bugs here:

  1. Your point (line 113): When best_model_dir is provided, it gets unconditionally overwritten with Path("."), ignoring the user's input completely.

  2. CodeRabbit's point (line 116): When best_model_dir is None, the / operator raises a TypeError.

Suggested fix:

if use_best_checkpoint:
    if best_model_dir is None:
        raise ValueError("best_model_dir must be provided when use_best_checkpoint is True")
    best_model_dir = Path(best_model_dir)  # Convert to Path without overwriting
    print(f"Best model will be saved to: {best_model_dir}")
else:
    best_model_dir = Path(".")  # Default only when checkpointing is disabled

best_model_path = best_model_dir / "best_model.pt"

Or, if you want to allow a default:

if use_best_checkpoint:
    best_model_dir = Path(best_model_dir) if best_model_dir else Path(".")
    print(f"Best model will be saved to: {best_model_dir}")
else:
    best_model_dir = Path(".")

best_model_path = best_model_dir / "best_model.pt"

This also aligns with your suggestion to use Path objects only (from your comment on line 90).

Comment on lines +148 to +152
for t in indices:
base_path = models_base_dir / f"{model_type}_{t}"
df_merge_list.append(pd.read_csv(os.path.join(base_path, "train_with_id.csv")))
df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv")))
df_challenge_labels = pd.read_csv(os.path.join(base_path, "challenge_label.csv"))
Copy link

@coderabbitai coderabbitai bot Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if challenge_label.csv files differ across tabddpm models
for dir in tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_*/; do
  if [ -f "$dir/challenge_label.csv" ]; then
    echo "=== $(basename $dir) ==="
    head -5 "$dir/challenge_label.csv"
  fi
done

Repository: VectorInstitute/midst-toolkit

Length of output: 247


🏁 Script executed:

#!/bin/bash
# Read the full prepare_data_for_attack function to see return statement and full context
cat -n src/midst_toolkit/attacks/tf/data_utils.py | sed -n '138,171p'

Repository: VectorInstitute/midst-toolkit

Length of output: 1999


🏁 Script executed:

#!/bin/bash
# Search for callers of prepare_data_for_attack to understand how returned values are used
rg "prepare_data_for_attack" --type py -A 5

Repository: VectorInstitute/midst-toolkit

Length of output: 1934


df_challenge_labels is overwritten in each iteration; only the last model's labels are returned.

The loop reads challenge_label.csv for each index but overwrites df_challenge_labels on every iteration (line 152). Challenge labels differ across models, so this causes data loss. Either append labels to a list and return concatenated data, or document if only the last model's labels are intentional. Note: current callers ignore this return value by using _, but the bug still exists.

🤖 Prompt for AI Agents
In src/midst_toolkit/attacks/tf/data_utils.py around lines 148 to 152,
df_challenge_labels is overwritten on every loop iteration so only the last
model's labels are kept; change the code to collect labels per-iteration (e.g.,
create a df_challenge_labels_list before the loop, append each pd.read_csv(...)
to that list) and after the loop concatenate them with pd.concat(...) (or
otherwise aggregate appropriately) and return the combined DataFrame;
alternatively, if returning only the last model's labels was intended, add an
explicit comment and keep behavior but ensure callers expect that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important one to address. We're not appending the labels to this list here. Is this expected behavior? If so, we should at least have a comment that explains why we're not appending.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if challenge_label.csv files exist in the test assets and if they differ
echo "=== Checking for challenge_label.csv files in test assets ==="
fd -t f "challenge_label.csv" tests/integration/attacks/tf/assets/

echo -e "\n=== Checking file sizes and content samples ==="
for file in $(fd -t f "challenge_label.csv" tests/integration/attacks/tf/assets/); do
  if [ -f "$file" ]; then
    echo "File: $file"
    echo "Size: $(wc -l < "$file") lines"
    echo "First 3 lines:"
    head -3 "$file"
    echo "---"
  fi
done

Length of output: 1190


🏁 Script executed:

#!/bin/bash
# Search for usage of challenge_label and challenge labels in the codebase
echo "=== Searching for 'challenge_label' references in Python files ==="
rg "challenge_label" --type py -C 2

echo -e "\n=== Checking if challenge labels are documented anywhere ==="
rg "challenge.*label" --type md -i

Length of output: 5699


🏁 Script executed:

#!/bin/bash
# Check the full function implementation and see how deduplication works
echo "=== Full prepare_data_for_attack function ==="
sed -n '/^def prepare_data_for_attack/,/^def \|^class \|^$/p' src/midst_toolkit/attacks/tf/data_utils.py | head -50

Length of output: 345


🏁 Script executed:

#!/bin/bash
# Get the complete prepare_data_for_attack function
echo "=== Complete prepare_data_for_attack function ==="
awk '/^def prepare_data_for_attack/,/^def [^_]|^class / {print NR": "$0}' src/midst_toolkit/attacks/tf/data_utils.py | head -40

Length of output: 145


🏁 Script executed:

#!/bin/bash
# Check how the function is actually called in tf_attack.py
echo "=== Usage in tf_attack.py ==="
rg "prepare_data_for_attack" src/midst_toolkit/attacks/tf/tf_attack.py -B 3 -A 5

Length of output: 694


🏁 Script executed:

#!/bin:bash
# Compare first and last challenge_label files to confirm they differ
echo "=== Comparing challenge_label files ==="
echo "First model (tabddpm_1):"
head -10 tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csv
echo -e "\nLast model (tabddpm_5):"
head -10 tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csv
echo -e "\nFull diff check:"
diff tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csv tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csv | head -20

Length of output: 258


@emersodb, excellent catch! This is definitely a bug. Here's what I found:

Evidence

The challenge_label.csv files differ significantly across models:

tabddpm_1 (rows 2-4):

1
1
1

tabddpm_5 (rows 2-4):

0
1
0

Running a full diff shows dozens of differences between these files.

The Problem

At line 152, df_challenge_labels is overwritten in each iteration:

for t in indices:
    base_path = models_base_dir / f"{model_type}_{t}"
    df_merge_list.append(pd.read_csv(...))           # ✅ appended
    df_challenge_list.append(pd.read_csv(...))       # ✅ appended
    df_challenge_labels = pd.read_csv(...)           # ❌ overwritten!

This means only the last model's labels are returned, discarding all previous models' labels—even though those labels contain different data.

Current Impact

The callers in tf_attack.py currently ignore this return value with _, _, _, so the bug isn't causing runtime issues yet. However, the function's contract is broken.

Recommended Fix

Change line 152 and the code after the loop to match the pattern used for the other two DataFrames:

df_merge_list = []
df_challenge_list = []
df_challenge_labels_list = []  # Add this

for t in indices:
    base_path = models_base_dir / f"{model_type}_{t}"
    df_merge_list.append(pd.read_csv(os.path.join(base_path, "train_with_id.csv")))
    df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv")))
    df_challenge_labels_list.append(pd.read_csv(os.path.join(base_path, "challenge_label.csv")))  # Change this

df_merge = pd.concat(df_merge_list, ignore_index=True)
df_challenge = pd.concat(df_challenge_list, ignore_index=True)
df_challenge_labels = pd.concat(df_challenge_labels_list, ignore_index=True)  # Add this

Then proceed with deduplication as before. This will ensure all challenge labels from all models are properly collected and returned.

Should I open a follow-up issue for this fix, or would you like to address it in this PR?

Comment on lines 287 to 294
if dataset.numerical_feautures is not None:
x = torch.from_numpy(
np.concatenate([dataset.numerical_feautures[split], dataset.categorical_features[split]], axis=1)
).float()
else:
x = torch.from_numpy(dataset.categorical_features[split]).float()
else:
x = torch.from_numpy(dataset.numerical_features[split]).float()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Typo in attribute name causes runtime AttributeError.

Line 287 uses numerical_feautures (misspelled with 'au') while line 294 correctly uses numerical_features. This will raise an AttributeError when categorical_features is not None but numerical_features is present.

     if dataset.categorical_features is not None:
-        if dataset.numerical_feautures is not None:
+        if dataset.numerical_features is not None:
             x = torch.from_numpy(
-                np.concatenate([dataset.numerical_feautures[split], dataset.categorical_features[split]], axis=1)
+                np.concatenate([dataset.numerical_features[split], dataset.categorical_features[split]], axis=1)
             ).float()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if dataset.numerical_feautures is not None:
x = torch.from_numpy(
np.concatenate([dataset.numerical_feautures[split], dataset.categorical_features[split]], axis=1)
).float()
else:
x = torch.from_numpy(dataset.categorical_features[split]).float()
else:
x = torch.from_numpy(dataset.numerical_features[split]).float()
if dataset.numerical_features is not None:
x = torch.from_numpy(
np.concatenate([dataset.numerical_features[split], dataset.categorical_features[split]], axis=1)
).float()
else:
x = torch.from_numpy(dataset.categorical_features[split]).float()
else:
x = torch.from_numpy(dataset.numerical_features[split]).float()
🤖 Prompt for AI Agents
In src/midst_toolkit/attacks/tf/data_utils.py around lines 287 to 294, there's a
misspelled attribute `numerical_feautures` that will raise AttributeError;
change that misspelling to `numerical_features` so the code accesses the correct
attribute when concatenating numerical and categorical arrays, and ensure any
other references in this block use the correct `numerical_features` name.

"general": {
"data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data",
"exp_name": "train_1",
"workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_3/workspace",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded absolute path will break portability.

The workspace_dir uses an absolute path that is machine-specific. This should be updated to use a relative path or a configurable value to ensure the tests work across different environments.

🤖 Prompt for AI Agents
In
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args
around line 5, the workspace_dir is a hardcoded absolute path which breaks
portability; change it to a relative path or configuration-driven value (e.g.,
use a path relative to the repository root or an environment variable like
WORKSPACE_DIR) and update any test setup to resolve the chosen
relative/configurable path at runtime so the tests work across different
machines.

"general": {
"data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data",
"exp_name": "train_1",
"workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded absolute path will break portability.

The workspace_dir contains an absolute path /projects/midst-experiments/... that is specific to a particular machine. This will fail in CI environments and for other developers.

Consider using a relative path or a path that can be configured dynamically during test execution, similar to how data_dir uses a relative path.

🤖 Prompt for AI Agents
In
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json
around line 5, the workspace_dir is a hardcoded absolute path
(/projects/midst-experiments/...), which breaks portability; change it to a
relative path (e.g., "./workspace" or "train/tabddpm_4/workspace") or a
configurable placeholder (e.g., "${WORKSPACE_DIR}") and update the test setup to
inject/resolve that variable at runtime so CI and other developers can run the
tests without machine-specific paths.

"general": {
"data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data",
"exp_name": "train_1",
"workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded absolute path will break test portability.

The workspace_dir contains an absolute path (/projects/midst-experiments/...) which will fail on different machines or CI environments. Consider using a relative path or a placeholder that gets resolved at runtime, similar to how data_dir uses a relative path.

-        "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace",
+        "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace",
"workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace",
🤖 Prompt for AI Agents
In
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args
around line 5, the workspace_dir is a hardcoded absolute path which breaks
portability; change it to a relative path or placeholder (e.g., ./workspace or
${WORKSPACE_DIR}) and update test/bootstrap logic to resolve that placeholder at
runtime (or set the env var in test setup) so the path is computed
per-environment rather than fixed.

"general": {
"data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data",
"exp_name": "train_1",
"workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded absolute path will break portability.

Same issue as other workspace configs — the absolute workspace_dir path should be made relative or configurable.

🤖 Prompt for AI Agents
In
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args
around line 5, the workspace_dir is a hardcoded absolute path which breaks
portability; change it to a relative path or read from a configurable source
(env var or test fixture) instead — replace the absolute path with a
project-relative path (e.g., ./workspace or ../<appropriate> path) or use a
placeholder that the test harness fills from an environment variable or config
so CI and other machines can resolve the workspace directory at runtime.

Comment on lines 9 to 11
# Add paths
sys.path.append("/h/behnzaman/")
sys.path.insert(0, "/h/behnzaman/midst-experiments/deps/TF_attack/")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Remove hardcoded developer-specific paths from sys.path.

These absolute paths are specific to a developer's home directory and will cause import failures for other developers and in CI environments. If these dependencies are required, they should be properly installed as packages or the paths should be configurable via environment variables.

-# Add paths
-sys.path.append("/h/behnzaman/")
-sys.path.insert(0, "/h/behnzaman/midst-experiments/deps/TF_attack/")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Add paths
sys.path.append("/h/behnzaman/")
sys.path.insert(0, "/h/behnzaman/midst-experiments/deps/TF_attack/")
🤖 Prompt for AI Agents
In tests/integration/attacks/tf/test_tf_attack.py around lines 9-11, remove the
two hardcoded sys.path.append/sys.path.insert calls that reference a developer's
home directory; instead ensure the dependency is available via normal packaging
(add it to requirements/dev and pip install it or use an editable install during
development) or make the path configurable (read a TF_ATTACK_PATH env var and
validate it, or compute a project-relative path using pathlib from the
repository root). Update the test to import the package normally once installed,
or if a temporary path injection is unavoidable, derive the path from an
environment variable or project root instead of using absolute user-specific
paths.

@bzamanlooy
Copy link
Collaborator Author

@coderabbitai full review

@coderabbitai
Copy link

coderabbitai bot commented Dec 4, 2025

✅ Actions performed

Full review triggered.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (7)
src/midst_toolkit/attacks/tf/data_utils.py (1)

138-171: Still only returning the last model’s challenge_label.csv

Inside prepare_data_for_attack, df_challenge_labels is reassigned on every iteration (Line 152) and never accumulated. With multiple indices, you return labels only from the last model, while df_merge and df_challenge are concatenated across all models. This is the same issue previously flagged.

If you intend to aggregate across models, collect labels in a list and concatenate:

 def prepare_data_for_attack(indices, model_type, models_base_dir, keys_for_deduplication):
@@
-    df_challenge_list = []
+    df_challenge_list = []
+    df_challenge_labels_list = []
@@
-        df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv")))
-        df_challenge_labels = pd.read_csv(os.path.join(base_path, "challenge_label.csv"))
+        df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv")))
+        df_challenge_labels_list.append(pd.read_csv(os.path.join(base_path, "challenge_label.csv")))
@@
-    df_challenge = pd.concat(df_challenge_list, ignore_index=True)
+    df_challenge = pd.concat(df_challenge_list, ignore_index=True)
+    df_challenge_labels = pd.concat(df_challenge_labels_list, ignore_index=True)
@@
-    return df_merge_without_challenge, df_challenge, df_challenge_labels
+    return df_merge_without_challenge, df_challenge, df_challenge_labels

If only the last model’s labels are truly desired, add an explicit comment and enforce that indices has length 1.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args (1)

2-8: Replace hardcoded absolute workspace_dir with a project-relative path

workspace_dir is an absolute path (Line 5), which will fail on other machines and CI. For a test asset, it should be under the repo with a relative path, similar to data_dir and other configs.

-        "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace",
+        "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace",
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (1)

5-5: Hardcoded absolute path breaks CI/portability.

This issue was previously flagged. The workspace_dir contains an absolute path specific to a particular machine, which will fail in CI and for other developers.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1)

5-5: Hardcoded absolute path breaks CI/portability.

This issue was previously flagged. The workspace_dir contains an absolute path specific to a particular machine, which will fail in CI and for other developers.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (1)

3-7: Avoid hardcoded absolute workspace_dir in test config.

"workspace_dir" is an absolute, developer/cluster-specific path and will not exist on other machines/CI. Use a project-relative path (like the other fields) or a placeholder that your test harness fills in.

-        "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace",
+        "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace",
src/midst_toolkit/attacks/tf/tf_attack.py (2)

383-435: Use addt_value_list when building training/validation features; current loop ignores it.

Here you still hardcode for addt_value in [0]:, so addt_value_list is effectively unused. At the same time, the feature matrix width is computed as len(input_noise) * len(timesteps_list) * len(addt_value_list), meaning any addt_value_list with length > 1 leaves trailing zero columns, and the classifier input_dim is larger than the actually populated features.

Refactor the nested loop to iterate over addt_value_list and index feature blocks accordingly:

-        t_value_count = 0
-        for t_value in timesteps_list:
-            for addt_value in [0]:
+        feature_block_index = 0
+        for t_value in timesteps_list:
+            for addt_value in addt_value_list:
                 if model_number in train_indices:
@@
-                    x_train[
-                        samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1),
-                        t_value_count * num_noise_per_time_step : (t_value_count + 1) * num_noise_per_time_step,
-                    ] = predictions.detach().squeeze().cpu().numpy()
+                    start = feature_block_index * num_noise_per_time_step
+                    end = (feature_block_index + 1) * num_noise_per_time_step
+                    x_train[
+                        samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1),
+                        start:end,
+                    ] = predictions.detach().squeeze().cpu().numpy()
@@
-                    x_train_label[
-                        samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1)
-                    ] = np.concatenate([np.zeros(samples_per_train_model), np.ones(samples_per_train_model)])
-                    t_value_count += 1
+                    x_train_label[
+                        samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1)
+                    ] = np.concatenate([np.zeros(samples_per_train_model), np.ones(samples_per_train_model)])
+                    feature_block_index += 1
@@
-                    x_val[
-                        sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1),
-                        t_value_count * num_noise_per_time_step : (t_value_count + 1) * num_noise_per_time_step,
-                    ] = predictions.detach().squeeze().cpu().numpy()
+                    start = feature_block_index * num_noise_per_time_step
+                    end = (feature_block_index + 1) * num_noise_per_time_step
+                    x_val[
+                        sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1),
+                        start:end,
+                    ] = predictions.detach().squeeze().cpu().numpy()
@@
-                    x_val_label[sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1)] = (
-                        np.concatenate([np.zeros(sample_per_val_model), np.ones(sample_per_val_model)])
-                    )
-                    t_value_count += 1
+                    x_val_label[sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1)] = (
+                        np.concatenate([np.zeros(sample_per_val_model), np.ones(sample_per_val_model)])
+                    )
+                    feature_block_index += 1

This way, x_train/x_val and the MLP’s input_dim stay consistent for arbitrary addt_value_list.


503-528: Respect addt_value_list, avoid hardcoded batch_size, and guard min–max normalization against zero range.

Three issues remain in tf_attack’s scoring loop:

  1. Hardcoded batch_size = 200:

    • This magic number is baked into the function, making it harder to reuse in other scenarios or tests.
  2. addt_value_list is ignored:

    • The loop still uses for addt_value in [0]:, so callers can’t vary addt_value despite it being a parameter.
  3. Potential division by zero in min–max normalization:

    • If all model outputs are identical (min_output == max_output), (max_output - min_output) is zero, leading to nans and failing the subsequent assertion.

A possible fix:

-def tf_attack(
+def tf_attack(
@@
-    classifier_hidden_dim: int,
-    addt_value_list: list[int],
-    meta_dir: Path,
-) -> tuple[Any, Any, Any]:
+    classifier_hidden_dim: int,
+    addt_value_list: list[int],
+    meta_dir: Path,
+    batch_size: int = 200,
+) -> tuple[Any, Any, Any]:
@@
-        model_dir = tabddpm_data_dir / model_folder
-        model_path = model_dir / target_model_subdir
-        batch_size = 200
-        t_value_count = 0
+        model_dir = tabddpm_data_dir / model_folder
+        model_path = model_dir / target_model_subdir
         current_input = []
         for t_value in timesteps_list:
-            for addt_value in [0]:
+            for addt_value in addt_value_list:
                 predictions: torch.Tensor = get_score(
@@
-                )
-                t_value_count += 1
-                current_input.append(predictions)
+                )
+                current_input.append(predictions)
@@
-        predictions = regression_model(predictions).detach().cpu().numpy()
-        # clip to [0, 1]
-        min_output, max_output = np.min(predictions), np.max(predictions)
-        predictions = (predictions - min_output) / (max_output - min_output)
+        predictions = regression_model(predictions).detach().cpu().numpy()
+        # Rescale to [0, 1] safely
+        min_output, max_output = np.min(predictions), np.max(predictions)
+        range_output = max_output - min_output
+        if range_output > 0:
+            predictions = (predictions - min_output) / range_output
+        else:
+            # All outputs identical; fall back to a neutral constant
+            predictions = np.full_like(predictions, 0.5)
         predictions = torch.tensor(predictions)

This keeps the existing integration test behavior (batch_size defaults to 200, addt_value_list=[0]) while making the function robust and reusable.

🧹 Nitpick comments (8)
src/midst_toolkit/attacks/tf/data_utils.py (3)

1-18: Tighten lint/typing pragmas and clean leftover comment

  • Line 1: # ruff: noqa: D102, D105, D103, D200 is flagged as unused; if these rules aren’t enabled, drop this directive instead of carrying dead config.
  • Lines 2–6: multiple broad mypy disables (no-untyped-def, has-type, index, attr-defined, assignment) effectively opt this module out of type checking. Given the PR TODO “fix typing errors”, it would be better long‑term to narrow these or remove them once annotations are in place.
  • Line 16: Comment # at very top of file (optional but helpful) looks like a review note rather than code documentation and can be removed.
-# ruff: noqa: D102, D105, D103, D200
-# mypy: disable-error-code=no-untyped-def
-# mypy: disable-error-code=has-type
-# mypy: disable-error-code=index
-# mypy: disable-error-code=attr-defined
-# mypy: disable-error-code=assignment
-from __future__ import annotations
-
-# at very top of file (optional but helpful)
+from __future__ import annotations
@@
-from typing import Any, Literal
+from typing import Any, Literal

(And later, once typing is addressed, consider re‑enabling mypy checks instead of disabling them globally for this file.)


59-76: verbose parameter in load_multi_table_customized is currently unused

The verbose argument (Line 59) is never read, which is flagged by static analysis and can confuse callers expecting logging or extra checks.

Either remove the parameter if not needed:

-def load_multi_table_customized(data_dir, meta_dir=None, train_name="train.csv", verbose=True):
+def load_multi_table_customized(data_dir, meta_dir=None, train_name="train.csv"):

or actually use it (e.g., to print/log dataset/table info when verbose is True).


267-304: prepare_fast_dataloader behavior (infinite batches, feature concatenation) is intentional but worth documenting clearly

  • Correctly handles three cases: both numerical+categorical, categorical‑only, and numerical‑only.
  • Returns an endless stream of (x, y) batches via while True: yield from dataloader, reshuffling each epoch when split == "train".

Consider making the “infinite generator” behavior even more explicit in the docstring (e.g., “This is an infinite generator; callers must bound iteration externally”) to avoid misuse in code that expects a single finite epoch.

tests/unit/evaluation/quality/test_alpha_precision_naive.py (1)

53-58: Redundant conditional branches with identical assertions.

Both the if is_apple_silicon() and else branches contain identical assertions. Either the branching is unnecessary and can be simplified, or the expected values for the non-Apple Silicon path should differ.

If the values are truly architecture-independent for naive metrics, simplify:

-    if is_apple_silicon():
-        assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"]
-        assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"]
-    else:
-        assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"]
-        assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"]
+    assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"]
+    assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"]
tests/integration/attacks/tf/test_tf_attack.py (2)

45-54: Remove redundant .values() unpacking before explicit key-based access.

You first unpack mia_performance_* .values() into variables, then immediately overwrite them using explicit key lookups. The unpacking is unnecessary and relies on dict insertion order; just keep the key-based access:

-    mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config)
-    tpr_at_fpr_train, roc_auc_train = mia_performance_train.values()
-    tpr_at_fpr_val, roc_auc_val = mia_performance_val.values()
-    tpr_at_fpr_test, roc_auc_test = mia_performance_test.values()
-    tpr_at_fpr_train = mia_performance_train["max_tpr"]
-    roc_auc_train = mia_performance_train["roc_auc"]
-    tpr_at_fpr_val = mia_performance_val["max_tpr"]
-    roc_auc_val = mia_performance_val["roc_auc"]
-    tpr_at_fpr_test = mia_performance_test["max_tpr"]
-    roc_auc_test = mia_performance_test["roc_auc"]
+    mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config)
+    tpr_at_fpr_train = mia_performance_train["max_tpr"]
+    roc_auc_train = mia_performance_train["roc_auc"]
+    tpr_at_fpr_val = mia_performance_val["max_tpr"]
+    roc_auc_val = mia_performance_val["roc_auc"]
+    tpr_at_fpr_test = mia_performance_test["max_tpr"]
+    roc_auc_test = mia_performance_test["roc_auc"]

69-119: Consider deleting the large commented-out alternative test with absolute paths.

This block is dead code and contains hardcoded, developer-specific absolute paths. Keeping it commented out adds noise and may confuse future readers about which setup is authoritative. Prefer removing it or moving the scenario into a separate, properly parameterized test/config.

src/midst_toolkit/attacks/tf/tf_attack.py (2)

1-7: Remove or narrow unused ruff: noqa directive.

Ruff reports this noqa as unused for D102, D105, D103, D200, PLR0915. If those checks aren’t actually enabled, the directive is unnecessary noise. Either remove it or restrict it to the specific rules you need to silence.

-# ruff: noqa: D102, D105, D103, D200, PLR0915
+# ruff: noqa

or simply delete the line if you don’t need it at all.


113-181: Guard get_dataset against missing target_model_dir/batch_size instead of relying on fragile defaults.

get_dataset defaults target_model_dir=None and batch_size=None, but the body assumes both are set:

  • os.path.join(target_model_dir, ...) will fail if target_model_dir is None.
  • prepare_fast_dataloader(..., batch_size=batch_size, ...) likely expects an integer.

From this module it’s always called with non-None arguments, but the signature suggests they’re optional. Either make them required or fail fast with clear errors:

-def get_dataset(data_path, target_model_dir=None, train_name="train_with_id.csv", batch_size=None, meta_dir=""):
+def get_dataset(data_path, target_model_dir=None, train_name="train_with_id.csv", batch_size=None, meta_dir=""):
@@
-    tables, relation_order, _ = load_multi_table_customized(
+    if target_model_dir is None:
+        raise ValueError("target_model_dir must be provided.")
+    if batch_size is None:
+        raise ValueError("batch_size must be provided.")
+
+    tables, relation_order, _ = load_multi_table_customized(
         data_path,
         meta_dir=meta_dir,
         train_name=train_name,

This makes incorrect external use of get_dataset fail deterministically instead of with a less obvious TypeError.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6eb18e9 and b1d9f8a.

⛔ Files ignored due to path filters (52)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/data_for_training_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/data_for_training_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/data_for_validating_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/data_for_validating_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_synthetic.csv is excluded by !**/*.csv
📒 Files selected for processing (18)
  • src/midst_toolkit/attacks/tf/classification.py (1 hunks)
  • src/midst_toolkit/attacks/tf/data_utils.py (1 hunks)
  • src/midst_toolkit/attacks/tf/tf_attack.py (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/data_configs/dataset_meta.json (1 hunks)
  • tests/integration/attacks/tf/data_configs/trans.json (1 hunks)
  • tests/integration/attacks/tf/data_configs/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/test_tf_attack.py (1 hunks)
  • tests/unit/evaluation/quality/test_alpha_precision_naive.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/unit/evaluation/quality/test_alpha_precision_naive.py (4)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/data_processing/midst_data_processing.py (2)
  • load_midst_data (94-121)
  • process_midst_data_for_alpha_precision_evaluation (17-91)
src/midst_toolkit/evaluation/utils.py (2)
  • extract_columns_based_on_meta_info (45-87)
  • one_hot_encode_categoricals_and_merge_with_numerical (90-128)
tests/utils/architecture.py (1)
  • is_apple_silicon (4-6)
src/midst_toolkit/attacks/tf/data_utils.py (2)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
  • Dataset (77-397)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • Normalization (58-63)
src/midst_toolkit/attacks/tf/tf_attack.py (4)
src/midst_toolkit/attacks/tf/classification.py (2)
  • MLP (10-42)
  • fitmodel (67-161)
src/midst_toolkit/attacks/tf/data_utils.py (6)
  • CustomUnpickler (51-56)
  • TaskType (42-48)
  • evaluate_attack_performance (185-218)
  • load_multi_table_customized (59-103)
  • prepare_data_for_attack (138-171)
  • prepare_fast_dataloader (267-304)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
  • Dataset (77-397)
  • from_df (276-397)
  • size (199-211)
src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py (3)
  • sample_time (945-982)
  • gaussian_q_sample (302-321)
  • _gaussian_loss (505-543)
🪛 Ruff (0.14.7)
src/midst_toolkit/attacks/tf/data_utils.py

1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200)

Remove unused noqa directive

(RUF100)


59-59: Unused function argument: verbose

(ARG001)


75-75: Avoid specifying long messages outside the exception class

(TRY003)


95-95: Avoid specifying long messages outside the exception class

(TRY003)


101-101: Avoid specifying long messages outside the exception class

(TRY003)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


165-165: Avoid specifying long messages outside the exception class

(TRY003)


193-193: Avoid specifying long messages outside the exception class

(TRY003)

src/midst_toolkit/attacks/tf/tf_attack.py

1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200, PLR0915)

Remove unused noqa directive

(RUF100)


82-82: Avoid specifying long messages outside the exception class

(TRY003)


222-222: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (11)
src/midst_toolkit/attacks/tf/data_utils.py (3)

70-103: Table loading and validation logic looks sound for single-table use

The per‑table loading/validation flow (train CSV + {table}_domain.json, ID column removal, '?' value check, and numeric‑column string guard) is consistent and defensive. For the current test dataset (single "trans" table), this is straightforward and appropriate.


185-219: Attack evaluation aggregation and metric computation look coherent

The aggregation in evaluate_attack_performance—looping over indices, skipping missing files, concatenating predictions and labels, then computing max TPR at a fixed FPR and ROC AUC—is consistent and matches the intended evaluation logic. Guard clauses for empty indices and no predictions give safe fallbacks.

Please double‑check that the shape of challenge_label.csv (loaded via np.loadtxt(..., skiprows=1)) matches the expected 1D label array so roc_curve receives the correct y_true.


221-265: FastTensorDataLoader implementation is minimal and correct

The loader enforces equal leading dimensions, supports optional shuffling via a permuted index, and computes the correct number of batches including a final partial batch. __iter__/__next__ semantics are standard and should integrate cleanly with for loops or yield from.

tests/integration/attacks/tf/data_configs/dataset_meta.json (1)

1-1: Dataset meta format is minimal but consistent with loader expectations

Single‑table relation_order and tables.trans structure align with how load_multi_table_customized reads dataset_meta.json. No issues spotted.

tests/integration/attacks/tf/data_configs/trans.json (1)

1-50: Config parameters and relative paths look appropriate for tests

The general/clustering/diffusion/classifier/sampling/matching sections are consistent and use repo‑relative paths (including workspace_dir), which is good for portability in CI and local runs.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1)

1-1: Domain descriptor matches the expected schema

Field names, sizes, and types mirror the other trans_domain.json assets, so this should integrate smoothly with the domain‑aware tooling.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)

1-1: Tabddpm_1 domain metadata is consistent with other variants

Schema and typing for all fields match the other trans_domain.json assets, which keeps the attack tests’ domain assumptions uniform.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1)

1-1: Tabddpm_5 domain descriptor aligns with the rest of the suite

Field set and type declarations are consistent with the other tabddpm model domain files, so downstream utilities can treat all models uniformly.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1)

1-1: LGTM!

Valid JSON domain metadata defining feature sizes and types for test fixtures. Structure is consistent with other trans_domain.json files in the PR.

tests/integration/attacks/tf/data_configs/trans_domain.json (1)

1-1: LGTM!

Valid JSON domain metadata for test data configuration.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1)

1-1: LGTM!

Valid JSON domain metadata consistent with other tabddpm model test assets.

Comment on lines +112 to +155
if use_best_checkpoint and best_model_dir is not None:
best_model_dir = Path(".") # or raise ValueError
print(f"Best model will be saved to: {best_model_dir}")

best_model_path = best_model_dir / "best_model.pt"
optimizer = optim.Adam(regression_model.parameters(), lr=learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

has_validation = x_val is not None
x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device)
if has_validation:
x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)

indices = torch.randperm(x_train.size(0))
x_train, y_train = x_train[indices], y_train[indices]

regression_model.train()
best_tpr, best_model_exists = 0.0, False

for epoch in range(num_epochs):
optimizer.zero_grad()
loss = custom_loss_fn(regression_model, x_train, y_train)
loss.backward()
optimizer.step()

if (epoch + 1) % 100 == 0:
train_loss, train_tpr = evaluate_model(regression_model, x_train, y_train)
if x_val is not None:
test_loss, test_tpr = evaluate_model(regression_model, x_val, y_val)
if test_tpr > best_tpr:
best_tpr = test_tpr
save_best_model(regression_model, best_model_path)
best_model_exists = True
print(
f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss}, "
f"Test Loss: {test_loss}, Train TPR: {train_tpr}, Test TPR: {test_tpr}"
)
else:
print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss}, Train TPR: {train_tpr}")

if use_best_checkpoint and best_model_exists:
load_best_model(regression_model, best_model_path, device)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix fitmodel’s best_model_dir handling; current defaults crash and override caller paths.

Currently:

  • With default arguments (use_best_checkpoint=None, best_model_dir=None), best_model_path = best_model_dir / "best_model.pt" raises a TypeError.
  • When use_best_checkpoint is truthy and a non-None best_model_dir is passed, it’s silently replaced with Path("."), so the best model is saved/loaded from the CWD instead of the caller-specified directory.

Make the requirement explicit and avoid overriding the caller:

-    if use_best_checkpoint and best_model_dir is not None:
-        best_model_dir = Path(".")  # or raise ValueError
-        print(f"Best model will be saved to: {best_model_dir}")
-
-    best_model_path = best_model_dir / "best_model.pt"
+    best_model_path: Path | None = None
+    if use_best_checkpoint:
+        if best_model_dir is None:
+            raise ValueError("best_model_dir must be provided when use_best_checkpoint=True.")
+        best_model_dir = Path(best_model_dir)
+        best_model_dir.mkdir(parents=True, exist_ok=True)
+        best_model_path = best_model_dir / "best_model.pt"
+        print(f"Best model will be saved to: {best_model_path}")
@@
-                if test_tpr > best_tpr:
+                if test_tpr > best_tpr and best_model_path is not None:
                     best_tpr = test_tpr
                     save_best_model(regression_model, best_model_path)
                     best_model_exists = True
@@
-    if use_best_checkpoint and best_model_exists:
-        load_best_model(regression_model, best_model_path, device)
+    if use_best_checkpoint and best_model_exists and best_model_path is not None:
+        load_best_model(regression_model, best_model_path, device)

This preserves existing behavior for the integration path (where a results_path is passed) and makes misuse fail fast with a clear error.

🤖 Prompt for AI Agents
In src/midst_toolkit/attacks/tf/classification.py around lines 112 to 155, the
code currently overwrites caller-provided best_model_dir with Path(".") and
computes best_model_path when best_model_dir may be None, causing TypeError and
silent misuse; instead, do not override best_model_dir, require/validate it when
use_best_checkpoint is truthy (e.g., if use_best_checkpoint: if best_model_dir
is None: raise ValueError("best_model_dir must be provided when
use_best_checkpoint is True")), convert best_model_dir to a Path (best_model_dir
= Path(best_model_dir)) and only compute best_model_path = best_model_dir /
"best_model.pt" and call save_best_model/load_best_model when
use_best_checkpoint is truthy and a valid best_model_dir is present; leave
behavior unchanged when use_best_checkpoint is falsy.

Comment on lines +2 to +8
"general": {
"data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data",
"exp_name": "train_1",
"workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace",
"sample_prefix": "",
"test_data_dir": "tests/integration/attacks/ensemble/assets/test"
},
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid hardcoded absolute workspace_dir in test config

workspace_dir is an absolute path (Line 5), which will break portability across environments and CI. For test assets, prefer a project‑relative path under the repo (e.g. the tabddpm_5 workspace folder).

-        "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace",
+        "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace",
🤖 Prompt for AI Agents
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json
lines 2-8: the "workspace_dir" field is an absolute path which breaks
portability/CI; replace it with a repo-relative path (no leading slash) pointing
to a workspace folder under the test asset (for example the tabddpm_5 workspace
directory within the repo), and update the JSON value accordingly so tests use a
project-relative workspace path.

Comment on lines +15 to +21
set_all_random_seeds(
seed=133742,
use_deterministic_torch_algos=True,
disable_torch_benchmarking=True,
)

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Ensure RNG and env cleanup even when the test fails.

If any assertion fails before the end of the test, unset_all_random_seeds() and the CUBLAS_WORKSPACE_CONFIG cleanup won’t run, leaving global state polluted for subsequent tests. Wrap the body in a try/finally:

 def test_tf_attack_whitebox_small_config():
-    # Set deterministic behavior
-    set_all_random_seeds(
-        seed=133742,
-        use_deterministic_torch_algos=True,
-        disable_torch_benchmarking=True,
-    )
-
-    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-
-    base_path = ...
-    ...
-    assert tpr_at_fpr_test == pytest.approx(0.12, abs=1e-8)
-
-    unset_all_random_seeds()
-    os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
+    # Set deterministic behavior
+    set_all_random_seeds(
+        seed=133742,
+        use_deterministic_torch_algos=True,
+        disable_torch_benchmarking=True,
+    )
+    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+    try:
+        base_path = ...
+        ...
+        assert tpr_at_fpr_test == pytest.approx(0.12, abs=1e-8)
+    finally:
+        unset_all_random_seeds()
+        os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)

Also applies to: 65-66

🤖 Prompt for AI Agents
In tests/integration/attacks/tf/test_tf_attack.py around lines 15-21 (and
similarly at lines 65-66), the test sets global RNG seeds and the
CUBLAS_WORKSPACE_CONFIG env var but does not guarantee cleanup if the test
fails; wrap the test body that calls set_all_random_seeds and sets
os.environ["CUBLAS_WORKSPACE_CONFIG"] in a try/finally so that
unset_all_random_seeds() and deletion (or restoration) of
CUBLAS_WORKSPACE_CONFIG always run in the finally block; apply the same
try/finally pattern to the other occurrence at lines 65-66 to ensure global
state is restored even on assertion errors.



def test_alpha_precision_evaluation() -> None:
# Setting the paramters to True helps get consistent output on the same architecture for the _OC metrics
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typo: "paramters" should be "parameters".

-    # Setting the paramters to True helps get consistent output on the same architecture for the _OC metrics
+    # Setting the parameters to True helps get consistent output on the same architecture for the _OC metrics
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Setting the paramters to True helps get consistent output on the same architecture for the _OC metrics
# Setting the parameters to True helps get consistent output on the same architecture for the _OC metrics
🤖 Prompt for AI Agents
tests/unit/evaluation/quality/test_alpha_precision_naive.py lines 22-22: The
comment has a typo "paramters" — change it to "parameters" so the comment reads
"Setting the parameters to True helps get consistent output on the same
architecture for the _OC metrics".

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some preliminary comments for you to consider. I know this is a first pass at getting this code into the toolkit. So we don't need to address everything. However, I think most of my comments are at least worth thinking about addressing to true to improve the clarity of the code.

We'll certainly have to work on making it a bit easier to use.

Initializes the MLP (Multi-Layer Perceptron) model.

Args:
input_dim (int): The number of input features.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our docs, we're not requiring the type annotations here because we're already strongly typing the method/class signatures etc. So they are a bit redundant. You can shut off the auto-generation of these types in your VS code settings.

So these annotations should be dropped throughout.

self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1)

def forward(self, x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super minor, but perhaps change this to input?

confidences = model(x)
x = x.float()
y = y.float()
return nn.BCELoss()(confidences, y.unsqueeze(1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be missing something, but how is this loss any different than a standard BCELoss? I guess it takes in a model and an input instead of an output.

return torch.sigmoid(self.fc3(residual))


def custom_loss_fn(model, x, y):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name here could be more descriptive. Maybe bce_loss_from_model_and_input?


def fitmodel(
regression_model,
x_train,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe x -> input here in the names? It's also a little odd to have x_train_label I think. Colloquially, x normally indicates input and y labels I think

# get the diffusion model
with open(filepath, "rb") as f:
model = CustomUnpickler(f).load()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already do this above.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
diffusion = model.diffusion.to(device)

iter_max = iter_max // batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is happening here? Are we making sure that the size of the dataset is exactly the same as the batch_size? Why bother with a batch_size then?

x, out_dict = next(train_loader)
out_dict = {"y": out_dict}
x = x.to(device)
for k in out_dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to iterate through the out_dict here, since we know it's structure above. Why not just do

out_dict = {"y": out_dict[k].long().to(device)}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also uncertain why this works at all. Doesn't the loader return a tuple of tensors? Is .long valid to call on such a tuple?

parallel_batch=parallel_batch,
addt_value=addt_value,
)
t = t_cur * 0 + t_value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This *0 is, again, quite confusing to me...

keys_for_deduplication=["trans_id", "balance"],
)

n_feutures = [col for col in df_train_merge.columns if "_id" not in col]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feutures -> features.

Also, this isn't the number of features, but rather a list of columns that are not id columns. So this name is a bit misleading.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants