mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
mlxrunner: Cancel in-flight requests when the client disconnects
Currently, a canceled request can result in computation continuing in the background to completion. It can also trigger a deadlock when there is nobody to read the output tokens and the pipeline cannot continue to the next request.
This commit is contained in:
@@ -55,6 +55,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n := min(2<<10, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
@@ -92,6 +96,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
for i := range request.Options.MaxTokens {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nextSample, nextLogprobs = step(sample)
|
||||
|
||||
if i == 0 {
|
||||
@@ -111,9 +119,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
break
|
||||
}
|
||||
|
||||
request.Responses <- Response{
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
}:
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
@@ -126,8 +138,12 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -25,8 +24,9 @@ type Request struct {
|
||||
Responses chan Response
|
||||
Pipeline func(Request) error
|
||||
|
||||
Ctx context.Context
|
||||
|
||||
sample.Sampler
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
@@ -157,7 +157,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
return nil
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
break
|
||||
slog.Info("Request terminated", "error", err)
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
|
||||
@@ -5,6 +5,7 @@ package mlxrunner
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
@@ -98,12 +99,28 @@ func Execute(args []string) error {
|
||||
request.Options.TopK,
|
||||
)
|
||||
|
||||
runner.Requests <- request
|
||||
var cancel context.CancelFunc
|
||||
request.Ctx, cancel = context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case runner.Requests <- request:
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/jsonl")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
for response := range request.Responses {
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case response, ok := <-request.Responses:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
return
|
||||
@@ -113,6 +130,7 @@ func Execute(args []string) error {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
Reference in New Issue
Block a user