Skip to content

Commit 6b2da77

Browse files
committed
removed hard coded class number
Signed-off-by: elitap <[email protected]>
1 parent 1e7f3cc commit 6b2da77

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

training/example_train_script.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ python main_2pt5d.py --max_epochs 100 --val_every 1 --optim_lr 0.000005 \
88
--logdir finetune_ckpt_example --point_prompt --label_prompt --distributed --seed 12346 \
99
--iterative_training_warm_up_epoch 50 --reuse_img_embedding \
1010
--label_prompt_warm_up_epoch 25 \
11-
--checkpoint ./runs/9s_2dembed_model.pt
11+
--checkpoint ./runs/9s_2dembed_model.pt \
12+
--num_classes 105

training/main_2pt5d.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@
113113
parser.add_argument("--skip_bk", action="store_true", help="skip background (0) during training")
114114
parser.add_argument("--patch_embed_3d", action="store_true", help="using 3d patch embedding layer")
115115

116+
parser.add_argument("--num_classes", default=105, type=int, help="number of output classes")
117+
116118

117119
def start_tb(log_dir):
118120
cmd = ["tensorboard", "--logdir", log_dir]
@@ -123,6 +125,10 @@ def main():
123125
args = parser.parse_args()
124126
args.amp = not args.noamp
125127
args.logdir = "./runs/" + args.logdir
128+
129+
if args.num_classes == 0:
130+
warnings.warn("consider setting the correct number of classes")
131+
126132
# start_tb(args.logdir)
127133
if args.seed > -1:
128134
set_determinism(seed=args.seed)
@@ -135,6 +141,9 @@ def main():
135141
main_worker(gpu=0, args=args)
136142

137143

144+
145+
146+
138147
def main_worker(gpu, args):
139148
if args.distributed:
140149
torch.multiprocessing.set_start_method("fork", force=True)
@@ -162,7 +171,7 @@ def main_worker(gpu, args):
162171

163172
dice_loss = DiceCELoss(sigmoid=True)
164173

165-
post_label = AsDiscrete(to_onehot=105)
174+
post_label = AsDiscrete(to_onehot=args.num_classes)
166175
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
167176
dice_acc = DiceMetric(include_background=False, reduction=MetricReduction.MEAN, get_not_nans=True)
168177

training/trainer_2pt5d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def prepare_sam_training_input(inputs, labels, args, model):
129129
unique_labels = unique_labels[: args.num_prompt]
130130

131131
# add 4 background labels to every batch
132-
background_labels = list(set([i for i in range(1, 105)]) - set(unique_labels.cpu().numpy()))
132+
background_labels = list(set([i for i in range(1, args.num_classes)]) - set(unique_labels.cpu().numpy()))
133133
random.shuffle(background_labels)
134134
unique_labels = torch.cat([unique_labels, torch.tensor(background_labels[:4]).cuda(args.rank)])
135135

@@ -375,7 +375,7 @@ def train_epoch_iterative(model, loader, optimizer, scaler, epoch, loss_func, ar
375375

376376

377377
def prepare_sam_test_input(inputs, labels, args, previous_pred=None):
378-
unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank)
378+
unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank)
379379

380380
# preprocess make the size of lable same as high_res_logit
381381
batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float()
@@ -400,7 +400,7 @@ def prepare_sam_test_input(inputs, labels, args, previous_pred=None):
400400

401401
def prepare_sam_val_input_cp_only(inputs, labels, args):
402402
# Don't exclude background in val but will ignore it in metric calculation
403-
unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank)
403+
unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank)
404404

405405
# preprocess make the size of lable same as high_res_logit
406406
batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float()
@@ -460,7 +460,7 @@ def val_epoch(model, loader, epoch, acc_func, args, iterative=False, post_label=
460460
acc_batch = compute_dice(y_pred=y_pred, y=target)
461461
acc_sum, not_nans = (
462462
torch.nansum(acc_batch).item(),
463-
104 - torch.sum(torch.isnan(acc_batch).float()).item(),
463+
(args.num_classes-1) - torch.sum(torch.isnan(acc_batch).float()).item(),
464464
)
465465
acc_sum_total += acc_sum
466466
not_nans_total += not_nans

0 commit comments

Comments
 (0)