-
Notifications
You must be signed in to change notification settings - Fork 1
Bz/tf #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Bz/tf #104
Conversation
📝 WalkthroughWalkthroughThis PR introduces a comprehensive membership inference attack (MIA) framework for tabular differentially private models. The implementation includes three new Python modules providing data loading utilities ( Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 17
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
🟡 Minor comments (4)
src/midst_toolkit/attacks/tf/data_utils.py-157-165 (1)
157-165: Deduplication key validation happens afterdrop_duplicatesis called.The validation for missing keys (lines 162-165) occurs after
drop_duplicatesis already called (lines 158-159). If keys are missing,drop_duplicateswill raise aKeyErrorbefore your descriptiveValueErroris reached.+ # Ensure all keys for deduplication exist in both DataFrames before deduplication + missing_keys_merge = [key for key in keys_for_deduplication if key not in df_merge.columns] + missing_keys_challenge = [key for key in keys_for_deduplication if key not in df_challenge.columns] + if missing_keys_merge or missing_keys_challenge: + raise ValueError(f"Missing columns for deduplication: {missing_keys_merge + missing_keys_challenge}") + # Deduplicate the datasets once df_merge = df_merge.drop_duplicates(subset=keys_for_deduplication) df_challenge = df_challenge.drop_duplicates(subset=keys_for_deduplication) - - # Ensure all keys for deduplication exist in both DataFrames - missing_keys_merge = [key for key in keys_for_deduplication if key not in df_merge.columns] - missing_keys_challenge = [key for key in keys_for_deduplication if key not in df_challenge.columns] - if missing_keys_merge or missing_keys_challenge: - raise ValueError(f"Missing columns for deduplication: {missing_keys_merge + missing_keys_challenge}")src/midst_toolkit/attacks/tf/data_utils.py-174-179 (1)
174-179: PotentialValueErroron edge case.If all FPR values are >=
max_fpr,tpr[fpr < max_fpr]will be an empty array andmax()will raise aValueError.def get_tpr_at_fpr(true_membership: list[int], predictions: list[float], max_fpr: float = 0.1) -> float: """ Calculates the best True Positive Rate when the False Positive Rate is at most `max_fpr`. """ fpr, tpr, _ = roc_curve(true_membership, predictions) - return max(tpr[fpr < max_fpr]) + valid_tpr = tpr[fpr <= max_fpr] + if len(valid_tpr) == 0: + return 0.0 + return float(max(valid_tpr))src/midst_toolkit/attacks/tf/classifcation.py-133-134 (1)
133-134: Conditionx_val is not Noneis always true after tensor conversion.On line 116,
x_valis unconditionally converted to a tensor viatorch.tensor(x_val, ...). If the originalx_valparameter wasNone, this would raise an error before reaching line 133. The check should happen before tensor conversion.+ has_validation = x_val is not None x_train = torch.tensor(x_train, dtype=torch.float32).to(device) y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device) - x_val = torch.tensor(x_val, dtype=torch.float32).to(device) - y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device) + if has_validation: + x_val = torch.tensor(x_val, dtype=torch.float32).to(device) + y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)Then use
has_validationin the condition on line 133.Committable suggestion skipped: line range outside the PR's diff.
tests/integration/attacks/tf/test_tf_attack.py-51-53 (1)
51-53: Use explicit key access instead of dict unpacking for clarity.The dictionaries returned have keys
"max_tpr"and"roc_auc". While Python 3.7+ guarantees dictionary insertion order, explicit key access makes the code clearer and more maintainable.- tpr_at_fpr_train, roc_auc_train = mia_performance_train.values() - tpr_at_fpr_val, roc_auc_val = mia_performance_val.values() - tpr_at_fpr_test, roc_auc_test = mia_performance_test.values() + tpr_at_fpr_train = mia_performance_train["max_tpr"] + roc_auc_train = mia_performance_train["roc_auc"] + tpr_at_fpr_val = mia_performance_val["max_tpr"] + roc_auc_val = mia_performance_val["roc_auc"] + tpr_at_fpr_test = mia_performance_test["max_tpr"] + roc_auc_test = mia_performance_test["roc_auc"]
🧹 Nitpick comments (15)
src/midst_toolkit/attacks/tf/data_utils.py (2)
1-6: Excessive linting suppressions reduce code quality.Suppressing D102, D105, D103, D200 (docstring rules) and multiple mypy error codes across the entire file is a significant code smell. As per the PR description, "fix typing errors" is listed as a next step — consider addressing these rather than suppressing them.
59-59: Unusedverboseparameter.The
verboseparameter is declared but never used in the function body. Either implement verbose logging or remove the parameter.src/midst_toolkit/attacks/tf/classifcation.py (2)
59-62: Remove dead code:x = x.float()is unused.The variable
xis reassigned on line 60 but never used afterward. This appears to be leftover code.def custom_loss_fn(model, x, y): confidences = model(x) - x = x.float() y = y.float() return nn.BCELoss()(confidences, y.unsqueeze(1))
114-117: Inconsistent variable naming:y_testvsx_val.Line 117 uses
y_testfor labels corresponding tox_val, mixing "test" and "val" terminology. This is confusing given the function parameters usex_valandx_val_label.x_train = torch.tensor(x_train, dtype=torch.float32).to(device) y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device) x_val = torch.tensor(x_val, dtype=torch.float32).to(device) - y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device) + y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)Then update references on lines 134, 150 accordingly.
src/midst_toolkit/attacks/tf/tf_attack.py (11)
1-6: Clean up unused lint suppressions.Static analysis indicates the
noqadirective on line 1 is unused. Since the PR author already plans to "fix typing errors", consider removing unnecessary suppressions once proper type annotations are added rather than blanket-disabling mypy checks.
37-82: Several issues inmixed_lossfunction.
- Unused parameter:
no_meanis never used (Line 44).- Redundant device assignment:
deviceis assigned on line 51 but immediately overwritten on line 61.- Typo: "defeualt" → "default" (Line 63).
- Redundant conditional: The check
if not return_random:on line 74 is alwaysTruesince we already returned on line 67 whenreturn_random=True.def mixed_loss( diffusion, x, out_dict, noise=None, t=None, return_random=False, - no_mean=False, parallel_batch=None, addt_value=None, ): x_num = x[:, : diffusion.num_numerical_features] x_cat = x[:, diffusion.num_numerical_features :] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = x.device noise_tensor = torch.tensor(noise, device=device, dtype=torch.float) batch_noise = noise_tensor.repeat(x_num.shape[0], 1) # there is actually no categorical classes, as we have examined the DM, so we just ignore x_cat here and later x_num = x_num.repeat_interleave(parallel_batch, dim=0) x_cat = x_cat.repeat_interleave(parallel_batch, dim=0) b = x_num.shape[0] - device = x.device if t is None: - # the defeualt is uniform sampling + # the default is uniform sampling t, pt = diffusion.sample_time(b, device) if return_random: return noise, t, pt additional_t = t * 0 + addt_value # forward x_num_t with (t+additional_t) timestamps x_num_t = diffusion.gaussian_q_sample(x_num, t + additional_t, noise=batch_noise) - if not return_random: - current_t = t - # predict noises with t timestamps - predicted_noise = diffusion._denoise_fn(x_num_t, current_t, **out_dict) - current_loss = diffusion._gaussian_loss(predicted_noise, batch_noise, batch_noise, current_t, batch_noise) - transformed_current_loss = current_loss.reshape(-1, parallel_batch) + current_t = t + # predict noises with t timestamps + predicted_noise = diffusion._denoise_fn(x_num_t, current_t, **out_dict) + current_loss = diffusion._gaussian_loss(predicted_noise, batch_noise, batch_noise, current_t, batch_noise) + transformed_current_loss = current_loss.reshape(-1, parallel_batch) return transformed_current_loss * 0, transformed_current_loss
136-142: Prefix unused variables with underscore.
label_encodersandcolumn_ordersare unpacked but never used. Per convention, prefix with underscore to indicate intentional non-use.- dataset, label_encoders, column_orders = Dataset.from_df( + dataset, _label_encoders, _column_orders = Dataset.from_df(
207-229: Misplaced docstring.The docstring is placed in the middle of the function after executable code (lines 201-205). It should be the first statement after the function signature to be recognized by documentation tools and IDEs.
Move the docstring to immediately after line 198 (
def get_score(...):) before any executable statements.
247-252: Remove debug print and simplify loop logic.
- Debug print:
print(iter_max)should be removed or replaced with proper logging.- Confusing loop: The
assert iter_max == 1followed bywhile iter_id < iter_max:means the loop always runs exactly once. Consider simplifying to remove the loop entirely or documenting why this structure exists for future extensibility.- print(iter_max) iter_max = iter_max // batch_size return_res = torch.zeros([batch_size, parallel_batch]) assert iter_max == 1 - iter_id = 0 - while iter_id < iter_max: + # Process single batch (currently only supports iter_max == 1) + x, out_dict = next(train_loader) ... - iter_id += 1
236-236: Prefix unused variable with underscore.
challenge_datasetis unpacked but never used.- train_loader, iter_max, challenge_dataset = train_loader_list[loader_count] + train_loader, iter_max, _challenge_dataset = train_loader_list[loader_count]
261-269: Prefix unused variables with underscore.
noiseandptare unpacked but never used in this context.- noise, t_cur, pt = mixed_loss( + _noise, t_cur, _pt = mixed_loss(
309-316: Hardcoded deduplication keys reduce reusability.The keys
["trans_id", "balance"]are hardcoded in bothprepare_data_for_attackcalls. Consider making these configurable via a parameter.
479-479: Document magic number for noise dimension.The noise dimension
size=8is hardcoded without explanation. Consider extracting this to a named constant or parameter with documentation explaining why 8 is the appropriate value.+ NOISE_DIMENSION = 8 # Must match the diffusion model's expected noise dimension - input_noise: list[list[float]] = [np.random.normal(size=8).tolist() for _ in range(num_noise_per_time_step)] + input_noise: list[list[float]] = [np.random.normal(size=NOISE_DIMENSION).tolist() for _ in range(num_noise_per_time_step)]
187-198: Unusedphaseparameter.The
phaseparameter is accepted but never used in the function body. Either remove it or implement the intended behavior.def get_score( data_path, save_dir, input_noise, type="tabddpm", - phase=None, challenge_name=None, batch_size=None, parallel_batch=None, addt_value=None, t_value=None, ):If
phaseis intended for future use, add a TODO comment or raiseNotImplementedErrorwhen a non-None value is passed.
191-191: Avoid shadowing built-intype.Using
typeas a parameter name shadows Python's built-intype()function, which can cause subtle bugs if the built-in is needed within this function.def get_score( data_path, save_dir, input_noise, - type="tabddpm", + model_type="tabddpm", ... ): - if type == "tabddpm": + if model_type == "tabddpm":
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (49)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/data_for_training_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/data_for_training_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/data_for_validating_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/data_for_validating_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_synthetic.csvis excluded by!**/*.csv
📒 Files selected for processing (22)
src/midst_toolkit/attacks/tf/classifcation.py(1 hunks)src/midst_toolkit/attacks/tf/data_utils.py(1 hunks)src/midst_toolkit/attacks/tf/tf_attack.py(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/data_configs/dataset_meta.json(1 hunks)tests/integration/attacks/tf/data_configs/trans.json(1 hunks)tests/integration/attacks/tf/data_configs/trans_domain.json(1 hunks)tests/integration/attacks/tf/test_tf_attack.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (13)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (2)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (2)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)src/midst_toolkit/models/clavaddpm/model.py (1)
DiffusionParameters(19-31)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/workspace/train_1/args (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (2)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/data_configs/trans_domain.json (1)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)
tests/integration/attacks/tf/data_configs/trans.json (2)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)src/midst_toolkit/attacks/ensemble/blending.py (1)
__init__(27-66)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/test_tf_attack.py (2)
src/midst_toolkit/attacks/tf/tf_attack.py (1)
tf_attack(459-559)src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)
src/midst_toolkit/attacks/tf/data_utils.py (3)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
Dataset(77-397)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
Normalization(58-63)src/midst_toolkit/models/clavaddpm/data_loaders.py (1)
FastTensorDataLoader(473-537)
🪛 Ruff (0.14.7)
src/midst_toolkit/attacks/tf/tf_attack.py
1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200, PLR0915)
Remove unused noqa directive
(RUF100)
44-44: Unused function argument: no_mean
(ARG001)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
136-136: Unpacked variable label_encoders is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
136-136: Unpacked variable column_orders is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
192-192: Unused function argument: phase
(ARG001)
203-203: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Unpacked variable challenge_dataset is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
261-261: Unpacked variable noise is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
261-261: Unpacked variable pt is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
src/midst_toolkit/attacks/tf/data_utils.py
1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200)
Remove unused noqa directive
(RUF100)
59-59: Unused function argument: verbose
(ARG001)
75-75: Avoid specifying long messages outside the exception class
(TRY003)
95-95: Avoid specifying long messages outside the exception class
(TRY003)
101-101: Avoid specifying long messages outside the exception class
(TRY003)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
165-165: Avoid specifying long messages outside the exception class
(TRY003)
190-190: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: run-code-check
- GitHub Check: unit-tests
- GitHub Check: integration-tests
🔇 Additional comments (16)
src/midst_toolkit/attacks/tf/data_utils.py (1)
218-261: LGTM!The
FastTensorDataLoaderimplementation correctly handles batching and shuffling. The pattern matches the reference implementation fromsrc/midst_toolkit/models/clavaddpm/data_loaders.py.tests/integration/attacks/tf/data_configs/trans.json (1)
1-50: LGTM!This configuration uses relative paths consistently, making it portable across different environments unlike some of the other tabddpm config files.
tests/integration/attacks/tf/data_configs/dataset_meta.json (1)
1-1: ✓ Test metadata structure is appropriate.The JSON correctly describes a single-table dataset with no relationships, which aligns with the trans domain schema and test configuration patterns.
tests/integration/attacks/tf/data_configs/trans_domain.json (1)
1-1: ✓ Domain schema is consistent and well-formed.The feature definitions (8 fields with appropriate continuous/discrete classifications) align with the dataset_meta.json and are consistent across all tabddpm model variants used in tests.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)
1-1: ✓ Consistent test asset.Domain schema matches across all tabddpm variants, which is appropriate for uniform test configuration.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1)
1-1: ✓ Consistent test asset.Matches schema across all tabddpm variants for uniform testing.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1)
1-1: ✓ Consistent test asset.Matches schema across all tabddpm variants for uniform testing.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1)
1-1: ✓ Consistent test asset.Matches schema across all tabddpm variants for uniform testing.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (2)
1-8: Verify absolute workspace_dir path will not break tests.Line 5 contains an absolute path
/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_2/workspacethat does not exist in typical CI/test environments. This may cause test failures when the configuration is used to create directories or write outputs.Check whether:
- The integration test overrides this path at runtime (e.g., using
save_additional_tabddpm_config)- The code gracefully handles missing parent directories
- Tests are expected to run in an environment where
/projects/existsThe relative paths for
data_dirandtest_data_dir(lines 3, 7) are appropriate for test portability.
14-42: ✓ Hyperparameter choices are appropriate for fast integration testing.Minimal values (iterations=2–3, batch_size=1, num_timesteps=3) prioritize test speed while maintaining realistic configuration structure. This aligns with typical test fixtures for ML frameworks.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (2)
1-8: Verify absolute workspace_dir path will not break tests.Line 5 contains an absolute path
/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspacethat does not exist in typical CI/test environments. This matches the same issue intabddpm_2/updated_config.jsonand suggests a systematic configuration pattern.Verify that:
- The integration test overrides workspace_dir paths at runtime (expected based on
save_additional_tabddpm_configutility)- Absolute paths are intentionally used as templates and not written to during tests
The relative test data paths (lines 3, 7) are appropriate for portability.
14-42: ✓ Hyperparameter choices are appropriate for fast integration testing.Configuration values match the appropriate minimal test settings (iterations=2–3, batch_size=1, num_timesteps=3), balancing test speed with realistic model structure.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1)
1-1: LGTM!The domain schema defines the expected field structure with appropriate size and type metadata for the tabddpm_4 model configuration.
tests/integration/attacks/tf/test_tf_attack.py (1)
20-28: Good use of deterministic seeding for reproducible tests.The combination of
set_all_random_seedswithuse_deterministic_torch_algos=TrueandCUBLAS_WORKSPACE_CONFIGenvironment variable ensures reproducible results, which is essential for testing ML code with fixed expected values.src/midst_toolkit/attacks/tf/classifcation.py (1)
8-40: LGTM: MLP architecture is straightforward and well-documented.The 3-layer MLP with tanh activations and sigmoid output is appropriate for binary classification in the attack context.
src/midst_toolkit/attacks/tf/tf_attack.py (1)
81-81: Clarify intent of returning zeroed tensor.The return statement returns
transformed_current_loss * 0as the first element, which is always a zero tensor. If this is intentional (e.g., for API compatibility or specific algorithm requirements), consider adding a comment explaining why. Otherwise, this may be leftover debugging code.
| ) | ||
| return loss.item(), tpr | ||
|
|
||
| best_model_path = best_model_dir / "best_model.pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential TypeError if best_model_dir is None.
If best_model_dir is not provided (defaults to None), this line will raise TypeError: unsupported operand type(s) for /: 'NoneType' and 'str'. Add a guard or require the parameter.
+ if best_model_dir is None:
+ best_model_dir = Path(".") # or raise ValueError
best_model_path = best_model_dir / "best_model.pt"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| best_model_path = best_model_dir / "best_model.pt" | |
| if best_model_dir is None: | |
| best_model_dir = Path(".") # or raise ValueError | |
| best_model_path = best_model_dir / "best_model.pt" |
🤖 Prompt for AI Agents
In src/midst_toolkit/attacks/tf/classifcation.py around line 110, the expression
best_model_dir / "best_model.pt" will raise a TypeError if best_model_dir is
None; add a guard to validate best_model_dir before using the / operator (either
raise a clear ValueError if None, coerce a default Path, or require the
parameter to be non-None via function signature), and ensure best_model_dir is a
pathlib.Path (or convert it with Path(best_model_dir)) before constructing
best_model_path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is related to my comment above as well 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Fetch all review comments from emersodb on PR #104
gh pr view 104 --json comments,reviews --jq '.reviews[] | select(.author.login == "emersodb") | {body: .body, state: .state}' 2>/dev/null || echo "No reviews found"
# Also check individual review comments
gh api repos/VectorInstitute/midst-toolkit/pulls/104/comments --jq '.[] | select(.user.login == "emersodb") | {path: .path, line: .line, body: .body}' 2>/dev/null || echo "No review comments found"
Length of output: 2192
🏁 Script executed:
#!/bin/bash
# Get the full content of the classification.py file to see the logic around lines 110-120
cat src/midst_toolkit/attacks/tf/classification.py | sed -n '105,120p'
Length of output: 625
@emersodb Ah yes, I see both issues now! 🎯
Looking at lines 112-116:
if use_best_checkpoint and best_model_dir is not None:
best_model_dir = Path(".") # Line 113 - overwrites the user's input!
print(f"Best model will be saved to: {best_model_dir}")
best_model_path = best_model_dir / "best_model.pt" # Line 116 - fails if NoneTwo bugs here:
-
Your point (line 113): When
best_model_diris provided, it gets unconditionally overwritten withPath("."), ignoring the user's input completely. -
CodeRabbit's point (line 116): When
best_model_dirisNone, the/operator raises aTypeError.
Suggested fix:
if use_best_checkpoint:
if best_model_dir is None:
raise ValueError("best_model_dir must be provided when use_best_checkpoint is True")
best_model_dir = Path(best_model_dir) # Convert to Path without overwriting
print(f"Best model will be saved to: {best_model_dir}")
else:
best_model_dir = Path(".") # Default only when checkpointing is disabled
best_model_path = best_model_dir / "best_model.pt"Or, if you want to allow a default:
if use_best_checkpoint:
best_model_dir = Path(best_model_dir) if best_model_dir else Path(".")
print(f"Best model will be saved to: {best_model_dir}")
else:
best_model_dir = Path(".")
best_model_path = best_model_dir / "best_model.pt"This also aligns with your suggestion to use Path objects only (from your comment on line 90).
| for t in indices: | ||
| base_path = models_base_dir / f"{model_type}_{t}" | ||
| df_merge_list.append(pd.read_csv(os.path.join(base_path, "train_with_id.csv"))) | ||
| df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv"))) | ||
| df_challenge_labels = pd.read_csv(os.path.join(base_path, "challenge_label.csv")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if challenge_label.csv files differ across tabddpm models
for dir in tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_*/; do
if [ -f "$dir/challenge_label.csv" ]; then
echo "=== $(basename $dir) ==="
head -5 "$dir/challenge_label.csv"
fi
doneRepository: VectorInstitute/midst-toolkit
Length of output: 247
🏁 Script executed:
#!/bin/bash
# Read the full prepare_data_for_attack function to see return statement and full context
cat -n src/midst_toolkit/attacks/tf/data_utils.py | sed -n '138,171p'Repository: VectorInstitute/midst-toolkit
Length of output: 1999
🏁 Script executed:
#!/bin/bash
# Search for callers of prepare_data_for_attack to understand how returned values are used
rg "prepare_data_for_attack" --type py -A 5Repository: VectorInstitute/midst-toolkit
Length of output: 1934
df_challenge_labels is overwritten in each iteration; only the last model's labels are returned.
The loop reads challenge_label.csv for each index but overwrites df_challenge_labels on every iteration (line 152). Challenge labels differ across models, so this causes data loss. Either append labels to a list and return concatenated data, or document if only the last model's labels are intentional. Note: current callers ignore this return value by using _, but the bug still exists.
🤖 Prompt for AI Agents
In src/midst_toolkit/attacks/tf/data_utils.py around lines 148 to 152,
df_challenge_labels is overwritten on every loop iteration so only the last
model's labels are kept; change the code to collect labels per-iteration (e.g.,
create a df_challenge_labels_list before the loop, append each pd.read_csv(...)
to that list) and after the loop concatenate them with pd.concat(...) (or
otherwise aggregate appropriately) and return the combined DataFrame;
alternatively, if returning only the last model's labels was intended, add an
explicit comment and keep behavior but ensure callers expect that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an important one to address. We're not appending the labels to this list here. Is this expected behavior? If so, we should at least have a comment that explains why we're not appending.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if challenge_label.csv files exist in the test assets and if they differ
echo "=== Checking for challenge_label.csv files in test assets ==="
fd -t f "challenge_label.csv" tests/integration/attacks/tf/assets/
echo -e "\n=== Checking file sizes and content samples ==="
for file in $(fd -t f "challenge_label.csv" tests/integration/attacks/tf/assets/); do
if [ -f "$file" ]; then
echo "File: $file"
echo "Size: $(wc -l < "$file") lines"
echo "First 3 lines:"
head -3 "$file"
echo "---"
fi
done
Length of output: 1190
🏁 Script executed:
#!/bin/bash
# Search for usage of challenge_label and challenge labels in the codebase
echo "=== Searching for 'challenge_label' references in Python files ==="
rg "challenge_label" --type py -C 2
echo -e "\n=== Checking if challenge labels are documented anywhere ==="
rg "challenge.*label" --type md -i
Length of output: 5699
🏁 Script executed:
#!/bin/bash
# Check the full function implementation and see how deduplication works
echo "=== Full prepare_data_for_attack function ==="
sed -n '/^def prepare_data_for_attack/,/^def \|^class \|^$/p' src/midst_toolkit/attacks/tf/data_utils.py | head -50
Length of output: 345
🏁 Script executed:
#!/bin/bash
# Get the complete prepare_data_for_attack function
echo "=== Complete prepare_data_for_attack function ==="
awk '/^def prepare_data_for_attack/,/^def [^_]|^class / {print NR": "$0}' src/midst_toolkit/attacks/tf/data_utils.py | head -40
Length of output: 145
🏁 Script executed:
#!/bin/bash
# Check how the function is actually called in tf_attack.py
echo "=== Usage in tf_attack.py ==="
rg "prepare_data_for_attack" src/midst_toolkit/attacks/tf/tf_attack.py -B 3 -A 5
Length of output: 694
🏁 Script executed:
#!/bin:bash
# Compare first and last challenge_label files to confirm they differ
echo "=== Comparing challenge_label files ==="
echo "First model (tabddpm_1):"
head -10 tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csv
echo -e "\nLast model (tabddpm_5):"
head -10 tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csv
echo -e "\nFull diff check:"
diff tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csv tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csv | head -20
Length of output: 258
@emersodb, excellent catch! This is definitely a bug. Here's what I found:
Evidence
The challenge_label.csv files differ significantly across models:
tabddpm_1 (rows 2-4):
1
1
1
tabddpm_5 (rows 2-4):
0
1
0
Running a full diff shows dozens of differences between these files.
The Problem
At line 152, df_challenge_labels is overwritten in each iteration:
for t in indices:
base_path = models_base_dir / f"{model_type}_{t}"
df_merge_list.append(pd.read_csv(...)) # ✅ appended
df_challenge_list.append(pd.read_csv(...)) # ✅ appended
df_challenge_labels = pd.read_csv(...) # ❌ overwritten!This means only the last model's labels are returned, discarding all previous models' labels—even though those labels contain different data.
Current Impact
The callers in tf_attack.py currently ignore this return value with _, _, _, so the bug isn't causing runtime issues yet. However, the function's contract is broken.
Recommended Fix
Change line 152 and the code after the loop to match the pattern used for the other two DataFrames:
df_merge_list = []
df_challenge_list = []
df_challenge_labels_list = [] # Add this
for t in indices:
base_path = models_base_dir / f"{model_type}_{t}"
df_merge_list.append(pd.read_csv(os.path.join(base_path, "train_with_id.csv")))
df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv")))
df_challenge_labels_list.append(pd.read_csv(os.path.join(base_path, "challenge_label.csv"))) # Change this
df_merge = pd.concat(df_merge_list, ignore_index=True)
df_challenge = pd.concat(df_challenge_list, ignore_index=True)
df_challenge_labels = pd.concat(df_challenge_labels_list, ignore_index=True) # Add thisThen proceed with deduplication as before. This will ensure all challenge labels from all models are properly collected and returned.
Should I open a follow-up issue for this fix, or would you like to address it in this PR?
| if dataset.numerical_feautures is not None: | ||
| x = torch.from_numpy( | ||
| np.concatenate([dataset.numerical_feautures[split], dataset.categorical_features[split]], axis=1) | ||
| ).float() | ||
| else: | ||
| x = torch.from_numpy(dataset.categorical_features[split]).float() | ||
| else: | ||
| x = torch.from_numpy(dataset.numerical_features[split]).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in attribute name causes runtime AttributeError.
Line 287 uses numerical_feautures (misspelled with 'au') while line 294 correctly uses numerical_features. This will raise an AttributeError when categorical_features is not None but numerical_features is present.
if dataset.categorical_features is not None:
- if dataset.numerical_feautures is not None:
+ if dataset.numerical_features is not None:
x = torch.from_numpy(
- np.concatenate([dataset.numerical_feautures[split], dataset.categorical_features[split]], axis=1)
+ np.concatenate([dataset.numerical_features[split], dataset.categorical_features[split]], axis=1)
).float()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if dataset.numerical_feautures is not None: | |
| x = torch.from_numpy( | |
| np.concatenate([dataset.numerical_feautures[split], dataset.categorical_features[split]], axis=1) | |
| ).float() | |
| else: | |
| x = torch.from_numpy(dataset.categorical_features[split]).float() | |
| else: | |
| x = torch.from_numpy(dataset.numerical_features[split]).float() | |
| if dataset.numerical_features is not None: | |
| x = torch.from_numpy( | |
| np.concatenate([dataset.numerical_features[split], dataset.categorical_features[split]], axis=1) | |
| ).float() | |
| else: | |
| x = torch.from_numpy(dataset.categorical_features[split]).float() | |
| else: | |
| x = torch.from_numpy(dataset.numerical_features[split]).float() |
🤖 Prompt for AI Agents
In src/midst_toolkit/attacks/tf/data_utils.py around lines 287 to 294, there's a
misspelled attribute `numerical_feautures` that will raise AttributeError;
change that misspelling to `numerical_features` so the code accesses the correct
attribute when concatenating numerical and categorical arrays, and ensure any
other references in this block use the correct `numerical_features` name.
| "general": { | ||
| "data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data", | ||
| "exp_name": "train_1", | ||
| "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_3/workspace", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded absolute path will break portability.
The workspace_dir uses an absolute path that is machine-specific. This should be updated to use a relative path or a configurable value to ensure the tests work across different environments.
🤖 Prompt for AI Agents
In
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args
around line 5, the workspace_dir is a hardcoded absolute path which breaks
portability; change it to a relative path or configuration-driven value (e.g.,
use a path relative to the repository root or an environment variable like
WORKSPACE_DIR) and update any test setup to resolve the chosen
relative/configurable path at runtime so the tests work across different
machines.
| "general": { | ||
| "data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data", | ||
| "exp_name": "train_1", | ||
| "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded absolute path will break portability.
The workspace_dir contains an absolute path /projects/midst-experiments/... that is specific to a particular machine. This will fail in CI environments and for other developers.
Consider using a relative path or a path that can be configured dynamically during test execution, similar to how data_dir uses a relative path.
🤖 Prompt for AI Agents
In
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json
around line 5, the workspace_dir is a hardcoded absolute path
(/projects/midst-experiments/...), which breaks portability; change it to a
relative path (e.g., "./workspace" or "train/tabddpm_4/workspace") or a
configurable placeholder (e.g., "${WORKSPACE_DIR}") and update the test setup to
inject/resolve that variable at runtime so CI and other developers can run the
tests without machine-specific paths.
| "general": { | ||
| "data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data", | ||
| "exp_name": "train_1", | ||
| "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded absolute path will break test portability.
The workspace_dir contains an absolute path (/projects/midst-experiments/...) which will fail on different machines or CI environments. Consider using a relative path or a placeholder that gets resolved at runtime, similar to how data_dir uses a relative path.
- "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace",
+ "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace",📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace", | |
| "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace", |
🤖 Prompt for AI Agents
In
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args
around line 5, the workspace_dir is a hardcoded absolute path which breaks
portability; change it to a relative path or placeholder (e.g., ./workspace or
${WORKSPACE_DIR}) and update test/bootstrap logic to resolve that placeholder at
runtime (or set the env var in test setup) so the path is computed
per-environment rather than fixed.
| "general": { | ||
| "data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data", | ||
| "exp_name": "train_1", | ||
| "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded absolute path will break portability.
Same issue as other workspace configs — the absolute workspace_dir path should be made relative or configurable.
🤖 Prompt for AI Agents
In
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args
around line 5, the workspace_dir is a hardcoded absolute path which breaks
portability; change it to a relative path or read from a configurable source
(env var or test fixture) instead — replace the absolute path with a
project-relative path (e.g., ./workspace or ../<appropriate> path) or use a
placeholder that the test harness fills from an environment variable or config
so CI and other machines can resolve the workspace directory at runtime.
| # Add paths | ||
| sys.path.append("/h/behnzaman/") | ||
| sys.path.insert(0, "/h/behnzaman/midst-experiments/deps/TF_attack/") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove hardcoded developer-specific paths from sys.path.
These absolute paths are specific to a developer's home directory and will cause import failures for other developers and in CI environments. If these dependencies are required, they should be properly installed as packages or the paths should be configurable via environment variables.
-# Add paths
-sys.path.append("/h/behnzaman/")
-sys.path.insert(0, "/h/behnzaman/midst-experiments/deps/TF_attack/")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Add paths | |
| sys.path.append("/h/behnzaman/") | |
| sys.path.insert(0, "/h/behnzaman/midst-experiments/deps/TF_attack/") |
🤖 Prompt for AI Agents
In tests/integration/attacks/tf/test_tf_attack.py around lines 9-11, remove the
two hardcoded sys.path.append/sys.path.insert calls that reference a developer's
home directory; instead ensure the dependency is available via normal packaging
(add it to requirements/dev and pip install it or use an editable install during
development) or make the path configurable (read a TF_ATTACK_PATH env var and
validate it, or compute a project-relative path using pathlib from the
repository root). Update the test to import the package normally once installed,
or if a temporary path injection is unavoidable, derive the path from an
environment variable or project root instead of using absolute user-specific
paths.
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (7)
src/midst_toolkit/attacks/tf/data_utils.py (1)
138-171: Still only returning the last model’schallenge_label.csvInside
prepare_data_for_attack,df_challenge_labelsis reassigned on every iteration (Line 152) and never accumulated. With multipleindices, you return labels only from the last model, whiledf_mergeanddf_challengeare concatenated across all models. This is the same issue previously flagged.If you intend to aggregate across models, collect labels in a list and concatenate:
def prepare_data_for_attack(indices, model_type, models_base_dir, keys_for_deduplication): @@ - df_challenge_list = [] + df_challenge_list = [] + df_challenge_labels_list = [] @@ - df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv"))) - df_challenge_labels = pd.read_csv(os.path.join(base_path, "challenge_label.csv")) + df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv"))) + df_challenge_labels_list.append(pd.read_csv(os.path.join(base_path, "challenge_label.csv"))) @@ - df_challenge = pd.concat(df_challenge_list, ignore_index=True) + df_challenge = pd.concat(df_challenge_list, ignore_index=True) + df_challenge_labels = pd.concat(df_challenge_labels_list, ignore_index=True) @@ - return df_merge_without_challenge, df_challenge, df_challenge_labels + return df_merge_without_challenge, df_challenge, df_challenge_labelsIf only the last model’s labels are truly desired, add an explicit comment and enforce that
indiceshas length 1.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args (1)
2-8: Replace hardcoded absoluteworkspace_dirwith a project-relative path
workspace_diris an absolute path (Line 5), which will fail on other machines and CI. For a test asset, it should be under the repo with a relative path, similar todata_dirand other configs.- "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace", + "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace",tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (1)
5-5: Hardcoded absolute path breaks CI/portability.This issue was previously flagged. The
workspace_dircontains an absolute path specific to a particular machine, which will fail in CI and for other developers.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1)
5-5: Hardcoded absolute path breaks CI/portability.This issue was previously flagged. The
workspace_dircontains an absolute path specific to a particular machine, which will fail in CI and for other developers.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (1)
3-7: Avoid hardcoded absoluteworkspace_dirin test config.
"workspace_dir"is an absolute, developer/cluster-specific path and will not exist on other machines/CI. Use a project-relative path (like the other fields) or a placeholder that your test harness fills in.- "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace", + "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace",src/midst_toolkit/attacks/tf/tf_attack.py (2)
383-435: Useaddt_value_listwhen building training/validation features; current loop ignores it.Here you still hardcode
for addt_value in [0]:, soaddt_value_listis effectively unused. At the same time, the feature matrix width is computed aslen(input_noise) * len(timesteps_list) * len(addt_value_list), meaning anyaddt_value_listwith length > 1 leaves trailing zero columns, and the classifier input_dim is larger than the actually populated features.Refactor the nested loop to iterate over
addt_value_listand index feature blocks accordingly:- t_value_count = 0 - for t_value in timesteps_list: - for addt_value in [0]: + feature_block_index = 0 + for t_value in timesteps_list: + for addt_value in addt_value_list: if model_number in train_indices: @@ - x_train[ - samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1), - t_value_count * num_noise_per_time_step : (t_value_count + 1) * num_noise_per_time_step, - ] = predictions.detach().squeeze().cpu().numpy() + start = feature_block_index * num_noise_per_time_step + end = (feature_block_index + 1) * num_noise_per_time_step + x_train[ + samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1), + start:end, + ] = predictions.detach().squeeze().cpu().numpy() @@ - x_train_label[ - samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1) - ] = np.concatenate([np.zeros(samples_per_train_model), np.ones(samples_per_train_model)]) - t_value_count += 1 + x_train_label[ + samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1) + ] = np.concatenate([np.zeros(samples_per_train_model), np.ones(samples_per_train_model)]) + feature_block_index += 1 @@ - x_val[ - sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1), - t_value_count * num_noise_per_time_step : (t_value_count + 1) * num_noise_per_time_step, - ] = predictions.detach().squeeze().cpu().numpy() + start = feature_block_index * num_noise_per_time_step + end = (feature_block_index + 1) * num_noise_per_time_step + x_val[ + sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1), + start:end, + ] = predictions.detach().squeeze().cpu().numpy() @@ - x_val_label[sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1)] = ( - np.concatenate([np.zeros(sample_per_val_model), np.ones(sample_per_val_model)]) - ) - t_value_count += 1 + x_val_label[sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1)] = ( + np.concatenate([np.zeros(sample_per_val_model), np.ones(sample_per_val_model)]) + ) + feature_block_index += 1This way,
x_train/x_valand the MLP’sinput_dimstay consistent for arbitraryaddt_value_list.
503-528: Respectaddt_value_list, avoid hardcodedbatch_size, and guard min–max normalization against zero range.Three issues remain in
tf_attack’s scoring loop:
Hardcoded
batch_size = 200:
- This magic number is baked into the function, making it harder to reuse in other scenarios or tests.
addt_value_listis ignored:
- The loop still uses
for addt_value in [0]:, so callers can’t varyaddt_valuedespite it being a parameter.Potential division by zero in min–max normalization:
- If all model outputs are identical (
min_output == max_output),(max_output - min_output)is zero, leading tonans and failing the subsequent assertion.A possible fix:
-def tf_attack( +def tf_attack( @@ - classifier_hidden_dim: int, - addt_value_list: list[int], - meta_dir: Path, -) -> tuple[Any, Any, Any]: + classifier_hidden_dim: int, + addt_value_list: list[int], + meta_dir: Path, + batch_size: int = 200, +) -> tuple[Any, Any, Any]: @@ - model_dir = tabddpm_data_dir / model_folder - model_path = model_dir / target_model_subdir - batch_size = 200 - t_value_count = 0 + model_dir = tabddpm_data_dir / model_folder + model_path = model_dir / target_model_subdir current_input = [] for t_value in timesteps_list: - for addt_value in [0]: + for addt_value in addt_value_list: predictions: torch.Tensor = get_score( @@ - ) - t_value_count += 1 - current_input.append(predictions) + ) + current_input.append(predictions) @@ - predictions = regression_model(predictions).detach().cpu().numpy() - # clip to [0, 1] - min_output, max_output = np.min(predictions), np.max(predictions) - predictions = (predictions - min_output) / (max_output - min_output) + predictions = regression_model(predictions).detach().cpu().numpy() + # Rescale to [0, 1] safely + min_output, max_output = np.min(predictions), np.max(predictions) + range_output = max_output - min_output + if range_output > 0: + predictions = (predictions - min_output) / range_output + else: + # All outputs identical; fall back to a neutral constant + predictions = np.full_like(predictions, 0.5) predictions = torch.tensor(predictions)This keeps the existing integration test behavior (
batch_sizedefaults to 200,addt_value_list=[0]) while making the function robust and reusable.
🧹 Nitpick comments (8)
src/midst_toolkit/attacks/tf/data_utils.py (3)
1-18: Tighten lint/typing pragmas and clean leftover comment
- Line 1:
# ruff: noqa: D102, D105, D103, D200is flagged as unused; if these rules aren’t enabled, drop this directive instead of carrying dead config.- Lines 2–6: multiple broad mypy disables (
no-untyped-def,has-type,index,attr-defined,assignment) effectively opt this module out of type checking. Given the PR TODO “fix typing errors”, it would be better long‑term to narrow these or remove them once annotations are in place.- Line 16: Comment
# at very top of file (optional but helpful)looks like a review note rather than code documentation and can be removed.-# ruff: noqa: D102, D105, D103, D200 -# mypy: disable-error-code=no-untyped-def -# mypy: disable-error-code=has-type -# mypy: disable-error-code=index -# mypy: disable-error-code=attr-defined -# mypy: disable-error-code=assignment -from __future__ import annotations - -# at very top of file (optional but helpful) +from __future__ import annotations @@ -from typing import Any, Literal +from typing import Any, Literal(And later, once typing is addressed, consider re‑enabling mypy checks instead of disabling them globally for this file.)
59-76:verboseparameter inload_multi_table_customizedis currently unusedThe
verboseargument (Line 59) is never read, which is flagged by static analysis and can confuse callers expecting logging or extra checks.Either remove the parameter if not needed:
-def load_multi_table_customized(data_dir, meta_dir=None, train_name="train.csv", verbose=True): +def load_multi_table_customized(data_dir, meta_dir=None, train_name="train.csv"):or actually use it (e.g., to print/log dataset/table info when
verboseis True).
267-304:prepare_fast_dataloaderbehavior (infinite batches, feature concatenation) is intentional but worth documenting clearly
- Correctly handles three cases: both numerical+categorical, categorical‑only, and numerical‑only.
- Returns an endless stream of
(x, y)batches viawhile True: yield from dataloader, reshuffling each epoch whensplit == "train".Consider making the “infinite generator” behavior even more explicit in the docstring (e.g., “This is an infinite generator; callers must bound iteration externally”) to avoid misuse in code that expects a single finite epoch.
tests/unit/evaluation/quality/test_alpha_precision_naive.py (1)
53-58: Redundant conditional branches with identical assertions.Both the
if is_apple_silicon()andelsebranches contain identical assertions. Either the branching is unnecessary and can be simplified, or the expected values for the non-Apple Silicon path should differ.If the values are truly architecture-independent for naive metrics, simplify:
- if is_apple_silicon(): - assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"] - assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"] - else: - assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"] - assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"] + assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"] + assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"]tests/integration/attacks/tf/test_tf_attack.py (2)
45-54: Remove redundant.values()unpacking before explicit key-based access.You first unpack
mia_performance_* .values()into variables, then immediately overwrite them using explicit key lookups. The unpacking is unnecessary and relies on dict insertion order; just keep the key-based access:- mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config) - tpr_at_fpr_train, roc_auc_train = mia_performance_train.values() - tpr_at_fpr_val, roc_auc_val = mia_performance_val.values() - tpr_at_fpr_test, roc_auc_test = mia_performance_test.values() - tpr_at_fpr_train = mia_performance_train["max_tpr"] - roc_auc_train = mia_performance_train["roc_auc"] - tpr_at_fpr_val = mia_performance_val["max_tpr"] - roc_auc_val = mia_performance_val["roc_auc"] - tpr_at_fpr_test = mia_performance_test["max_tpr"] - roc_auc_test = mia_performance_test["roc_auc"] + mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config) + tpr_at_fpr_train = mia_performance_train["max_tpr"] + roc_auc_train = mia_performance_train["roc_auc"] + tpr_at_fpr_val = mia_performance_val["max_tpr"] + roc_auc_val = mia_performance_val["roc_auc"] + tpr_at_fpr_test = mia_performance_test["max_tpr"] + roc_auc_test = mia_performance_test["roc_auc"]
69-119: Consider deleting the large commented-out alternative test with absolute paths.This block is dead code and contains hardcoded, developer-specific absolute paths. Keeping it commented out adds noise and may confuse future readers about which setup is authoritative. Prefer removing it or moving the scenario into a separate, properly parameterized test/config.
src/midst_toolkit/attacks/tf/tf_attack.py (2)
1-7: Remove or narrow unusedruff: noqadirective.Ruff reports this
noqaas unused forD102, D105, D103, D200, PLR0915. If those checks aren’t actually enabled, the directive is unnecessary noise. Either remove it or restrict it to the specific rules you need to silence.-# ruff: noqa: D102, D105, D103, D200, PLR0915 +# ruff: noqaor simply delete the line if you don’t need it at all.
113-181: Guardget_datasetagainst missingtarget_model_dir/batch_sizeinstead of relying on fragile defaults.
get_datasetdefaultstarget_model_dir=Noneandbatch_size=None, but the body assumes both are set:
os.path.join(target_model_dir, ...)will fail iftarget_model_dirisNone.prepare_fast_dataloader(..., batch_size=batch_size, ...)likely expects an integer.From this module it’s always called with non-
Nonearguments, but the signature suggests they’re optional. Either make them required or fail fast with clear errors:-def get_dataset(data_path, target_model_dir=None, train_name="train_with_id.csv", batch_size=None, meta_dir=""): +def get_dataset(data_path, target_model_dir=None, train_name="train_with_id.csv", batch_size=None, meta_dir=""): @@ - tables, relation_order, _ = load_multi_table_customized( + if target_model_dir is None: + raise ValueError("target_model_dir must be provided.") + if batch_size is None: + raise ValueError("batch_size must be provided.") + + tables, relation_order, _ = load_multi_table_customized( data_path, meta_dir=meta_dir, train_name=train_name,This makes incorrect external use of
get_datasetfail deterministically instead of with a less obviousTypeError.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (52)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/data_for_training_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/data_for_training_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/data_for_validating_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/data_for_validating_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_synthetic.csvis excluded by!**/*.csv
📒 Files selected for processing (18)
src/midst_toolkit/attacks/tf/classification.py(1 hunks)src/midst_toolkit/attacks/tf/data_utils.py(1 hunks)src/midst_toolkit/attacks/tf/tf_attack.py(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/data_configs/dataset_meta.json(1 hunks)tests/integration/attacks/tf/data_configs/trans.json(1 hunks)tests/integration/attacks/tf/data_configs/trans_domain.json(1 hunks)tests/integration/attacks/tf/test_tf_attack.py(1 hunks)tests/unit/evaluation/quality/test_alpha_precision_naive.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/unit/evaluation/quality/test_alpha_precision_naive.py (4)
src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/data_processing/midst_data_processing.py (2)
load_midst_data(94-121)process_midst_data_for_alpha_precision_evaluation(17-91)src/midst_toolkit/evaluation/utils.py (2)
extract_columns_based_on_meta_info(45-87)one_hot_encode_categoricals_and_merge_with_numerical(90-128)tests/utils/architecture.py (1)
is_apple_silicon(4-6)
src/midst_toolkit/attacks/tf/data_utils.py (2)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
Dataset(77-397)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
Normalization(58-63)
src/midst_toolkit/attacks/tf/tf_attack.py (4)
src/midst_toolkit/attacks/tf/classification.py (2)
MLP(10-42)fitmodel(67-161)src/midst_toolkit/attacks/tf/data_utils.py (6)
CustomUnpickler(51-56)TaskType(42-48)evaluate_attack_performance(185-218)load_multi_table_customized(59-103)prepare_data_for_attack(138-171)prepare_fast_dataloader(267-304)src/midst_toolkit/models/clavaddpm/dataset.py (3)
Dataset(77-397)from_df(276-397)size(199-211)src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py (3)
sample_time(945-982)gaussian_q_sample(302-321)_gaussian_loss(505-543)
🪛 Ruff (0.14.7)
src/midst_toolkit/attacks/tf/data_utils.py
1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200)
Remove unused noqa directive
(RUF100)
59-59: Unused function argument: verbose
(ARG001)
75-75: Avoid specifying long messages outside the exception class
(TRY003)
95-95: Avoid specifying long messages outside the exception class
(TRY003)
101-101: Avoid specifying long messages outside the exception class
(TRY003)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
165-165: Avoid specifying long messages outside the exception class
(TRY003)
193-193: Avoid specifying long messages outside the exception class
(TRY003)
src/midst_toolkit/attacks/tf/tf_attack.py
1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200, PLR0915)
Remove unused noqa directive
(RUF100)
82-82: Avoid specifying long messages outside the exception class
(TRY003)
222-222: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (11)
src/midst_toolkit/attacks/tf/data_utils.py (3)
70-103: Table loading and validation logic looks sound for single-table useThe per‑table loading/validation flow (train CSV +
{table}_domain.json, ID column removal,'?'value check, and numeric‑column string guard) is consistent and defensive. For the current test dataset (single"trans"table), this is straightforward and appropriate.
185-219: Attack evaluation aggregation and metric computation look coherentThe aggregation in
evaluate_attack_performance—looping overindices, skipping missing files, concatenating predictions and labels, then computing max TPR at a fixed FPR and ROC AUC—is consistent and matches the intended evaluation logic. Guard clauses for emptyindicesand no predictions give safe fallbacks.Please double‑check that the shape of
challenge_label.csv(loaded vianp.loadtxt(..., skiprows=1)) matches the expected 1D label array soroc_curvereceives the correcty_true.
221-265:FastTensorDataLoaderimplementation is minimal and correctThe loader enforces equal leading dimensions, supports optional shuffling via a permuted index, and computes the correct number of batches including a final partial batch.
__iter__/__next__semantics are standard and should integrate cleanly withforloops oryield from.tests/integration/attacks/tf/data_configs/dataset_meta.json (1)
1-1: Dataset meta format is minimal but consistent with loader expectationsSingle‑table
relation_orderandtables.transstructure align with howload_multi_table_customizedreadsdataset_meta.json. No issues spotted.tests/integration/attacks/tf/data_configs/trans.json (1)
1-50: Config parameters and relative paths look appropriate for testsThe general/clustering/diffusion/classifier/sampling/matching sections are consistent and use repo‑relative paths (including
workspace_dir), which is good for portability in CI and local runs.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1)
1-1: Domain descriptor matches the expected schemaField names, sizes, and types mirror the other
trans_domain.jsonassets, so this should integrate smoothly with the domain‑aware tooling.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)
1-1: Tabddpm_1 domain metadata is consistent with other variantsSchema and typing for all fields match the other
trans_domain.jsonassets, which keeps the attack tests’ domain assumptions uniform.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1)
1-1: Tabddpm_5 domain descriptor aligns with the rest of the suiteField set and type declarations are consistent with the other tabddpm model domain files, so downstream utilities can treat all models uniformly.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1)
1-1: LGTM!Valid JSON domain metadata defining feature sizes and types for test fixtures. Structure is consistent with other
trans_domain.jsonfiles in the PR.tests/integration/attacks/tf/data_configs/trans_domain.json (1)
1-1: LGTM!Valid JSON domain metadata for test data configuration.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1)
1-1: LGTM!Valid JSON domain metadata consistent with other tabddpm model test assets.
| if use_best_checkpoint and best_model_dir is not None: | ||
| best_model_dir = Path(".") # or raise ValueError | ||
| print(f"Best model will be saved to: {best_model_dir}") | ||
|
|
||
| best_model_path = best_model_dir / "best_model.pt" | ||
| optimizer = optim.Adam(regression_model.parameters(), lr=learning_rate) | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
|
||
| has_validation = x_val is not None | ||
| x_train = torch.tensor(x_train, dtype=torch.float32).to(device) | ||
| y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device) | ||
| if has_validation: | ||
| x_val = torch.tensor(x_val, dtype=torch.float32).to(device) | ||
| y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device) | ||
|
|
||
| indices = torch.randperm(x_train.size(0)) | ||
| x_train, y_train = x_train[indices], y_train[indices] | ||
|
|
||
| regression_model.train() | ||
| best_tpr, best_model_exists = 0.0, False | ||
|
|
||
| for epoch in range(num_epochs): | ||
| optimizer.zero_grad() | ||
| loss = custom_loss_fn(regression_model, x_train, y_train) | ||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
| if (epoch + 1) % 100 == 0: | ||
| train_loss, train_tpr = evaluate_model(regression_model, x_train, y_train) | ||
| if x_val is not None: | ||
| test_loss, test_tpr = evaluate_model(regression_model, x_val, y_val) | ||
| if test_tpr > best_tpr: | ||
| best_tpr = test_tpr | ||
| save_best_model(regression_model, best_model_path) | ||
| best_model_exists = True | ||
| print( | ||
| f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss}, " | ||
| f"Test Loss: {test_loss}, Train TPR: {train_tpr}, Test TPR: {test_tpr}" | ||
| ) | ||
| else: | ||
| print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss}, Train TPR: {train_tpr}") | ||
|
|
||
| if use_best_checkpoint and best_model_exists: | ||
| load_best_model(regression_model, best_model_path, device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix fitmodel’s best_model_dir handling; current defaults crash and override caller paths.
Currently:
- With default arguments (
use_best_checkpoint=None,best_model_dir=None),best_model_path = best_model_dir / "best_model.pt"raises aTypeError. - When
use_best_checkpointis truthy and a non-Nonebest_model_diris passed, it’s silently replaced withPath("."), so the best model is saved/loaded from the CWD instead of the caller-specified directory.
Make the requirement explicit and avoid overriding the caller:
- if use_best_checkpoint and best_model_dir is not None:
- best_model_dir = Path(".") # or raise ValueError
- print(f"Best model will be saved to: {best_model_dir}")
-
- best_model_path = best_model_dir / "best_model.pt"
+ best_model_path: Path | None = None
+ if use_best_checkpoint:
+ if best_model_dir is None:
+ raise ValueError("best_model_dir must be provided when use_best_checkpoint=True.")
+ best_model_dir = Path(best_model_dir)
+ best_model_dir.mkdir(parents=True, exist_ok=True)
+ best_model_path = best_model_dir / "best_model.pt"
+ print(f"Best model will be saved to: {best_model_path}")
@@
- if test_tpr > best_tpr:
+ if test_tpr > best_tpr and best_model_path is not None:
best_tpr = test_tpr
save_best_model(regression_model, best_model_path)
best_model_exists = True
@@
- if use_best_checkpoint and best_model_exists:
- load_best_model(regression_model, best_model_path, device)
+ if use_best_checkpoint and best_model_exists and best_model_path is not None:
+ load_best_model(regression_model, best_model_path, device)This preserves existing behavior for the integration path (where a results_path is passed) and makes misuse fail fast with a clear error.
🤖 Prompt for AI Agents
In src/midst_toolkit/attacks/tf/classification.py around lines 112 to 155, the
code currently overwrites caller-provided best_model_dir with Path(".") and
computes best_model_path when best_model_dir may be None, causing TypeError and
silent misuse; instead, do not override best_model_dir, require/validate it when
use_best_checkpoint is truthy (e.g., if use_best_checkpoint: if best_model_dir
is None: raise ValueError("best_model_dir must be provided when
use_best_checkpoint is True")), convert best_model_dir to a Path (best_model_dir
= Path(best_model_dir)) and only compute best_model_path = best_model_dir /
"best_model.pt" and call save_best_model/load_best_model when
use_best_checkpoint is truthy and a valid best_model_dir is present; leave
behavior unchanged when use_best_checkpoint is falsy.
| "general": { | ||
| "data_dir": "tests/integration/attacks/ensemble/assets/shadow_models_data", | ||
| "exp_name": "train_1", | ||
| "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace", | ||
| "sample_prefix": "", | ||
| "test_data_dir": "tests/integration/attacks/ensemble/assets/test" | ||
| }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid hardcoded absolute workspace_dir in test config
workspace_dir is an absolute path (Line 5), which will break portability across environments and CI. For test assets, prefer a project‑relative path under the repo (e.g. the tabddpm_5 workspace folder).
- "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace",
+ "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace",🤖 Prompt for AI Agents
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json
lines 2-8: the "workspace_dir" field is an absolute path which breaks
portability/CI; replace it with a repo-relative path (no leading slash) pointing
to a workspace folder under the test asset (for example the tabddpm_5 workspace
directory within the repo), and update the JSON value accordingly so tests use a
project-relative workspace path.
| set_all_random_seeds( | ||
| seed=133742, | ||
| use_deterministic_torch_algos=True, | ||
| disable_torch_benchmarking=True, | ||
| ) | ||
|
|
||
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure RNG and env cleanup even when the test fails.
If any assertion fails before the end of the test, unset_all_random_seeds() and the CUBLAS_WORKSPACE_CONFIG cleanup won’t run, leaving global state polluted for subsequent tests. Wrap the body in a try/finally:
def test_tf_attack_whitebox_small_config():
- # Set deterministic behavior
- set_all_random_seeds(
- seed=133742,
- use_deterministic_torch_algos=True,
- disable_torch_benchmarking=True,
- )
-
- os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-
- base_path = ...
- ...
- assert tpr_at_fpr_test == pytest.approx(0.12, abs=1e-8)
-
- unset_all_random_seeds()
- os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
+ # Set deterministic behavior
+ set_all_random_seeds(
+ seed=133742,
+ use_deterministic_torch_algos=True,
+ disable_torch_benchmarking=True,
+ )
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+ try:
+ base_path = ...
+ ...
+ assert tpr_at_fpr_test == pytest.approx(0.12, abs=1e-8)
+ finally:
+ unset_all_random_seeds()
+ os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)Also applies to: 65-66
🤖 Prompt for AI Agents
In tests/integration/attacks/tf/test_tf_attack.py around lines 15-21 (and
similarly at lines 65-66), the test sets global RNG seeds and the
CUBLAS_WORKSPACE_CONFIG env var but does not guarantee cleanup if the test
fails; wrap the test body that calls set_all_random_seeds and sets
os.environ["CUBLAS_WORKSPACE_CONFIG"] in a try/finally so that
unset_all_random_seeds() and deletion (or restoration) of
CUBLAS_WORKSPACE_CONFIG always run in the finally block; apply the same
try/finally pattern to the other occurrence at lines 65-66 to ensure global
state is restored even on assertion errors.
|
|
||
|
|
||
| def test_alpha_precision_evaluation() -> None: | ||
| # Setting the paramters to True helps get consistent output on the same architecture for the _OC metrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: "paramters" should be "parameters".
- # Setting the paramters to True helps get consistent output on the same architecture for the _OC metrics
+ # Setting the parameters to True helps get consistent output on the same architecture for the _OC metrics📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Setting the paramters to True helps get consistent output on the same architecture for the _OC metrics | |
| # Setting the parameters to True helps get consistent output on the same architecture for the _OC metrics |
🤖 Prompt for AI Agents
tests/unit/evaluation/quality/test_alpha_precision_naive.py lines 22-22: The
comment has a typo "paramters" — change it to "parameters" so the comment reads
"Setting the parameters to True helps get consistent output on the same
architecture for the _OC metrics".
emersodb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some preliminary comments for you to consider. I know this is a first pass at getting this code into the toolkit. So we don't need to address everything. However, I think most of my comments are at least worth thinking about addressing to true to improve the clarity of the code.
We'll certainly have to work on making it a bit easier to use.
| Initializes the MLP (Multi-Layer Perceptron) model. | ||
|
|
||
| Args: | ||
| input_dim (int): The number of input features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For our docs, we're not requiring the type annotations here because we're already strongly typing the method/class signatures etc. So they are a bit redundant. You can shut off the auto-generation of these types in your VS code settings.
So these annotations should be dropped throughout.
| self.fc2 = nn.Linear(hidden_dim, hidden_dim) | ||
| self.fc3 = nn.Linear(hidden_dim, 1) | ||
|
|
||
| def forward(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super minor, but perhaps change this to input?
| confidences = model(x) | ||
| x = x.float() | ||
| y = y.float() | ||
| return nn.BCELoss()(confidences, y.unsqueeze(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be missing something, but how is this loss any different than a standard BCELoss? I guess it takes in a model and an input instead of an output.
| return torch.sigmoid(self.fc3(residual)) | ||
|
|
||
|
|
||
| def custom_loss_fn(model, x, y): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name here could be more descriptive. Maybe bce_loss_from_model_and_input?
|
|
||
| def fitmodel( | ||
| regression_model, | ||
| x_train, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe x -> input here in the names? It's also a little odd to have x_train_label I think. Colloquially, x normally indicates input and y labels I think
| # get the diffusion model | ||
| with open(filepath, "rb") as f: | ||
| model = CustomUnpickler(f).load() | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already do this above.
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| diffusion = model.diffusion.to(device) | ||
|
|
||
| iter_max = iter_max // batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is happening here? Are we making sure that the size of the dataset is exactly the same as the batch_size? Why bother with a batch_size then?
| x, out_dict = next(train_loader) | ||
| out_dict = {"y": out_dict} | ||
| x = x.to(device) | ||
| for k in out_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to iterate through the out_dict here, since we know it's structure above. Why not just do
out_dict = {"y": out_dict[k].long().to(device)}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also uncertain why this works at all. Doesn't the loader return a tuple of tensors? Is .long valid to call on such a tuple?
| parallel_batch=parallel_batch, | ||
| addt_value=addt_value, | ||
| ) | ||
| t = t_cur * 0 + t_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This *0 is, again, quite confusing to me...
| keys_for_deduplication=["trans_id", "balance"], | ||
| ) | ||
|
|
||
| n_feutures = [col for col in df_train_merge.columns if "_id" not in col] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feutures -> features.
Also, this isn't the number of features, but rather a list of columns that are not id columns. So this name is a bit misleading.
PR Type
Feauture:
TF Attack
Short Description
Adding TF attack to the midst toolkit. Refactored the code, made sure it works with midst toolkit, wrote an integration test.
#Next step: