@@ -129,7 +129,7 @@ def prepare_sam_training_input(inputs, labels, args, model):
129129 unique_labels = unique_labels [: args .num_prompt ]
130130
131131 # add 4 background labels to every batch
132- background_labels = list (set ([i for i in range (1 , 105 )]) - set (unique_labels .cpu ().numpy ()))
132+ background_labels = list (set ([i for i in range (1 , args . num_classes )]) - set (unique_labels .cpu ().numpy ()))
133133 random .shuffle (background_labels )
134134 unique_labels = torch .cat ([unique_labels , torch .tensor (background_labels [:4 ]).cuda (args .rank )])
135135
@@ -375,7 +375,7 @@ def train_epoch_iterative(model, loader, optimizer, scaler, epoch, loss_func, ar
375375
376376
377377def prepare_sam_test_input (inputs , labels , args , previous_pred = None ):
378- unique_labels = torch .tensor ([i for i in range (1 , 105 )]).cuda (args .rank )
378+ unique_labels = torch .tensor ([i for i in range (1 , args . num_classes )]).cuda (args .rank )
379379
380380 # preprocess make the size of lable same as high_res_logit
381381 batch_labels = torch .stack ([labels == unique_labels [i ] for i in range (len (unique_labels ))], dim = 0 ).float ()
@@ -400,7 +400,7 @@ def prepare_sam_test_input(inputs, labels, args, previous_pred=None):
400400
401401def prepare_sam_val_input_cp_only (inputs , labels , args ):
402402 # Don't exclude background in val but will ignore it in metric calculation
403- unique_labels = torch .tensor ([i for i in range (1 , 105 )]).cuda (args .rank )
403+ unique_labels = torch .tensor ([i for i in range (1 , args . num_classes )]).cuda (args .rank )
404404
405405 # preprocess make the size of lable same as high_res_logit
406406 batch_labels = torch .stack ([labels == unique_labels [i ] for i in range (len (unique_labels ))], dim = 0 ).float ()
@@ -460,7 +460,7 @@ def val_epoch(model, loader, epoch, acc_func, args, iterative=False, post_label=
460460 acc_batch = compute_dice (y_pred = y_pred , y = target )
461461 acc_sum , not_nans = (
462462 torch .nansum (acc_batch ).item (),
463- 104 - torch .sum (torch .isnan (acc_batch ).float ()).item (),
463+ ( args . num_classes - 1 ) - torch .sum (torch .isnan (acc_batch ).float ()).item (),
464464 )
465465 acc_sum_total += acc_sum
466466 not_nans_total += not_nans
0 commit comments