Skip to content

Commit 1ef4cbe

Browse files
authored
Merge pull request #254 from iustinsirbu13/feature/multimatch
Added MultiMatch algorithm.
2 parents aa9018c + 864f2f7 commit 1ef4cbe

File tree

10 files changed

+589
-15
lines changed

10 files changed

+589
-15
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
algorithm: multimatch
2+
save_dir: ./saved_models/usb_nlp/multimatch
3+
save_name: multimatch_ag_news_200_0
4+
resume: True
5+
load_path: ./saved_models/usb_nlp/multimatch/multimatch_ag_news_200_0/latest_model.pth
6+
overwrite: True
7+
use_tensorboard: True
8+
use_wandb: False
9+
epoch: 100
10+
num_train_iter: 102400
11+
num_warmup_iter: 5120
12+
num_log_iter: 256
13+
num_eval_iter: 2048
14+
num_labels: 200
15+
batch_size: 8
16+
eval_batch_size: 8
17+
ema_m: 0.0
18+
hard_label: True
19+
T: 0.5
20+
p_cutoff: 0.95
21+
ulb_loss_ratio: 3.0
22+
num_heads: 3
23+
apm_percentile: 0.05
24+
no_low: True
25+
apm_disagreement_weight: 3
26+
threshold_algo: freematch
27+
smoothness: 0.997
28+
uratio: 1
29+
use_cat: False
30+
optim: AdamW
31+
lr: 0.00005
32+
momentum: 0.9
33+
weight_decay: 0.0005
34+
layer_decay: 0.65
35+
amp: False
36+
clip: 0.0
37+
net: bert_base_uncased_multihead
38+
net_from_name: False
39+
data_dir: ./data
40+
dataset: ag_news
41+
train_sampler: RandomSampler
42+
num_classes: 4
43+
num_workers: 4
44+
max_length: 512
45+
seed: 0
46+
world_size: 1
47+
rank: 0
48+
multiprocessing_distributed: False
49+
dist_url: tcp://127.0.0.1:10001
50+
dist_backend: nccl
51+
gpu: 0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .multimatch import MultiMatch
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
from semilearn.algorithms.flexmatch.utils import FlexMatchThresholdingHook
5+
from semilearn.algorithms.freematch.utils import FreeMatchThresholingHook as FreeMatchThresholdingHook
6+
from semilearn.algorithms.multimatch.utils import MultiMatchThresholdingHook
7+
from semilearn.algorithms.utils import SSL_Argument, str2bool
8+
from semilearn.core.algorithmbase import AlgorithmBase
9+
from semilearn.core.utils import ALGORITHMS
10+
11+
12+
@ALGORITHMS.register('multimatch')
13+
class MultiMatch(AlgorithmBase):
14+
def __init__(self, args, net_builder, tb_log=None, logger=None):
15+
16+
# multihead specific arguments
17+
self.num_heads = args.num_heads
18+
19+
# arguments used by the auxiliary thresholding (e.g. freematch)
20+
self.init_maskinghook_args(T=args.T, hard_label=args.hard_label, ema_p=args.ema_p, use_quantile=args.use_quantile,
21+
clip_thresh=args.clip_thresh, p_cutoff=args.p_cutoff, thresh_warmup=args.thresh_warmup,
22+
threshold_algo=args.threshold_algo)
23+
24+
super().__init__(args, net_builder, tb_log, logger)
25+
26+
27+
def init_maskinghook_args(self, T, p_cutoff, hard_label=True, ema_p=0.999, use_quantile=True, clip_thresh=False, thresh_warmup=True, threshold_algo='freematch'):
28+
self.T = T
29+
self.p_cutoff = p_cutoff
30+
self.use_hard_label = hard_label
31+
self.thresh_warmup = thresh_warmup
32+
self.ema_p = ema_p
33+
self.use_quantile = use_quantile
34+
self.clip_thresh = clip_thresh
35+
self.threshold_algo = threshold_algo
36+
37+
def set_model(self):
38+
"""
39+
initialize model
40+
"""
41+
model = self.net_builder(self.args)
42+
return model
43+
44+
def set_ema_model(self):
45+
"""
46+
initialize ema model from model
47+
"""
48+
ema_model = self.net_builder(self.args)
49+
ema_model.load_state_dict(self.model.state_dict())
50+
return ema_model
51+
52+
def set_hooks(self):
53+
self.register_hook(MultiMatchThresholdingHook(self.args), "APMHook")
54+
55+
for i in range(self.num_heads):
56+
if self.threshold_algo == 'flexmatch':
57+
self.register_hook(FlexMatchThresholdingHook(ulb_dest_len=self.args.ulb_dest_len, num_classes=self.num_classes, thresh_warmup=self.args.thresh_warmup), f"MaskingHook{i}")
58+
elif self.threshold_algo == 'freematch':
59+
self.register_hook(FreeMatchThresholdingHook(num_classes=self.num_classes, momentum=self.args.ema_p), f"MaskingHook{i}")
60+
elif self.threshold_algo == 'none':
61+
pass
62+
else:
63+
raise NotImplementedError()
64+
65+
super().set_hooks()
66+
67+
def get_head_logits(self, head_id, logits, num_lb):
68+
head_logits = logits[head_id]
69+
logits_x_lb = head_logits[:num_lb]
70+
logits_x_ulb_w, logits_x_ulb_s = head_logits[num_lb:].chunk(2)
71+
return logits_x_lb, logits_x_ulb_w, logits_x_ulb_s
72+
73+
def get_pseudo_labels(self, ulb_weak_logits):
74+
# max probability for each logit tensor
75+
# index with highest probability for each logit tensor
76+
_, pseudo_labels = torch.max(ulb_weak_logits, dim=-1)
77+
return pseudo_labels
78+
79+
def get_supervised_loss(self, lb_logits, lb_target):
80+
head_losses = [F.cross_entropy(lb_logits[head_id], lb_target) for head_id in range(self.num_heads)]
81+
if self.args.average_losses:
82+
return sum(head_losses) / len(head_losses)
83+
return sum(head_losses)
84+
85+
def _get_auxiliary_mask(self, logits_x_ulb_w, idx_ulb, head_id):
86+
# calculate mask
87+
if self.threshold_algo == 'freematch':
88+
mask = self.call_hook("masking", f"MaskingHook{head_id}", logits_x_ulb=logits_x_ulb_w)
89+
elif self.threshold_algo == 'flexmatch':
90+
probs_x_ulb_w = self.compute_prob(logits_x_ulb_w.detach())
91+
mask = self.call_hook("masking", f"MaskingHook{head_id}", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False, idx_ulb=idx_ulb)
92+
elif self.threshold_algo == 'none':
93+
mask = torch.ones(idx_ulb.shape[0], dtype=torch.int64).cuda(self.gpu)
94+
else:
95+
raise NotImplementedError()
96+
return mask
97+
98+
def get_auxiliary_mask_comp(self, logits_x_ulb_w, idx_ulb, head_id1, head_id2):
99+
auxiliary_mask1 = self._get_auxiliary_mask(logits_x_ulb_w[head_id1], idx_ulb, head_id1)
100+
auxiliary_mask2 = self._get_auxiliary_mask(logits_x_ulb_w[head_id2], idx_ulb, head_id2)
101+
return torch.maximum(auxiliary_mask1, auxiliary_mask2)
102+
103+
def get_head_unsupervised_loss(self, ulb_weak_logits, ulb_strong_logits, pseudo_labels, idx_ulb, y_ulb, head_id):
104+
'''
105+
This works only for 3 heads
106+
'''
107+
if head_id == 0:
108+
head_id1, head_id2 = 1, 2
109+
elif head_id == 1:
110+
head_id1, head_id2 = 0, 2
111+
else:
112+
head_id1, head_id2 = 0, 1
113+
114+
num_ulb = idx_ulb.shape[0]
115+
multihead_labels = torch.ones(num_ulb, dtype=torch.int64).cuda(self.gpu) * -1
116+
multihead_agreement_types = torch.ones(num_ulb, dtype=torch.int64).cuda(self.gpu) * -1
117+
agreement_types_mask = torch.ones(num_ulb, dtype=torch.int64).cuda(self.gpu) * -1
118+
119+
for i in range(num_ulb):
120+
label1 = pseudo_labels[head_id1][i]
121+
label2 = pseudo_labels[head_id2][i]
122+
multihead_labels[i], multihead_agreement_types[i], agreement_types_mask[i] = self.call_hook(
123+
"get_apm_label", "APMHook", head_id=head_id, head_id1=head_id1, head_id2=head_id2, idx=idx_ulb[i], label1=label1, label2=label2)
124+
125+
auxiliary_mask = self.get_auxiliary_mask_comp(ulb_weak_logits, idx_ulb, head_id1, head_id2)
126+
127+
multihead_labels[multihead_labels == -1] = 0 # can't have labels -1, even though the weight will be 0
128+
samples_weights = (agreement_types_mask == 0) * self.args.apm_disagreement_weight + (agreement_types_mask == 1) * 1
129+
130+
final_weights = samples_weights * auxiliary_mask
131+
132+
return (F.cross_entropy(ulb_strong_logits[head_id], multihead_labels, reduction='none') * final_weights).mean()
133+
134+
135+
def get_unsupervised_loss(self, ulb_weak_logits, ulb_strong_logits, pseudo_labels, idx_ulb, y_ulb):
136+
for head_id in range(self.num_heads):
137+
self.call_hook("update", "APMHook", logits_x_ulb_w=ulb_weak_logits[head_id], logits_x_ulb_s=ulb_strong_logits[head_id], idx_ulb=idx_ulb, head_id=head_id)
138+
139+
head_losses = [self.get_head_unsupervised_loss(ulb_weak_logits, ulb_strong_logits, pseudo_labels, idx_ulb, y_ulb, head_id) for head_id in range(self.num_heads)]
140+
return sum(head_losses) / self.num_heads
141+
142+
def get_loss(self, lb_loss, ulb_loss):
143+
return lb_loss + self.lambda_u * ulb_loss
144+
145+
def _post_process_logits(self, logits_x_lb, logits_x_ulb_w, logits_x_ulb_s, y_lb, idx_ulb, y_ulb, feat_dict=None):
146+
# Supervised loss
147+
lb_loss = self.get_supervised_loss(logits_x_lb, y_lb)
148+
149+
# Pseudo labels
150+
pseudo_labels = torch.stack([self.get_pseudo_labels(logits_x_ulb_w[head_id]) for head_id in range(self.num_heads)])
151+
152+
# Unsupervised loss
153+
ulb_loss = self.get_unsupervised_loss(logits_x_ulb_w, logits_x_ulb_s, pseudo_labels, idx_ulb, y_ulb)
154+
155+
# Total loss
156+
loss = self.get_loss(lb_loss, ulb_loss)
157+
158+
if feat_dict:
159+
out_dict = self.process_out_dict(loss=loss, feat=feat_dict)
160+
else:
161+
out_dict = self.process_out_dict(loss=loss)
162+
log_dict = self.process_log_dict(sup_loss=lb_loss.item(),
163+
unsup_loss=ulb_loss.item(),
164+
total_loss=loss.item())
165+
166+
return out_dict, log_dict
167+
168+
def train_step_base(self, logits, y_lb, idx_ulb, y_ulb):
169+
num_lb = y_lb.shape[0]
170+
num_ulb = idx_ulb.shape[0]
171+
172+
logits_x_lb = torch.zeros(self.num_heads, num_lb, self.num_classes).cuda(self.gpu)
173+
logits_x_ulb_w = torch.zeros(self.num_heads, num_ulb, self.num_classes).cuda(self.gpu)
174+
logits_x_ulb_s = torch.zeros(self.num_heads, num_ulb, self.num_classes).cuda(self.gpu)
175+
176+
for head_id in range(self.num_heads):
177+
logits_x_lb[head_id], logits_x_ulb_w[head_id], logits_x_ulb_s[head_id] = \
178+
self.get_head_logits(head_id, logits, num_lb)
179+
180+
return self._post_process_logits(logits_x_lb, logits_x_ulb_w, logits_x_ulb_s, y_lb, idx_ulb, y_ulb)
181+
182+
183+
# @overrides
184+
def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s, idx_ulb, y_ulb=None):
185+
idx_ulb = idx_ulb.cuda(self.gpu)
186+
187+
if self.use_cat:
188+
inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s))
189+
inputs = inputs.cuda(self.gpu)
190+
logits = self.model(inputs)['logits']
191+
return self.train_step_base(logits, y_lb, idx_ulb, y_ulb)
192+
else:
193+
outs_x_lb = self.model(x_lb)
194+
logits_x_lb = outs_x_lb['logits']
195+
feats_x_lb = outs_x_lb['feat']
196+
outs_x_ulb_s = self.model(x_ulb_s)
197+
logits_x_ulb_s = outs_x_ulb_s['logits']
198+
feats_x_ulb_s = outs_x_ulb_s['feat']
199+
with torch.no_grad():
200+
outs_x_ulb_w = self.model(x_ulb_w)
201+
logits_x_ulb_w = outs_x_ulb_w['logits']
202+
feats_x_ulb_w = outs_x_ulb_w['feat']
203+
feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s}
204+
205+
return self._post_process_logits(logits_x_lb, logits_x_ulb_w, logits_x_ulb_s, y_lb, idx_ulb, y_ulb, feat_dict=feat_dict)
206+
207+
def get_logits(self, data, out_key):
208+
x = data['x_lb']
209+
if isinstance(x, dict):
210+
x = {k: v.cuda(self.gpu) for k, v in x.items()}
211+
else:
212+
x = x.cuda(self.gpu)
213+
214+
logits = self.model(x)[out_key]
215+
216+
# Use all heads for prediction
217+
return sum(logits) / self.num_heads
218+
219+
def get_save_dict(self):
220+
save_dict = super().get_save_dict()
221+
222+
# additional saving arguments
223+
for i in range(self.num_heads):
224+
if self.threshold_algo == 'freematch':
225+
save_dict[f'p_model{i}'] = self.hooks_dict[f'MaskingHook{i}'].p_model.cpu()
226+
save_dict[f'time_p{i}'] = self.hooks_dict[f'MaskingHook{i}'].time_p.cpu()
227+
elif self.threshold_algo == 'flexmatch':
228+
save_dict[f'classwise_acc{i}'] = self.hooks_dict[f'MaskingHook{i}'].classwise_acc.cpu()
229+
save_dict[f'selected_label{i}'] = self.hooks_dict[f'MaskingHook{i}'].selected_label.cpu()
230+
elif self.threshold_algo == 'none':
231+
pass
232+
else:
233+
raise NotImplementedError()
234+
235+
return save_dict
236+
237+
def load_model(self, load_path):
238+
checkpoint = super().load_model(load_path)
239+
240+
for i in range(self.num_heads):
241+
if self.threshold_algo == 'freematch':
242+
self.hooks_dict[f'MaskingHook{i}'].p_model = checkpoint[f'p_model{i}'].cuda(self.gpu)
243+
self.hooks_dict[f'MaskingHook{i}'].time_p = checkpoint[f'time_p{i}'].cuda(self.gpu)
244+
elif self.threshold_algo == 'flexmatch':
245+
self.hooks_dict[f'MaskingHook{i}'].classwise_acc = checkpoint[f'classwise_acc{i}'].cuda(self.gpu)
246+
self.hooks_dict[f'MaskingHook{i}'].selected_label = checkpoint[f'selected_label{i}'].cuda(self.gpu)
247+
elif self.threshold_algo == 'none':
248+
pass
249+
else:
250+
raise NotImplementedError()
251+
252+
self.print_fn("additional parameter loaded")
253+
return checkpoint
254+
255+
@staticmethod
256+
def get_argument():
257+
return [
258+
SSL_Argument('--num_heads', int, 3),
259+
SSL_Argument('--no_low', str2bool, False), # gamma_min -inf (True) or 0 (False), the lower limit for the apm threshold
260+
SSL_Argument('--apm_disagreement_weight', float, 3),
261+
SSL_Argument('--apm_percentile', float, 0.05),
262+
SSL_Argument('--smoothness', float, 0.997),
263+
SSL_Argument('--adjust_clf_size', str2bool, False),
264+
SSL_Argument('--num_recalibrate_iter', int, 0), # if 0, it will be done every epoch
265+
SSL_Argument('--average_losses', str2bool, False),
266+
SSL_Argument('--threshold_algo', str, 'freematch'),
267+
# arguments used by the freematch/flexmatch thresholding
268+
SSL_Argument('--hard_label', str2bool, True),
269+
SSL_Argument('--T', float, 0.5),
270+
SSL_Argument('--ema_p', float, 0.999),
271+
SSL_Argument('--ent_loss_ratio', float, 0.01),
272+
SSL_Argument('--use_quantile', str2bool, False),
273+
SSL_Argument('--clip_thresh', str2bool, False),
274+
SSL_Argument('--p_cutoff', float, 0.95),
275+
SSL_Argument('--thresh_warmup', str2bool, True),
276+
]

0 commit comments

Comments
 (0)