diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 901b25e89..50e4681e1 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -55,6 +55,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error { total, processed := len(tokens), 0 slog.Info("Prompt processing progress", "processed", processed, "total", total) for total-processed > 1 { + if err := request.Ctx.Err(); err != nil { + return err + } + n := min(2<<10, total-processed-1) r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) mlx.Sweep() @@ -92,6 +96,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error { now := time.Now() final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1} for i := range request.Options.MaxTokens { + if err := request.Ctx.Err(); err != nil { + return err + } + nextSample, nextLogprobs = step(sample) if i == 0 { @@ -111,9 +119,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error { break } - request.Responses <- Response{ + select { + case <-request.Ctx.Done(): + return request.Ctx.Err() + case request.Responses <- Response{ Text: r.Decode(output, &b), Token: int(output), + }: } mlx.Unpin(sample, logprobs) @@ -126,8 +138,12 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } final.CompletionTokensDuration = time.Since(now) - request.Responses <- final - return nil + select { + case <-request.Ctx.Done(): + return request.Ctx.Err() + case request.Responses <- final: + return nil + } } func (r Runner) Decode(sample int32, b *bytes.Buffer) string { diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 353b98d8d..7d538b02a 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -12,7 +12,6 @@ import ( "golang.org/x/sync/errgroup" - "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" @@ -25,8 +24,9 @@ type Request struct { Responses chan Response Pipeline func(Request) error + Ctx context.Context + sample.Sampler - caches []cache.Cache } type TextCompletionsRequest struct { @@ -157,7 +157,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error { return nil case request := <-r.Requests: if err := request.Pipeline(request); err != nil { - break + slog.Info("Request terminated", "error", err) } close(request.Responses) diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index 09b71f3c8..89688cfbc 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -5,6 +5,7 @@ package mlxrunner import ( "bytes" "cmp" + "context" "encoding/json" "flag" "fmt" @@ -98,19 +99,36 @@ func Execute(args []string) error { request.Options.TopK, ) - runner.Requests <- request + var cancel context.CancelFunc + request.Ctx, cancel = context.WithCancel(r.Context()) + defer cancel() + + select { + case <-r.Context().Done(): + return + case runner.Requests <- request: + } w.Header().Set("Content-Type", "application/jsonl") w.WriteHeader(http.StatusOK) enc := json.NewEncoder(w) - for response := range request.Responses { - if err := enc.Encode(response); err != nil { - slog.Error("Failed to encode response", "error", err) + for { + select { + case <-r.Context().Done(): return - } + case response, ok := <-request.Responses: + if !ok { + return + } - if f, ok := w.(http.Flusher); ok { - f.Flush() + if err := enc.Encode(response); err != nil { + slog.Error("Failed to encode response", "error", err) + return + } + + if f, ok := w.(http.Flusher); ok { + f.Flush() + } } } })