-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtask_arith.py
More file actions
115 lines (87 loc) · 3.38 KB
/
task_arith.py
File metadata and controls
115 lines (87 loc) · 3.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from src.task_vectors import TaskVector
from src.eval import eval_single_dataset
from src.args import parse_arguments
import torch.nn.functional as F
from tqdm import tqdm
import json
import numpy as np
import logging
import random
# Initialize logger
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Example of logging
logger.info("Logger initialized successfully.")
# Config
#datasets = ['MNIST', 'DTD', 'EuroSAT', 'GTSRB', 'SUN397', 'SVHN']
# currently no 'Cars' and 'RESISC45'
# datasets = ['SVHN']
#test_datasets = ['CIFAR10']
args = parse_arguments()
SEED = args.seed
logger.info(f"Setting random seed to {SEED} for reproducibility.")
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
model = args.model
datasets = args.eval_datasets
#args.data_location = 'datasets'
#args.model = model
#args.save = f'checkpoints/{model}'
pretrained_checkpoint = f'checkpoints/{model}/zeroshot.pt'
# 加载预训练模型的 state_dict
pretrained_model = torch.load(pretrained_checkpoint)
pretrained_state_dict = pretrained_model.state_dict()
#simply avg merge the task vectors
#can also work for whole model parameter (not only task vec)
def avg_merge(task_vectors):
"""
Args:
task_vectors (list[dict]): A list of model state dicts.
Returns:
dict: A merged model state dict.
"""
merged_vector = {}
# Iterate over keys in the task vector
for key in task_vectors[0].keys():
# Initialize a tensor for the merged result
merged_tensor = torch.zeros_like(task_vectors[0][key])
# Merge task vectors whose signs match the elected sign
for task_vector in task_vectors:
merged_tensor += task_vector[key]
# Normalize the merged tensor by the number of task vectors
merged_vector[key] = merged_tensor / len(task_vectors)
return merged_vector
# Load Task Vectors
task_vectors = [
TaskVector(pretrained_checkpoint, f'checkpoints/{model}/{dataset}/finetuned.pt').vector
for dataset in datasets
]
# Perform task arithmetic
merged_task_vector = avg_merge(task_vectors)
# Add to Initial Parameters and Scale
scaling_hyperparameter = args.scaling
final_parameters = {key: scaling_hyperparameter * merged_task_vector[key]
for key in merged_task_vector.keys()}
# Apply to Model and Evaluate
image_encoder = TaskVector(vector=final_parameters).apply_to(pretrained_checkpoint)
evaluation_results = {}
for dataset in datasets:
print(f"Evaluating on {dataset}...")
result = eval_single_dataset(image_encoder, dataset, args)
evaluation_results[dataset] = result
# Calculate the average accuracy across all datasets
accuracies = [result['top1'] for result in evaluation_results.values()]
average_accuracy = sum(accuracies) / len(accuracies)
# Add the average accuracy to the evaluation results
evaluation_results['average_accuracy'] = average_accuracy
# Save the evaluation results to the specified JSON path
with open(args.results_db, 'w') as f:
json.dump(evaluation_results, f, indent=4)
# Log the average accuracy
print(f"Average accuracy across datasets: {average_accuracy * 100:.2f}%")
print(f"Average accuracy across datasets: {average_accuracy * 100:.2f}%")
print("Evaluation results saved to:", args.results_db)