11import argparse
2- from threading import Thread
2+ import time
3+ import statistics
4+ from threading import Thread , Event
35import asyncio
46from fastapi import FastAPI , WebSocket
57from fastapi .responses import HTMLResponse
8+ from starlette .websockets import WebSocketDisconnect
69from pydantic import BaseModel
7- from transformers import TextIteratorStreamer
10+ from transformers import TextIteratorStreamer , StoppingCriteria , StoppingCriteriaList
811import uvicorn
912from turnkeyml .state import State
1013from turnkeyml .tools import Tool
1114from turnkeyml .llm .tools .adapter import ModelAdapter , TokenizerAdapter
1215
16+ DEFAULT_GENERATE_PARAMS = {
17+ "do_sample" : True ,
18+ "top_k" : 50 ,
19+ "top_p" : 0.95 ,
20+ "temperature" : 0.7 ,
21+ }
22+
23+ DEFAULT_SERVER_PORT = 8000
24+
1325
1426class LLMPrompt (Tool ):
1527 """
@@ -61,7 +73,9 @@ def run(
6173 tokenizer : TokenizerAdapter = state .tokenizer
6274
6375 input_ids = tokenizer (prompt , return_tensors = "pt" ).input_ids
64- response = model .generate (input_ids , max_new_tokens = max_new_tokens )
76+ response = model .generate (
77+ input_ids , max_new_tokens = max_new_tokens , ** DEFAULT_GENERATE_PARAMS
78+ )
6579 response_text = tokenizer .decode (response [0 ], skip_special_tokens = True ).strip ()
6680
6781 state .response = response_text
@@ -70,16 +84,32 @@ def run(
7084 return state
7185
7286
87+ # Custom huggingface-style stopping criteria to allow
88+ # us to halt streaming in-progress generations
89+ class StopOnEvent (StoppingCriteria ):
90+ def __init__ (self , stop_event : Event ):
91+ super ().__init__ ()
92+ self .stop_event = stop_event
93+
94+ def __call__ (self , input_ids , scores , ** kwargs ):
95+ return self .stop_event .is_set ()
96+
97+
7398class Serve (Tool ):
7499 """
75100 Open a web server that apps can use to communicate with the LLM.
76101
77- There are two ways interact with the server:
102+ There are two ways to perform generations with the server:
78103 - Send an http request to "http://localhost:8000/generate" and
79104 receive back a response with the complete prompt.
80105 - Open a WebSocket with "ws://localhost:8000" and receive a
81106 streaming response to the prompt.
82107
108+ The server also exposes these helpful endpoints:
109+ - /health: check whether a model is loaded and ready to serve.
110+ - /stats: performance statistics for the generation.
111+ - /halt: stop an in-progress generation from make more tokens.
112+
83113 The WebSocket functionality is demonstrated by the webpage served at
84114 http://localhost:8000, which you can visit with a web browser after
85115 opening the server.
@@ -89,6 +119,7 @@ class Serve(Tool):
89119 huggingface TextIteratorStreamer.
90120 - state.tokenizer: tokenizer instance used to generate inputs for the
91121 model. Must be compatible with the huggingface TextIteratorStreamer.
122+ - state.checkpoint: name of the checkpoint used to load state.model.
92123
93124 Output state produced: None
94125 """
@@ -102,6 +133,17 @@ def __init__(self):
102133 enable_logger = False ,
103134 )
104135
136+ # Performance stats that are set during /ws and can be
137+ # fetched in /stats
138+ self .time_to_first_token = None
139+ self .tokens_per_second = None
140+ self .input_tokens = None
141+ self .output_tokens = None
142+ self .decode_token_times = None
143+
144+ # Flag that tells the LLM to stop generating text and end the response
145+ self .stop_event = Event ()
146+
105147 @staticmethod
106148 def parser (add_help : bool = True ) -> argparse .ArgumentParser :
107149 parser = __class__ .helpful_parser (
@@ -151,10 +193,15 @@ class Message(BaseModel):
151193 <input type="text" id="messageText" autocomplete="off"/>
152194 <button type="submit">Send</button>
153195 </form>
196+ <button onclick="showStats()">Show Stats</button>
197+ <button onclick="halt()">Halt</button>
198+ <button onclick="health()">Health</button>
154199 <p id="allMessages"></p> <!-- Use a <p> element to display all messages -->
200+ <p id="statsMessage"></p> <!-- Use a <p> element to display stats message -->
155201 <script>
156202 const messageQueue = []; // Store incoming messages
157203 const allMessagesContainer = document.getElementById('allMessages'); // Get the container element
204+ const statsMessageContainer = document.getElementById('statsMessage'); // Get the stats message container
158205 var ws = new WebSocket("ws://localhost:8000/ws");
159206 ws.onmessage = function(event) {
160207 const message = event.data;
@@ -173,6 +220,36 @@ class Message(BaseModel):
173220 input.value = ''
174221 event.preventDefault()
175222 }
223+ function showStats() {
224+ fetch('/stats')
225+ .then(response => response.json())
226+ .then(data => {
227+ statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
228+ })
229+ .catch(error => {
230+ console.error('Error:', error);
231+ });
232+ }
233+ function halt() {
234+ fetch('/halt')
235+ .then(response => response.json())
236+ .then(data => {
237+ statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
238+ })
239+ .catch(error => {
240+ console.error('Error:', error);
241+ });
242+ }
243+ function health() {
244+ fetch('/health')
245+ .then(response => response.json())
246+ .then(data => {
247+ statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
248+ })
249+ .catch(error => {
250+ console.error('Error:', error);
251+ });
252+ }
176253 </script>
177254 </body>
178255 </html>
@@ -188,11 +265,8 @@ async def generate_response(message: Message):
188265 response = model .generate (
189266 input_ids ,
190267 max_new_tokens = max_new_tokens ,
191- do_sample = True ,
192- top_k = 50 ,
193- top_p = 0.95 ,
194- temperature = 0.7 ,
195268 pad_token_id = tokenizer .eos_token_id ,
269+ ** DEFAULT_GENERATE_PARAMS ,
196270 )
197271 generated_text = tokenizer .decode (response [0 ], skip_special_tokens = True )
198272
@@ -203,13 +277,23 @@ async def generate_response(message: Message):
203277
204278 @app .websocket ("/ws" )
205279 async def stream_response (websocket : WebSocket ):
280+ """
281+ Receive a prompt string, and then stream the response back
282+ over a websocket.
283+ """
284+
206285 await websocket .accept ()
207286 while True :
208287
209- message = await websocket .receive_text ()
210-
211- if message == "done" :
288+ try :
289+ message = await websocket .receive_text ()
290+ except WebSocketDisconnect :
291+ print ("Client closed connection" )
212292 break
293+
294+ # Reset the early-exit flag before we start each generation
295+ self .stop_event .clear ()
296+
213297 input_ids = tokenizer (message , return_tensors = "pt" ).input_ids
214298
215299 # Set up the generation parameters
@@ -219,39 +303,109 @@ async def stream_response(websocket: WebSocket):
219303
220304 streamer = oga .OrtGenaiStreamer (tokenizer )
221305
306+ self .input_tokens = len (input_ids )
307+
222308 else :
223309 # Huggingface-like models
224310 streamer = TextIteratorStreamer (
225311 tokenizer ,
226312 skip_prompt = True ,
227313 )
314+
315+ self .input_tokens = len (input_ids [0 ])
316+
317+ # Enable sending a signal into the generator thread to stop
318+ # the generation early
319+ stopping_criteria = StoppingCriteriaList ([StopOnEvent (self .stop_event )])
320+
228321 generation_kwargs = {
229322 "input_ids" : input_ids ,
230323 "streamer" : streamer ,
231324 "max_new_tokens" : max_new_tokens ,
232- "do_sample" : True ,
233- "top_k" : 50 ,
234- "top_p" : 0.95 ,
235- "temperature" : 0.7 ,
236325 "pad_token_id" : tokenizer .eos_token_id ,
326+ "stopping_criteria" : stopping_criteria ,
327+ ** DEFAULT_GENERATE_PARAMS ,
237328 }
238329
330+ # Initialize performance variables
331+ generation_start_time = time .perf_counter ()
332+ first_token = True
333+ self .decode_token_times = []
334+ self .output_tokens = 0
335+
336+ # Begin generation
239337 thread = Thread (target = model .generate , kwargs = generation_kwargs )
240338 thread .start ()
241339
242340 # Generate the response using streaming
243341 for new_text in streamer :
342+
343+ # Capture performance stats about this token
344+ self .output_tokens = self .output_tokens + 1
345+ if first_token :
346+ self .time_to_first_token = (
347+ time .perf_counter () - generation_start_time
348+ )
349+ first_token = False
350+ else :
351+ self .decode_token_times .append (
352+ time .perf_counter () - next_token_start_time
353+ )
354+ next_token_start_time = time .perf_counter ()
355+
356+ # Print the decoded value to the terminal for debugging purposes
244357 print (new_text , end = "" , flush = True )
245358
246359 # Send the generated text to the client
247- await asyncio .sleep (0.1 ) # Add a small delay (adjust as needed)
360+ await asyncio .sleep (0.001 ) # Add a small delay (adjust as needed)
248361 await websocket .send_text (new_text )
249362
363+ # Allow the user to finish the response early
364+ if self .stop_event .is_set ():
365+ print ("Stopping generation early." )
366+ break
367+
368+ self .tokens_per_second = 1 / statistics .mean (self .decode_token_times )
250369 print ("\n " )
251370 thread .join ()
252371
253- await websocket .close ()
254-
255- uvicorn .run (app , host = "localhost" , port = 8000 )
372+ @app .get ("/stats" )
373+ async def send_stats ():
374+ """
375+ Send performance statistics to the client.
376+ """
377+ return {
378+ "time_to_first_token" : self .time_to_first_token ,
379+ "tokens_per_second" : self .tokens_per_second ,
380+ "input_tokens" : self .input_tokens ,
381+ "output_tokens" : self .output_tokens ,
382+ "decode_token_times" : self .decode_token_times ,
383+ }
384+
385+ @app .get ("/halt" )
386+ async def halt_generation ():
387+ """
388+ Allow the client to halt an in-progress generation.
389+ """
390+
391+ self .stop_event .set ()
392+
393+ return {
394+ "terminated" : True ,
395+ }
396+
397+ @app .get ("/health" )
398+ async def health ():
399+ """
400+ Report server health information to the client.
401+ """
402+
403+ self .stop_event .set ()
404+
405+ return {
406+ "model_loaded" : state .checkpoint ,
407+ }
408+
409+ uvicorn .run (app , host = "localhost" , port = DEFAULT_SERVER_PORT )
256410
257411 return state
0 commit comments