Skip to content

logit_bias parameter in chat completion api seems to be ignored. #501

@damienlancry

Description

@damienlancry

I deployed a mistral-7b-instruct-v0.1 model an endpoint on sagemaker following this tutorial.

In my particular usecase, I want the llm to output only one token: "0" or "1". Therefore, I am using the parameters logit_bias and max_tokens=1 of the chat completion api.

For logit_bias, I checked the token ids of "0", "1" from the tokenizer.json of mistral-7b-instruct-v0.1. I am therefore using logit_bias={"28734": 100, "28740": 100}

Then I am trying to send requests to this endpoint:

import json
import sagemaker

sagemaker_session = sagemaker.Session(
    boto_session=boto3.Session(
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key,
        aws_session_token=session_token,
    )
)

sagemaker_client = sagemaker_session.sagemaker_runtime_client

messages = [
    {"role": "system", "content": "Reply only by 0 or 1."},
    {"role": "user", "content": "what is 0+0?"},
    {"role": "assistant", "content": "0"},
    {"role": "user", "content": "what is 1/1?"},
]

content_type = "application/json"
request_body = {
    "messages": messages,
    "model": "mistralai/Mistral-7B-v0.1",
    "logit_bias": {"28734": 100, "28740": 100},  # bias of +100 for "0" and "1" 
    "max_tokens": 1,
}
payload = json.dumps(request_body).encode("utf-8")

endpoint_name = "lmi-mistral-7b-instruct-v01-xxxx-xx-xx-xx-xx-xx-xxx"
response = sagemaker_client.invoke_endpoint(EndpointName=endpoint_name, ContentType=content_type, Body=payload)
result = response["Body"].read()
ans = json.loads(result)
pprint(ans)

The output is:

{'choices': [{'finish_reason': 'length',
              'index': 0,
              'logprobs': None,
              'message': {'content': ' ', 'role': 'assistant'}}],
 'created': 1724127226,
 'id': 'chatcmpl-<built-in function id>',
 'object': 'chat.completion',
 'usage': {'completion_tokens': 1, 'prompt_tokens': 45, 'total_tokens': 46}}

So the output is " " which indicates that the logit_bias parameter is ignored.
Note that even if I penalize the white space token with logit_bias= {"28734": 100, "28740": 100, "29000": -100} the output token is still " ".

Am I doing anything wrong? Or is the logit_bias parameter not actually supported?
Thanks for any help.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions