Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 262 additions & 0 deletions examples/calibrated_confidence_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ff10d9c1-be8d-4ef7-844f-fe4731f982b2",
"metadata": {},
"outputs": [],
"source": [
"from IPython.core.debugger import set_trace"
]
},
{
"cell_type": "markdown",
"id": "a9981854-40e5-4b69-9efa-2f3b6c037ec3",
"metadata": {},
"source": [
"### Define lists of datasets, methods and metrics to consider"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cc85b82b-56e0-4b27-8074-1a3e06df81e5",
"metadata": {},
"outputs": [],
"source": [
"pwd = !pwd\n",
"pwd = pwd[0]\n",
"\n",
"# Absolute path to default Hydra config for normalization script. Take a look at this script to see how normalization is parametrized.\n",
"config_path = '/'.join(pwd.split('/')[:-1]) + '/examples/configs/normalization/fit/default.yaml'\n",
"\n",
"EVAL_MAN_PATH = pwd + '/polygraph_tacl_stablelm12b_wmt19.man'\n",
"TRAIN_MAN_PATH = pwd + '/polygraph_tacl_stablelm12b_wmt19_train.man'\n",
"UE_METHOD = 'MaximumSequenceProbability'\n",
"\n",
"# A quality metric for model's outputs that is naturally bounded on [0, 1]\n",
"GEN_METRIC_NAME = 'Comet'"
]
},
{
"cell_type": "markdown",
"id": "dffc34fe-319b-45a6-b0f5-5e4db2ec96da",
"metadata": {},
"source": [
"### Fit normalizers"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "260d8fa3-c70d-451b-8de5-6f68e3870293",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2024-12-03 17:02:31-- http://209.38.249.180:8000/polygraph_data/mans/\n",
"Connecting to 209.38.249.180:8000... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 4500 (4.4K) [text/html]\n",
"Saving to: ‘index.html.tmp’\n",
"\n",
"index.html.tmp 100%[===================>] 4.39K 4.19KB/s in 1.0s \n",
"\n",
"2024-12-03 17:02:33 (4.19 KB/s) - ‘index.html.tmp’ saved [4500/4500]\n",
"\n",
"Loading robots.txt; please ignore errors.\n",
"--2024-12-03 17:02:33-- http://209.38.249.180:8000/robots.txt\n",
"Connecting to 209.38.249.180:8000... connected.\n",
"HTTP request sent, awaiting response... 404 File not found\n",
"2024-12-03 17:02:33 ERROR 404: File not found.\n",
"\n",
"Removing index.html.tmp since it should be rejected.\n",
"\n",
"--2024-12-03 17:02:33-- http://209.38.249.180:8000/polygraph_data/mans/polygraph_tacl_stablelm12b_wmt19.man\n",
"Connecting to 209.38.249.180:8000... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2243332 (2.1M) [application/x-troff-man]\n",
"Saving to: ‘polygraph_tacl_stablelm12b_wmt19.man’\n",
"\n",
"polygraph_tacl_stab 100%[===================>] 2.14M 59.4KB/s in 42s \n",
"\n",
"2024-12-03 17:03:16 (51.6 KB/s) - ‘polygraph_tacl_stablelm12b_wmt19.man’ saved [2243332/2243332]\n",
"\n",
"--2024-12-03 17:03:16-- http://209.38.249.180:8000/polygraph_data/mans/polygraph_tacl_stablelm12b_wmt19_train.man\n",
"Connecting to 209.38.249.180:8000... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2012164 (1.9M) [application/x-troff-man]\n",
"Saving to: ‘polygraph_tacl_stablelm12b_wmt19_train.man’\n",
"\n",
"polygraph_tacl_stab 100%[===================>] 1.92M 78.2KB/s in 27s \n",
"\n",
"2024-12-03 17:03:43 (73.5 KB/s) - ‘polygraph_tacl_stablelm12b_wmt19_train.man’ saved [2012164/2012164]\n",
"\n",
"FINISHED --2024-12-03 17:03:43--\n",
"Total wall clock time: 1m 12s\n",
"Downloaded: 3 files, 4.1M in 1m 10s (59.2 KB/s)\n"
]
}
],
"source": [
"# Download all managers to current directory\n",
"!wget -r --cut-dirs=2 -nH --no-parent -A 'polygraph_tacl_stablelm12b_wmt19*man' http://209.38.249.180:8000/polygraph_data/mans/"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6d90b943-ff6a-480a-bf5d-d3bd22e258d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2024-12-03 17:03:57,716][lm_polygraph][INFO] - ====================================================================================================\n",
"[2024-12-03 17:03:57,716][lm_polygraph][INFO] - Initializing stat calculators...\n",
"[2024-12-03 17:03:57,717][lm_polygraph][INFO] - Initializing GreedyProbsCalculator\n",
"[2024-12-03 17:03:57,717][lm_polygraph][INFO] - Stat calculators: [<lm_polygraph.stat_calculators.greedy_probs.GreedyProbsCalculator object at 0x31c5f5000>]\n",
"[2024-12-03 17:03:57,717][lm_polygraph][INFO] - Done intitializing stat calculators...\n"
]
},
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"\n",
"def format_for_hydra(param):\n",
" return f'\\'[\"{param}\"]\\''\n",
"\n",
"# Run polygraph_normalize to fit normalizer using train dataset.\n",
"# Format path to manager so that Hydra correctly recognizes it as override with list of paths.\n",
"os.system(f'HYDRA_CONFIG={config_path} polygraph_normalize save_path=\"./\" man_paths={format_for_hydra(TRAIN_MAN_PATH)} gen_metric_names={format_for_hydra(GEN_METRIC_NAME)} ue_method_names={format_for_hydra(UE_METHOD)}')"
]
},
{
"cell_type": "markdown",
"id": "def027cc-e031-4ee0-a396-2aad347c4eb3",
"metadata": {},
"source": [
"### Normalize UE from test sets"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "31d2d6d7-39f9-4bf4-842e-24f8f464d6a0",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import pprint\n",
"\n",
"# Load saved fitted normalizer.\n",
"with open('fitted_normalizers.pickle', 'rb') as f:\n",
" fitted_normalizers = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fc3d056e-e44a-4b17-a2b1-0dcadfdf94e2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniconda/base/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from lm_polygraph.normalizers.isotonic_pcc import IsotonicPCCNormalizer\n",
"\n",
"# Restore saved normalizer\n",
"normalizer = IsotonicPCCNormalizer.loads(fitted_normalizers[GEN_METRIC_NAME, UE_METHOD, 'isotonic_pcc'])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a7ae19fa-3d56-4ad9-92dc-575836c143d4",
"metadata": {},
"outputs": [],
"source": [
"from lm_polygraph.utils.manager import UEManager\n",
"\n",
"test_man = UEManager.load(EVAL_MAN_PATH)\n",
"\n",
"de_sentence = test_man.stats['input_texts'][0]\n",
"# Remove prompt\n",
"de_sentence = de_sentence.split('\\n')[-3]\n",
"\n",
"translation = test_man.stats['greedy_texts'][0]\n",
"\n",
"ue = test_man.estimations[('sequence', 'MaximumSequenceProbability')][0]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "eb8eddb0-98fd-4255-9f48-e74ea665c594",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Zwar werde es noch dauern, bis die bislang der Kirchengemeinde gehörende Fläche im Gemeindebesitz ist, doch gibt es jetzt kein Planungshindernis für das von einem privaten Investor zu stemmende Projekt mehr.\n",
"====================================================================================================\n",
"Although it will take some time before the land currently belonging to the church community is in the possession of the municipality, there is no longer any planning obstacle for the project to be financed by a private investor.\n",
"====================================================================================================\n",
"Confidence: 0.7774244338428103\n"
]
}
],
"source": [
"calibrated_confidence = normalizer.transform([ue])[0]\n",
"print(de_sentence)\n",
"print('=' * 100)\n",
"print(translation)\n",
"print('=' * 100)\n",
"print('Confidence: ', calibrated_confidence)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading