mlxrunner: Cancel in-flight requests when the client disconnects

Currently, a canceled request can result in computation continuing
in the background to completion. It can also trigger a deadlock
when there is nobody to read the output tokens and the pipeline
cannot continue to the next request.
This commit is contained in:
Jesse Gross
2026-02-24 14:19:33 -08:00
parent 4e57d2094e
commit 0f23b7bff5
3 changed files with 47 additions and 13 deletions

View File

@@ -55,6 +55,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
if err := request.Ctx.Err(); err != nil {
return err
}
n := min(2<<10, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
@@ -92,6 +96,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
if err := request.Ctx.Err(); err != nil {
return err
}
nextSample, nextLogprobs = step(sample)
if i == 0 {
@@ -111,9 +119,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
break
}
request.Responses <- Response{
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}:
}
mlx.Unpin(sample, logprobs)
@@ -126,8 +138,12 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
return nil
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- final:
return nil
}
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {

View File

@@ -12,7 +12,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -25,8 +24,9 @@ type Request struct {
Responses chan Response
Pipeline func(Request) error
Ctx context.Context
sample.Sampler
caches []cache.Cache
}
type TextCompletionsRequest struct {
@@ -157,7 +157,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
break
slog.Info("Request terminated", "error", err)
}
close(request.Responses)

View File

@@ -5,6 +5,7 @@ package mlxrunner
import (
"bytes"
"cmp"
"context"
"encoding/json"
"flag"
"fmt"
@@ -98,19 +99,36 @@ func Execute(args []string) error {
request.Options.TopK,
)
runner.Requests <- request
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())
defer cancel()
select {
case <-r.Context().Done():
return
case runner.Requests <- request:
}
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for response := range request.Responses {
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
for {
select {
case <-r.Context().Done():
return
}
case response, ok := <-request.Responses:
if !ok {
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
}
})