mlxrunner: Propagate pipeline errors to client via api.StatusError

Errors that occur during pipeline processing are currently only
logged but not sent back to the client. Rather than using HTTP
status codes as we have historically done, this serializes errors
as messages to allow sending them at any time during the stream.
This commit is contained in:
Jesse Gross
2026-02-26 12:23:06 -08:00
parent 638faeac54
commit 18ab09b431
4 changed files with 43 additions and 38 deletions

View File

@@ -22,6 +22,7 @@ import (
"sync/atomic"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
@@ -192,6 +193,20 @@ type completionOpts struct {
NumPredict int `json:"num_predict,omitempty"`
}
type CompletionResponse struct {
Content string
Done bool
DoneReason int
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
PeakMemory uint64
Error *api.StatusError
}
// Close terminates the subprocess.
func (c *Client) Close() error {
c.mu.Lock()
@@ -251,29 +266,24 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var raw struct {
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
}
var raw CompletionResponse
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
if raw.Error != nil {
return *raw.Error
}
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration),
EvalDuration: raw.EvalDuration,
PeakMemory: raw.PeakMemory,
}

View File

@@ -98,7 +98,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
var b bytes.Buffer
now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
final := CompletionResponse{Done: true, PromptEvalCount: total, EvalCount: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
if err := request.Ctx.Err(); err != nil {
return err
@@ -108,7 +108,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
if i == 0 {
mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
@@ -116,18 +116,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
final.EvalCount = i
break
}
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
case request.Responses <- CompletionResponse{
Content: r.Decode(output, &b),
}:
}
@@ -140,7 +138,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
}
final.CompletionTokensDuration = time.Since(now)
final.EvalDuration = time.Since(now)
final.PeakMemory = uint64(mlx.PeakMemory())
select {
case <-request.Ctx.Done():

View File

@@ -4,14 +4,15 @@ package mlxrunner
import (
"context"
"errors"
"log/slog"
"net"
"net/http"
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -21,7 +22,7 @@ import (
type Request struct {
TextCompletionsRequest
Responses chan Response
Responses chan CompletionResponse
Pipeline func(Request) error
Ctx context.Context
@@ -43,21 +44,6 @@ type TextCompletionsRequest struct {
} `json:"options"`
}
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
@@ -159,6 +145,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
statusErr = api.StatusError{
StatusCode: http.StatusInternalServerError,
ErrorMessage: err.Error(),
}
}
select {
case request.Responses <- CompletionResponse{Error: &statusErr}:
case <-request.Ctx.Done():
}
}
close(request.Responses)

View File

@@ -79,7 +79,7 @@ func Execute(args []string) error {
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)}
request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err)