Skip to content

Commit 50996bb

Browse files
authored
Update LLM server, fix bugs, and format with black (#236)
1 parent d165e22 commit 50996bb

File tree

9 files changed

+217
-39
lines changed

9 files changed

+217
-39
lines changed

.github/workflows/test_lemonade.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ jobs:
3232
conda install pylint
3333
python -m pip check
3434
pip install -e .[llm]
35+
- name: Lint with Black
36+
uses: psf/black@stable
37+
with:
38+
options: "--check --verbose"
39+
src: "./src"
3540
- name: Lint with PyLint
3641
shell: bash -el {0}
3742
run: |

src/turnkeyml/common/build.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,16 @@ def get_wmic_info(command):
282282
try:
283283
output = subprocess.check_output(command, shell=True).decode()
284284
return output.split("\n")[1].strip()
285-
except Exception as e: # pylint: disable=broad-except
285+
except Exception as e: # pylint: disable=broad-except
286286
return str(e)
287287

288288
if os_type == "Windows":
289289
if shutil.which("wmic") is not None:
290290
info_dict["Processor"] = get_wmic_info("wmic cpu get name")
291291
info_dict["OEM System"] = get_wmic_info("wmic computersystem get model")
292-
mem_info_bytes = get_wmic_info("wmic computersystem get TotalPhysicalMemory")
292+
mem_info_bytes = get_wmic_info(
293+
"wmic computersystem get TotalPhysicalMemory"
294+
)
293295
try:
294296
mem_info_gb = round(int(mem_info_bytes) / (1024**3), 2)
295297
info_dict["Physical Memory"] = f"{mem_info_gb} GB"

src/turnkeyml/llm/cli.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ def main():
5454
except ModuleNotFoundError:
5555
pass
5656

57-
58-
59-
60-
6157
# Define the argument parser
6258
parser = cli.CustomArgumentParser(
6359
description="Turnkey analysis and benchmarking of GenAI models. "

src/turnkeyml/llm/tools/chat.py

Lines changed: 173 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
import argparse
2-
from threading import Thread
2+
import time
3+
import statistics
4+
from threading import Thread, Event
35
import asyncio
46
from fastapi import FastAPI, WebSocket
57
from fastapi.responses import HTMLResponse
8+
from starlette.websockets import WebSocketDisconnect
69
from pydantic import BaseModel
7-
from transformers import TextIteratorStreamer
10+
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
811
import uvicorn
912
from turnkeyml.state import State
1013
from turnkeyml.tools import Tool
1114
from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter
1215

16+
DEFAULT_GENERATE_PARAMS = {
17+
"do_sample": True,
18+
"top_k": 50,
19+
"top_p": 0.95,
20+
"temperature": 0.7,
21+
}
22+
23+
DEFAULT_SERVER_PORT = 8000
24+
1325

1426
class LLMPrompt(Tool):
1527
"""
@@ -61,7 +73,9 @@ def run(
6173
tokenizer: TokenizerAdapter = state.tokenizer
6274

6375
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
64-
response = model.generate(input_ids, max_new_tokens=max_new_tokens)
76+
response = model.generate(
77+
input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS
78+
)
6579
response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip()
6680

6781
state.response = response_text
@@ -70,16 +84,32 @@ def run(
7084
return state
7185

7286

87+
# Custom huggingface-style stopping criteria to allow
88+
# us to halt streaming in-progress generations
89+
class StopOnEvent(StoppingCriteria):
90+
def __init__(self, stop_event: Event):
91+
super().__init__()
92+
self.stop_event = stop_event
93+
94+
def __call__(self, input_ids, scores, **kwargs):
95+
return self.stop_event.is_set()
96+
97+
7398
class Serve(Tool):
7499
"""
75100
Open a web server that apps can use to communicate with the LLM.
76101
77-
There are two ways interact with the server:
102+
There are two ways to perform generations with the server:
78103
- Send an http request to "http://localhost:8000/generate" and
79104
receive back a response with the complete prompt.
80105
- Open a WebSocket with "ws://localhost:8000" and receive a
81106
streaming response to the prompt.
82107
108+
The server also exposes these helpful endpoints:
109+
- /health: check whether a model is loaded and ready to serve.
110+
- /stats: performance statistics for the generation.
111+
- /halt: stop an in-progress generation from make more tokens.
112+
83113
The WebSocket functionality is demonstrated by the webpage served at
84114
http://localhost:8000, which you can visit with a web browser after
85115
opening the server.
@@ -89,6 +119,7 @@ class Serve(Tool):
89119
huggingface TextIteratorStreamer.
90120
- state.tokenizer: tokenizer instance used to generate inputs for the
91121
model. Must be compatible with the huggingface TextIteratorStreamer.
122+
- state.checkpoint: name of the checkpoint used to load state.model.
92123
93124
Output state produced: None
94125
"""
@@ -102,6 +133,17 @@ def __init__(self):
102133
enable_logger=False,
103134
)
104135

136+
# Performance stats that are set during /ws and can be
137+
# fetched in /stats
138+
self.time_to_first_token = None
139+
self.tokens_per_second = None
140+
self.input_tokens = None
141+
self.output_tokens = None
142+
self.decode_token_times = None
143+
144+
# Flag that tells the LLM to stop generating text and end the response
145+
self.stop_event = Event()
146+
105147
@staticmethod
106148
def parser(add_help: bool = True) -> argparse.ArgumentParser:
107149
parser = __class__.helpful_parser(
@@ -151,10 +193,15 @@ class Message(BaseModel):
151193
<input type="text" id="messageText" autocomplete="off"/>
152194
<button type="submit">Send</button>
153195
</form>
196+
<button onclick="showStats()">Show Stats</button>
197+
<button onclick="halt()">Halt</button>
198+
<button onclick="health()">Health</button>
154199
<p id="allMessages"></p> <!-- Use a <p> element to display all messages -->
200+
<p id="statsMessage"></p> <!-- Use a <p> element to display stats message -->
155201
<script>
156202
const messageQueue = []; // Store incoming messages
157203
const allMessagesContainer = document.getElementById('allMessages'); // Get the container element
204+
const statsMessageContainer = document.getElementById('statsMessage'); // Get the stats message container
158205
var ws = new WebSocket("ws://localhost:8000/ws");
159206
ws.onmessage = function(event) {
160207
const message = event.data;
@@ -173,6 +220,36 @@ class Message(BaseModel):
173220
input.value = ''
174221
event.preventDefault()
175222
}
223+
function showStats() {
224+
fetch('/stats')
225+
.then(response => response.json())
226+
.then(data => {
227+
statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
228+
})
229+
.catch(error => {
230+
console.error('Error:', error);
231+
});
232+
}
233+
function halt() {
234+
fetch('/halt')
235+
.then(response => response.json())
236+
.then(data => {
237+
statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
238+
})
239+
.catch(error => {
240+
console.error('Error:', error);
241+
});
242+
}
243+
function health() {
244+
fetch('/health')
245+
.then(response => response.json())
246+
.then(data => {
247+
statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
248+
})
249+
.catch(error => {
250+
console.error('Error:', error);
251+
});
252+
}
176253
</script>
177254
</body>
178255
</html>
@@ -188,11 +265,8 @@ async def generate_response(message: Message):
188265
response = model.generate(
189266
input_ids,
190267
max_new_tokens=max_new_tokens,
191-
do_sample=True,
192-
top_k=50,
193-
top_p=0.95,
194-
temperature=0.7,
195268
pad_token_id=tokenizer.eos_token_id,
269+
**DEFAULT_GENERATE_PARAMS,
196270
)
197271
generated_text = tokenizer.decode(response[0], skip_special_tokens=True)
198272

@@ -203,13 +277,23 @@ async def generate_response(message: Message):
203277

204278
@app.websocket("/ws")
205279
async def stream_response(websocket: WebSocket):
280+
"""
281+
Receive a prompt string, and then stream the response back
282+
over a websocket.
283+
"""
284+
206285
await websocket.accept()
207286
while True:
208287

209-
message = await websocket.receive_text()
210-
211-
if message == "done":
288+
try:
289+
message = await websocket.receive_text()
290+
except WebSocketDisconnect:
291+
print("Client closed connection")
212292
break
293+
294+
# Reset the early-exit flag before we start each generation
295+
self.stop_event.clear()
296+
213297
input_ids = tokenizer(message, return_tensors="pt").input_ids
214298

215299
# Set up the generation parameters
@@ -219,39 +303,109 @@ async def stream_response(websocket: WebSocket):
219303

220304
streamer = oga.OrtGenaiStreamer(tokenizer)
221305

306+
self.input_tokens = len(input_ids)
307+
222308
else:
223309
# Huggingface-like models
224310
streamer = TextIteratorStreamer(
225311
tokenizer,
226312
skip_prompt=True,
227313
)
314+
315+
self.input_tokens = len(input_ids[0])
316+
317+
# Enable sending a signal into the generator thread to stop
318+
# the generation early
319+
stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
320+
228321
generation_kwargs = {
229322
"input_ids": input_ids,
230323
"streamer": streamer,
231324
"max_new_tokens": max_new_tokens,
232-
"do_sample": True,
233-
"top_k": 50,
234-
"top_p": 0.95,
235-
"temperature": 0.7,
236325
"pad_token_id": tokenizer.eos_token_id,
326+
"stopping_criteria": stopping_criteria,
327+
**DEFAULT_GENERATE_PARAMS,
237328
}
238329

330+
# Initialize performance variables
331+
generation_start_time = time.perf_counter()
332+
first_token = True
333+
self.decode_token_times = []
334+
self.output_tokens = 0
335+
336+
# Begin generation
239337
thread = Thread(target=model.generate, kwargs=generation_kwargs)
240338
thread.start()
241339

242340
# Generate the response using streaming
243341
for new_text in streamer:
342+
343+
# Capture performance stats about this token
344+
self.output_tokens = self.output_tokens + 1
345+
if first_token:
346+
self.time_to_first_token = (
347+
time.perf_counter() - generation_start_time
348+
)
349+
first_token = False
350+
else:
351+
self.decode_token_times.append(
352+
time.perf_counter() - next_token_start_time
353+
)
354+
next_token_start_time = time.perf_counter()
355+
356+
# Print the decoded value to the terminal for debugging purposes
244357
print(new_text, end="", flush=True)
245358

246359
# Send the generated text to the client
247-
await asyncio.sleep(0.1) # Add a small delay (adjust as needed)
360+
await asyncio.sleep(0.001) # Add a small delay (adjust as needed)
248361
await websocket.send_text(new_text)
249362

363+
# Allow the user to finish the response early
364+
if self.stop_event.is_set():
365+
print("Stopping generation early.")
366+
break
367+
368+
self.tokens_per_second = 1 / statistics.mean(self.decode_token_times)
250369
print("\n")
251370
thread.join()
252371

253-
await websocket.close()
254-
255-
uvicorn.run(app, host="localhost", port=8000)
372+
@app.get("/stats")
373+
async def send_stats():
374+
"""
375+
Send performance statistics to the client.
376+
"""
377+
return {
378+
"time_to_first_token": self.time_to_first_token,
379+
"tokens_per_second": self.tokens_per_second,
380+
"input_tokens": self.input_tokens,
381+
"output_tokens": self.output_tokens,
382+
"decode_token_times": self.decode_token_times,
383+
}
384+
385+
@app.get("/halt")
386+
async def halt_generation():
387+
"""
388+
Allow the client to halt an in-progress generation.
389+
"""
390+
391+
self.stop_event.set()
392+
393+
return {
394+
"terminated": True,
395+
}
396+
397+
@app.get("/health")
398+
async def health():
399+
"""
400+
Report server health information to the client.
401+
"""
402+
403+
self.stop_event.set()
404+
405+
return {
406+
"model_loaded": state.checkpoint,
407+
}
408+
409+
uvicorn.run(app, host="localhost", port=DEFAULT_SERVER_PORT)
256410

257411
return state

src/turnkeyml/llm/tools/huggingface_load.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,15 @@ def __init__(self, model, dtype=torch.float32, device="cpu"):
201201
self.dtype = dtype
202202
self.device = device
203203

204-
def generate(self, input_ids, max_new_tokens=512, repetition_penalty=1.2,
205-
do_sample=True, temperature=0.1, **kwargs):
204+
def generate(
205+
self,
206+
input_ids,
207+
max_new_tokens=512,
208+
repetition_penalty=1.2,
209+
do_sample=True,
210+
temperature=0.1,
211+
**kwargs,
212+
):
206213
amp_enabled = (
207214
True
208215
if (self.dtype == torch.float16 or self.dtype == torch.bfloat16)
@@ -221,7 +228,7 @@ def generate(self, input_ids, max_new_tokens=512, repetition_penalty=1.2,
221228
repetition_penalty=repetition_penalty,
222229
do_sample=do_sample,
223230
temperature=temperature,
224-
**kwargs
231+
**kwargs,
225232
)
226233

227234

0 commit comments

Comments
 (0)