Skip to content

Commit 98db134

Browse files
authored
fix: The call to the tokenizer service may be canceled in advance. (#201)
* fix: The call to the tokenizer service may be canceled in advance. Signed-off-by: Hang Yin <[email protected]> * fix nested if Signed-off-by: Hang Yin <[email protected]> * go dot fix Signed-off-by: Hang Yin <[email protected]> --------- Signed-off-by: Hang Yin <[email protected]>
1 parent 070bc0f commit 98db134

File tree

1 file changed

+50
-45
lines changed

1 file changed

+50
-45
lines changed

pkg/tokenization/uds_tokenizer.go

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -116,23 +116,13 @@ func (u *UdsTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.O
116116
return nil, nil, fmt.Errorf("failed to create request: %w", err)
117117
}
118118

119-
resp, err := u.executeRequest(req, defaultTimeout, defaultMaxRetries)
119+
respBody, err := u.executeRequest(req, defaultTimeout, defaultMaxRetries)
120120
if err != nil {
121-
return nil, nil, fmt.Errorf("failed to execute request: %w", err)
122-
}
123-
defer resp.Body.Close()
124-
125-
body, err := io.ReadAll(resp.Body)
126-
if err != nil {
127-
return nil, nil, fmt.Errorf("failed to read response body: %w", err)
128-
}
129-
130-
if resp.StatusCode != http.StatusOK {
131-
return nil, nil, fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body))
121+
return nil, nil, fmt.Errorf("tokenize request failed: %w", err)
132122
}
133123

134124
var tokenized TokenizedInput
135-
if err := json.Unmarshal(body, &tokenized); err != nil {
125+
if err := json.Unmarshal(respBody, &tokenized); err != nil {
136126
return nil, nil, fmt.Errorf("failed to unmarshal response: %w", err)
137127
}
138128

@@ -159,32 +149,23 @@ func (u *UdsTokenizer) RenderChatTemplate(
159149
}
160150
req.Header.Set("Content-Type", "application/json")
161151

162-
resp, err := u.executeRequest(req, defaultTimeout, defaultMaxRetries)
152+
respBody, err := u.executeRequest(req, defaultTimeout, defaultMaxRetries)
163153
if err != nil {
164-
return "", fmt.Errorf("failed to execute request: %w", err)
165-
}
166-
defer resp.Body.Close()
167-
168-
body, err := io.ReadAll(resp.Body)
169-
if err != nil {
170-
return "", fmt.Errorf("failed to read response body: %w", err)
171-
}
172-
173-
if resp.StatusCode != http.StatusOK {
174-
return "", fmt.Errorf("server returned status %d: %s", resp.StatusCode, string(body))
154+
return "", fmt.Errorf("chat-template request failed: %w", err)
175155
}
176-
177-
return string(body), nil
156+
return string(respBody), nil
178157
}
179158

180159
func (u *UdsTokenizer) Type() string {
181160
return "external-uds"
182161
}
183162

184163
// executeRequest executes an HTTP request with timeout and retry logic.
185-
func (u *UdsTokenizer) executeRequest(req *http.Request,
186-
timeout time.Duration, maxRetries int,
187-
) (*http.Response, error) {
164+
func (u *UdsTokenizer) executeRequest(
165+
req *http.Request,
166+
timeout time.Duration,
167+
maxRetries int,
168+
) ([]byte, error) {
188169
if timeout == 0 {
189170
timeout = defaultTimeout
190171
}
@@ -193,28 +174,27 @@ func (u *UdsTokenizer) executeRequest(req *http.Request,
193174
}
194175

195176
// Try the request up to maxRetries+1 times
196-
var lastErr error
197177
delay := initialRetryDelay
178+
var lastErr error
198179

199180
for attempt := 0; attempt <= maxRetries; attempt++ {
200181
// Create a context with timeout
201182
ctx, cancel := context.WithTimeout(req.Context(), timeout)
202-
req = req.WithContext(ctx)
183+
reqWithCtx := req.WithContext(ctx)
203184

204185
// Execute the request
205-
resp, err := u.httpClient.Do(req)
206-
lastErr = err
186+
resp, err := u.httpClient.Do(reqWithCtx)
207187

208-
cancel()
209-
210-
// If no error, check status code
188+
// Process the response if no error occurred
211189
if err == nil {
212-
// For non-5xx status codes, don't retry
213-
if resp.StatusCode < 500 {
214-
return resp, nil
190+
body, processErr := u.processResponse(resp, cancel)
191+
if processErr == nil {
192+
return body, nil
215193
}
216-
// Close the response body before retrying
217-
resp.Body.Close()
194+
lastErr = processErr
195+
} else {
196+
lastErr = err
197+
cancel()
218198
}
219199

220200
// If this was the last attempt, break
@@ -227,13 +207,38 @@ func (u *UdsTokenizer) executeRequest(req *http.Request,
227207
delay *= 2 // Exponential backoff
228208

229209
// Add some jitter to prevent thundering herd
230-
jitter, err := rand.Int(rand.Reader, big.NewInt(int64(delay/2)))
231-
if err != nil {
210+
jitter, randErr := rand.Int(rand.Reader, big.NewInt(int64(delay/2)))
211+
if randErr != nil {
232212
// Fallback to using the full delay without jitter
233213
jitter = big.NewInt(int64(delay / 2))
234214
}
235215
delay += time.Duration(jitter.Int64())
236216
}
217+
if lastErr == nil {
218+
lastErr = fmt.Errorf("request failed after %d attempts", maxRetries+1)
219+
} else {
220+
lastErr = fmt.Errorf("request failed after %d attempts: %w", maxRetries+1, lastErr)
221+
}
222+
return nil, lastErr
223+
}
224+
225+
// processResponse handles the HTTP response, checking status code and reading the body.
226+
func (u *UdsTokenizer) processResponse(resp *http.Response, cancel context.CancelFunc) ([]byte, error) {
227+
defer cancel()
228+
229+
// For 200 status codes, don't retry.
230+
if resp.StatusCode == http.StatusOK {
231+
// Read the response before canceling the context
232+
body, readErr := io.ReadAll(resp.Body)
233+
resp.Body.Close()
234+
if readErr == nil {
235+
return body, nil
236+
}
237+
return nil, fmt.Errorf("failed to read response body: %w", readErr)
238+
}
237239

238-
return nil, fmt.Errorf("request failed after %d retries: %w", maxRetries, lastErr)
240+
// For non-200 status codes, close the response body and return an error
241+
errorMsg := fmt.Errorf("server returned status %d", resp.StatusCode)
242+
resp.Body.Close()
243+
return nil, errorMsg
239244
}

0 commit comments

Comments
 (0)