Skip to content

Commit 8002871

Browse files
alfekkaalfekkaIINemo
authored
Visual LLM support (#341)
* visuallm with scripts * corrected focus bug * refactored visual lm to store images in dataset only * small fixes * fix * corrected attentionscore * fix * fix * fix transformers and default calculators * Fix visual * turn off focus * fix attn * fix attn * fix attn * get_images moving * linter issues * downgraded transformers * linter issues * linter issues * linter issues --------- Co-authored-by: alfekka <[email protected]> Co-authored-by: iinemo <[email protected]>
1 parent 6897a6d commit 8002871

28 files changed

+1582
-1307
lines changed

examples/basic_example_visual.ipynb

Lines changed: 292 additions & 0 deletions
Large diffs are not rendered by default.

examples/basic_visual_llm_example.ipynb

Lines changed: 0 additions & 941 deletions
This file was deleted.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
- name: MaximumSequenceProbability
2+
- name: Perplexity
3+
- name: MeanTokenEntropy
4+
- name: MeanPointwiseMutualInformation
5+
- name: MeanConditionalPointwiseMutualInformation
6+
- name: ClaimConditionedProbability
7+
- name: PTrue
8+
- name: PTrueSampling
9+
- name: MonteCarloSequenceEntropy
10+
- name: MonteCarloNormalizedSequenceEntropy
11+
- name: LexicalSimilarity
12+
cfg:
13+
metric: "rouge1"
14+
- name: LexicalSimilarity
15+
cfg:
16+
metric: "rouge2"
17+
- name: LexicalSimilarity
18+
cfg:
19+
metric: "rougeL"
20+
- name: LexicalSimilarity
21+
cfg:
22+
metric: "BLEU"
23+
- name: NumSemSets
24+
- name: EigValLaplacian
25+
cfg:
26+
similarity_score: "NLI_score"
27+
affinity: "entail"
28+
- name: EigValLaplacian
29+
cfg:
30+
similarity_score: "NLI_score"
31+
affinity: "contra"
32+
- name: EigValLaplacian
33+
cfg:
34+
similarity_score: "Jaccard_score"
35+
- name: DegMat
36+
cfg:
37+
similarity_score: "NLI_score"
38+
affinity: "entail"
39+
- name: DegMat
40+
cfg:
41+
similarity_score: "NLI_score"
42+
affinity: "contra"
43+
- name: DegMat
44+
cfg:
45+
similarity_score: "Jaccard_score"
46+
- name: Eccentricity
47+
cfg:
48+
similarity_score: "NLI_score"
49+
affinity: "entail"
50+
- name: Eccentricity
51+
cfg:
52+
similarity_score: "NLI_score"
53+
affinity: "contra"
54+
- name: Eccentricity
55+
cfg:
56+
similarity_score: "Jaccard_score"
57+
- name: SemanticEntropy
58+
- name: SAR
59+
- name: TokenSAR
60+
- name: SentenceSAR
61+
- name: LUQ
62+
- name: KernelLanguageEntropy
63+
- name: EigenScore
64+
- name: RenyiNeg
65+
- name: FisherRao
66+
- name: MahalanobisDistanceSeq
67+
- name: RelativeMahalanobisDistanceSeq
68+
- name: RDESeq
69+
- name: PPLMDSeq
70+
cfg:
71+
md_type: "MD"
72+
- name: PPLMDSeq
73+
cfg:
74+
md_type: "RMD"
75+
- name: AttentionScore
76+
cfg:
77+
layer: 16
78+
gen_only: False
79+
# - name: Focus
80+
# cfg:
81+
# model_name: '${model.path}'
82+
# path: "${cache_path}/focus/${model.path}/token_idf.pkl"
83+
# gamma: 0.9
84+
# p: 0.01
85+
# idf_dataset: "togethercomputer/RedPajama-Data-1T-Sample"
86+
# trust_remote_code: True
87+
# idf_seed: 42
88+
# idf_dataset_size: 1000
89+
# #idf_dataset_size: -1
90+
# spacy_path: "en_core_web_sm"
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from transformers import AutoModelForVision2Seq, AutoProcessor
2+
3+
4+
def load_model(model_path: str, device_map: str):
5+
model = AutoModelForVision2Seq.from_pretrained(
6+
model_path, trust_remote_code=True, device_map=device_map
7+
)
8+
model.eval()
9+
10+
return model
11+
12+
13+
def load_tokenizer(model_path: str):
14+
processor_visual = AutoProcessor.from_pretrained(
15+
model_path,
16+
padding_side="left",
17+
add_bos_token=True,
18+
)
19+
if processor_visual.tokenizer.pad_token is None:
20+
processor_visual.tokenizer.pad_token = processor_visual.tokenizer.eos_token
21+
22+
return processor_visual

examples/configs/model/kosmos.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
defaults:
2+
- default
3+
4+
path: microsoft/kosmos-2-patch14-224
5+
type: VisualLM
6+
path_to_load_script: model/default_visual.py
7+
8+
load_model_args:
9+
device_map: auto
10+
load_tokenizer_args: {}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
hydra:
2+
run:
3+
dir: ${cache_path}/${task}/${model}/${dataset}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4+
5+
defaults:
6+
- model: kosmos
7+
- estimators: default_estimators_visual
8+
- stat_calculators: default_calculators_visual
9+
- _self_
10+
11+
cache_path: ./workdir/output
12+
save_path: '${hydra:run.dir}'
13+
14+
task: qa
15+
16+
dataset: ['LM-Polygraph/vqa']
17+
text_column: question
18+
label_column: answer
19+
im_column: image
20+
train_split: train
21+
eval_split: test
22+
max_new_tokens: 3
23+
load_from_disk: false
24+
size: 100
25+
generation_params:
26+
generate_until:
27+
- "\n"
28+
29+
subsample_eval_dataset: -1
30+
31+
generation_metrics: null
32+
33+
ignore_exceptions: false
34+
35+
batch_size: 1
36+
37+
seed:
38+
- 1
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
- auto
2+
- name: TrainingStatisticExtractionCalculatorVisual
3+
builder: lm_polygraph.defaults.stat_calculator_builders.default_TrainingStatisticExtractionCalculatorVisual
4+
cfg:
5+
dataset: '${dataset}'
6+
text_column: '${text_column}'
7+
label_column: '${label_column}'
8+
im_column: '${im_column}'
9+
description: ''
10+
prompt: ''
11+
few_shot_split: "train"
12+
train_split: '${train_split}'
13+
load_from_disk: '${load_from_disk}'
14+
subsample_train_dataset: 50
15+
n_shot: 5
16+
background_train_dataset: LM-Polygraph/laion-1000-background
17+
background_train_dataset_text_column: txt
18+
background_train_dataset_label_column: __url__
19+
background_load_from_disk: false
20+
background_images: jpg
21+
background_train_dataset_data_files: data/train-00000-of-00001.parquet
22+
subsample_background_train_dataset: 100
23+
batch_size: '${batch_size}'
24+
seed: '${seed}'
25+
size: '${size}'
26+
bg_size: 1000
27+
output_attentions: True
28+
stats:
29+
- "train_embeddings"
30+
- "background_train_embeddings"
31+
- "train_greedy_log_likelihoods"
32+
dependencies:

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ scikit-learn>=1.5.1
55
tqdm>=4.64.1
66
matplotlib>=3.6
77
pandas>=1.3.5
8-
torch>=1.13.0
8+
torch>=2.6.0
99
bs4
10-
transformers>=4.48.0,<4.52.0
10+
transformers==4.50.0
1111
nltk>=3.6.5
1212
sacrebleu>=1.5.0
1313
sentencepiece>=0.1.97

scripts/polygraph_eval

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ from lm_polygraph.utils.manager import UEManager
1616
from lm_polygraph.utils.dataset import Dataset
1717
from lm_polygraph.utils.model import WhiteboxModel, BlackboxModel
1818
from lm_polygraph.model_adapters import WhiteboxModelvLLM
19+
from lm_polygraph.model_adapters.visual_whitebox_model import VisualWhiteboxModel
1920
from lm_polygraph.utils.processor import Logger
2021
from lm_polygraph.generation_metrics import *
2122
from lm_polygraph.estimators import *
2223
from lm_polygraph.ue_metrics import *
23-
from lm_polygraph.utils.common import load_external_module
24+
from lm_polygraph.utils.common import load_external_module, load_processor, load_image
2425
from lm_polygraph.utils.generation_parameters import GenerationParameters, GenerationParametersFactory
2526
from lm_polygraph.defaults.register_default_stat_calculators import (
2627
register_default_stat_calculators,
@@ -30,6 +31,7 @@ from lm_polygraph.utils.builder_enviroment_stat_calculator import (
3031
)
3132
from lm_polygraph.utils.factory_estimator import FactoryEstimator
3233
from lm_polygraph.utils.factory_stat_calculator import StatCalculatorContainer
34+
#from transformers import AutoProcessor, AutoModelForVision2Seq
3335

3436
hydra_config = Path(os.environ.get("HYDRA_CONFIG", ""))
3537

@@ -115,12 +117,14 @@ def main(args):
115117
n_shot=getattr(args, "n_shot", 5),
116118
few_shot_split=getattr(args, "few_shot_split", "train"),
117119
few_shot_prompt=getattr(args, "few_shot_prompt", None),
120+
im_column=getattr(args, "im_column", None),
118121
instruct=getattr(args, "instruct", None),
119122
split=args.eval_split,
120123
load_from_disk=args.load_from_disk,
121124
trust_remote_code=getattr(args, "trust_remote_code", False),
122125
**cache_kwargs,
123126
)
127+
# images=dataset.images
124128
log.info("Done with loading eval data.")
125129

126130
log.info("=" * 100)
@@ -191,7 +195,12 @@ def get_ue_metrics(args):
191195

192196

193197
def get_stat_calculator_names(config):
194-
model_type = "Whitebox" if getattr(config.model, "type", "Whitebox") != "Blackbox" else "Blackbox"
198+
model_type_raw = getattr(config.model, "type", "Whitebox")
199+
model_type = (
200+
"Blackbox" if model_type_raw == "Blackbox"
201+
else "VisualLM" if model_type_raw == "VisualLM"
202+
else "Whitebox"
203+
)
195204
language = getattr(config, "language", "en")
196205
output_attentions = getattr(config, "output_attentions", True) and (getattr(config.model, "type", "Whitebox") != "vLLMCausalLM")
197206
output_hidden_states = False if getattr(config.model, "type", "Whitebox") == "vLLMCausalLM" else True
@@ -321,6 +330,12 @@ def get_generation_metrics(args):
321330
def get_model(args):
322331
if getattr(args.model, "type", "Whitebox") == "Blackbox":
323332
return get_blackbox_model(args)
333+
elif getattr(args.model, "type", "Whitebox") == "VisualLM":
334+
cache_kwargs = {
335+
"cache_dir": getattr(args, "hf_cache", None),
336+
"token": getattr(args, "hf_token", None),
337+
}
338+
return get_visual_model(args, cache_kwargs)
324339
elif getattr(args.model, "type", "Whitebox") == "vLLMCausalLM":
325340
return get_vllm_model(args)
326341
else:
@@ -404,6 +419,52 @@ def get_whitebox_model(args, cache_kwargs={}):
404419

405420
return model
406421

422+
423+
def get_visual_model(args, cache_kwargs={}):
424+
if not "path_to_load_script" in args.model or not args.model.path_to_load_script:
425+
log.warning(
426+
"Loading model by directly passing the path to the model is deprecated and will be removed in the next release. Please use loading script instead."
427+
)
428+
log.info(f"Loading model with cache_kwargs: {cache_kwargs}")
429+
return VisualWhiteboxModel.from_pretrained(
430+
args.model.path,
431+
getattr(args, "generation_params", {}),
432+
device_map=args.model.load_model_args.device_map,
433+
add_bos_token=getattr(args.model, "add_bos_token", True),
434+
**cache_kwargs
435+
)
436+
437+
path_to_load_script = get_abs_path_from_hydra_config(
438+
args.model.path_to_load_script
439+
)
440+
load_module = load_external_module(path_to_load_script)
441+
442+
load_model_args = {'model_path': args.model.path}
443+
load_model_args.update(args.model.load_model_args)
444+
base_model = load_module.load_model(**load_model_args)
445+
446+
load_tok_args = {'model_path': args.model.path}
447+
load_tok_args.update(args.model.load_tokenizer_args)
448+
tokenizer = load_module.load_tokenizer(**load_tok_args)
449+
450+
load_proc_args = {'model_path': args.model.path}
451+
load_proc_args.update(getattr(args.model, "load_processor_args", {}))
452+
processor = load_processor(**load_proc_args)
453+
454+
generation_params = GenerationParametersFactory.from_params(
455+
yaml_config=getattr(args, "generation_params", {}),
456+
native_config=base_model.generation_config.to_dict()
457+
)
458+
459+
model = VisualWhiteboxModel(base_model,
460+
processor,
461+
args.model.path,
462+
args.model.type,
463+
generation_params)
464+
465+
return model
466+
467+
407468
def get_vllm_model(args):
408469
path_to_load_script = get_abs_path_from_hydra_config(
409470
args.model.path_to_load_script

src/lm_polygraph/defaults/register_default_stat_calculators.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ def _register(
153153
GreedyProbsVisualCalculator,
154154
"lm_polygraph.defaults.stat_calculator_builders.default_GreedyProbsVisualCalculator",
155155
{
156-
"output_attentions": True,
156+
"output_attentions": output_attentions,
157+
"output_hidden_states": output_hidden_states,
157158
},
158159
)
159160
_register(EntropyCalculator)
@@ -163,7 +164,6 @@ def _register(
163164
_register(BartScoreCalculator)
164165
_register(ModelScoreCalculator)
165166
_register(EnsembleTokenLevelDataCalculator)
166-
_register(PromptVisualCalculator)
167167
_register(SamplingPromptVisualCalculator)
168168
_register(ClaimPromptVisualCalculator)
169169
_register(
@@ -193,6 +193,7 @@ def _register(
193193
"language": language,
194194
},
195195
)
196+
_register(AttentionForwardPassCalculatorVisual)
196197

197198
else:
198199
raise NotImplementedError(f"Unknown model type: {model_type}")

0 commit comments

Comments
 (0)