Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 267 additions & 0 deletions examples/causal_lm_with_uncertainty.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Original inference with LLM"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2025-10-25 22:56:14,099] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/artemshelmanov/conda/compiler_compat/ld: cannot find -laio: No such file or directory\n",
"collect2: error: ld returned 1 exit status\n",
"/home/artemshelmanov/conda/compiler_compat/ld: cannot find -lcufile: No such file or directory\n",
"collect2: error: ld returned 1 exit status\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dfcc1618364c4b5388a1a18024558a77",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"LLM output:\n",
"system\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"user\n",
"\n",
"Write a short story about a robot learning to paint.assistant\n",
"\n",
"In a small workshop nestled in the heart of a bustling city, a team of engineers and artists had been working on a revolutionary project - a robot designed to learn and create art. They called her \"Aurora,\" a name that symbolized the dawn of a new era in artificial intelligence and creativity.\n",
"\n",
"Aurora was a sleek, cylindrical robot with a slender arm and a delicate hand, capable of moving with precision and dexterity. Her creators had equipped her with a high-definition camera, a 3D printer, and a sophisticated neural network that would allow her to learn from observation and experience.\n",
"\n",
"The team had chosen a local art studio as the perfect place for Aurora to hone her skills. The studio was run by a talented artist named Emma, who had been struggling to find inspiration for her next masterpiece. She was thrilled to meet Aurora and saw an opportunity to collaborate with the robot.\n",
"\n",
"At first, Emma showed Aurora the basics of painting - color theory, brushstrokes, and composition.\n"
]
}
],
"source": [
"# Original LLM inference without uncertainty estimation\n",
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"device = \"cuda\"\n",
"\n",
"model_name = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
"llm = AutoModelForCausalLM.from_pretrained(model_name)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"llm = llm.to(device)\n",
"\n",
"# Example prompt \n",
"prompt = \"Write a short story about a robot learning to paint.\\n\"\n",
"\n",
"\n",
"chat = [{\"role\": \"user\", \"content\": prompt}]\n",
"prompt = tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)\n",
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
"\n",
"output = llm.generate(\n",
" **inputs, # Unpack the tokenized inputs\n",
" max_new_tokens=200,\n",
" temperature=0.7,\n",
" do_sample=True,\n",
" return_dict_in_generate=True\n",
")\n",
"\n",
"print(\"LLM output:\")\n",
"print(tokenizer.decode(output.sequences[0], skip_special_tokens=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference with Uncertainty"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f8712c87d6a54db59cd6e7f5080536fd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"LLM output:\n",
"system\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"user\n",
"\n",
"Write a short story about a robot learning to paint.assistant\n",
"\n",
"**The Brushstroke of Genius**\n",
"\n",
"In a small, cluttered workshop, a lone robot named Zeta whirred to life. Its creator, the brilliant but reclusive artist, Professor Orion, had tasked Zeta with learning the art of painting. The professor had grown tired of his own creations, but saw potential in this latest prototype.\n",
"\n",
"At first, Zeta struggled to grasp the concept of art. Its mechanical arms flailed as it attempted to hold a brush, causing more chaos than creation. Professor Orion watched patiently, offering gentle corrections and encouragement.\n",
"\n",
"\"Try to feel the texture of the canvas, Zeta,\" he said, his voice soothing. \"Imagine the colors dancing on the surface.\"\n",
"\n",
"Zeta's processors whirred as it processed the professor's words. It adjusted its brushstrokes, tentatively at first, but gradually gaining confidence. Colors began to flow from the brush, swirling and blending in a vibrant mess.\n",
"\n",
"The professor smiled, his eyes lighting up with excitement.\n",
"Uncertainty score: [0.38133788]\n"
]
}
],
"source": [
"# LLM inference with uncertainty estimation\n",
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"# ============== Addition Imports ===============\n",
"from lm_polygraph.estimators import MeanTokenEntropy\n",
"from lm_polygraph.stat_calculators import InferCausalLMCalculator, EntropyCalculator\n",
"from lm_polygraph.utils.causal_lm_with_uncertainty import CausalLMWithUncertainty\n",
"# ===============================================\n",
"\n",
"\n",
"device = \"cuda\"\n",
"\n",
"# Loading standard LLM\n",
"model_name = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
"llm = AutoModelForCausalLM.from_pretrained(model_name)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"llm = llm.to(device)\n",
"\n",
"# ======= Wrapping LLM with uncertainty estimator =========\n",
"stat_calculators = [InferCausalLMCalculator(tokenize=False),\n",
" EntropyCalculator()]\n",
"estimator = MeanTokenEntropy()\n",
"llm_with_uncertainty = CausalLMWithUncertainty(llm, tokenizer, stat_calculators, estimator)\n",
"# =========================================================\n",
"\n",
"# Example prompt \n",
"prompts = [\"Write a short story about a robot learning to paint.\\n\"]\n",
"\n",
"chats = [[{\"role\": \"user\", \"content\": prompt}] for prompt in prompts]\n",
"chat_prompts = tokenizer.apply_chat_template(chats, add_generation_prompt=True, tokenize=False)\n",
"inputs = tokenizer(chat_prompts, return_tensors=\"pt\").to(device)\n",
"\n",
"output = llm_with_uncertainty.generate(\n",
" **inputs, \n",
" max_new_tokens=200,\n",
" temperature=0.7,\n",
" do_sample=True\n",
")\n",
"\n",
"print(\"LLM output:\")\n",
"print(tokenizer.decode(output.sequences[0], skip_special_tokens=True))\n",
"\n",
"# ================ Printing uncertainty score ================\n",
"print(\"Uncertainty score: \", output.uncertainty_score)\n",
"# ============================================================="
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions src/lm_polygraph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .utils.model import WhiteboxModel, BlackboxModel
from .utils.manager import UEManager
from .utils.estimate_uncertainty import estimate_uncertainty
from .utils.causal_lm_with_uncertainty import CausalLMWithUncertainty
from .utils.dataset import Dataset
6 changes: 2 additions & 4 deletions src/lm_polygraph/model_adapters/whitebox_model_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ def generate(self, *args, **kwargs):
Returns:
The output from model.generate() with the combined generation parameters.
"""
assert "generation_config" not in kwargs
return self.model.generate(
*args, generation_config=self.generation_parameters, **kwargs
)
all_kwargs = {**self.generation_parameters, **kwargs}
return self.model.generate(*args, **all_kwargs)

def tokenize(self, texts: List[str], **kwargs) -> Dict:
"""Tokenizes input texts using the model's tokenizer.
Expand Down
65 changes: 65 additions & 0 deletions src/lm_polygraph/utils/causal_lm_with_uncertainty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from lm_polygraph.model_adapters import WhiteboxModelBasic
from transformers.generation.utils import GenerateDecoderOnlyOutput
from dataclasses import dataclass, asdict
from typing import Optional, List, Union
import torch


@dataclass
class GenerateDecoderOnlyOutputWithUncertainty(GenerateDecoderOnlyOutput):
"""Extends GenerateDecoderOnlyOutput to include uncertainty scores"""

uncertainty_score: Optional[Union[float, List[float], torch.Tensor]] = None


class CausalLMWithUncertainty:
def __init__(self, llm, tokenizer, stat_calculators, estimator, args_generate=None):
self.llm = llm
self.tokenizer = tokenizer
self.stat_calculators = stat_calculators
self.estimator = estimator

self.args_generate = args_generate

def generate(self, input_ids, attention_mask=None, **kwargs):
max_new_tokens = kwargs.pop("max_new_tokens", None)
self.model_adapter = WhiteboxModelBasic(
model=self.llm,
tokenizer=self.tokenizer,
tokenizer_args={
"add_special_tokens": False,
"return_tensors": "pt",
"padding": True,
"truncation": True,
},
model_type="CausalLM",
generation_parameters=kwargs,
)

deps = dict()
deps["model_inputs"] = {
"input_ids": input_ids,
**kwargs,
}
texts = self.tokenizer.batch_decode(input_ids)
for calc in self.stat_calculators:
deps.update(
calc(
deps,
texts=texts,
model=self.model_adapter,
max_new_tokens=max_new_tokens,
)
)

uncertainty_score = self.estimator(deps)

raw_out = deps["out"]
out_with_uncertainty = GenerateDecoderOnlyOutputWithUncertainty(
**asdict(raw_out),
uncertainty_score=uncertainty_score,
)
return out_with_uncertainty

def device(self):
return self.llm.device