mlxrunner: Propagate pipeline errors to client via api.StatusError

Errors that occur during pipeline processing are currently only
logged but not sent back to the client. Rather than using HTTP
status codes as we have historically done, this serializes errors
as messages to allow sending them at any time during the stream.
This commit is contained in:
Jesse Gross
2026-02-26 12:23:06 -08:00
parent 638faeac54
commit 18ab09b431
4 changed files with 43 additions and 38 deletions

View File

@@ -22,6 +22,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
@@ -192,6 +193,20 @@ type completionOpts struct {
NumPredict int `json:"num_predict,omitempty"` NumPredict int `json:"num_predict,omitempty"`
} }
type CompletionResponse struct {
Content string
Done bool
DoneReason int
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
PeakMemory uint64
Error *api.StatusError
}
// Close terminates the subprocess. // Close terminates the subprocess.
func (c *Client) Close() error { func (c *Client) Close() error {
c.mu.Lock() c.mu.Lock()
@@ -251,29 +266,24 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() { for scanner.Scan() {
var raw struct { var raw CompletionResponse
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil { if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes())) slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue continue
} }
if raw.Error != nil {
return *raw.Error
}
cresp := llm.CompletionResponse{ cresp := llm.CompletionResponse{
Content: raw.Content, Content: raw.Content,
Done: raw.Done, Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason), DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount, PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration), PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount, EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration), EvalDuration: raw.EvalDuration,
PeakMemory: raw.PeakMemory, PeakMemory: raw.PeakMemory,
} }

View File

@@ -98,7 +98,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
var b bytes.Buffer var b bytes.Buffer
now := time.Now() now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1} final := CompletionResponse{Done: true, PromptEvalCount: total, EvalCount: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens { for i := range request.Options.MaxTokens {
if err := request.Ctx.Err(); err != nil { if err := request.Ctx.Err(); err != nil {
return err return err
@@ -108,7 +108,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
if i == 0 { if i == 0 {
mlx.Eval(sample) mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now) final.PromptEvalDuration = time.Since(now)
now = time.Now() now = time.Now()
} }
@@ -116,18 +116,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
session.outputs = append(session.outputs, output) session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) { if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0 final.DoneReason = 0
final.CompletionTokens = i final.EvalCount = i
break break
} }
select { select {
case <-request.Ctx.Done(): case <-request.Ctx.Done():
return request.Ctx.Err() return request.Ctx.Err()
case request.Responses <- Response{ case request.Responses <- CompletionResponse{
Text: r.Decode(output, &b), Content: r.Decode(output, &b),
Token: int(output),
}: }:
} }
@@ -140,7 +138,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
} }
final.CompletionTokensDuration = time.Since(now) final.EvalDuration = time.Since(now)
final.PeakMemory = uint64(mlx.PeakMemory()) final.PeakMemory = uint64(mlx.PeakMemory())
select { select {
case <-request.Ctx.Done(): case <-request.Ctx.Done():

View File

@@ -4,14 +4,15 @@ package mlxrunner
import ( import (
"context" "context"
"errors"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"strings" "strings"
"time"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -21,7 +22,7 @@ import (
type Request struct { type Request struct {
TextCompletionsRequest TextCompletionsRequest
Responses chan Response Responses chan CompletionResponse
Pipeline func(Request) error Pipeline func(Request) error
Ctx context.Context Ctx context.Context
@@ -43,21 +44,6 @@ type TextCompletionsRequest struct {
} `json:"options"` } `json:"options"`
} }
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct { type Runner struct {
Model base.Model Model base.Model
Tokenizer *tokenizer.Tokenizer Tokenizer *tokenizer.Tokenizer
@@ -159,6 +145,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case request := <-r.Requests: case request := <-r.Requests:
if err := request.Pipeline(request); err != nil { if err := request.Pipeline(request); err != nil {
slog.Info("Request terminated", "error", err) slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
statusErr = api.StatusError{
StatusCode: http.StatusInternalServerError,
ErrorMessage: err.Error(),
}
}
select {
case request.Responses <- CompletionResponse{Error: &statusErr}:
case <-request.Ctx.Done():
}
} }
close(request.Responses) close(request.Responses)

View File

@@ -79,7 +79,7 @@ func Execute(args []string) error {
}) })
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)} request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil { if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err) slog.Error("Failed to decode request", "error", err)