mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 03:12:11 -05:00
mlxrunner: Propagate pipeline errors to client via api.StatusError
Errors that occur during pipeline processing are currently only logged but not sent back to the client. Rather than using HTTP status codes as we have historically done, this serializes errors as messages to allow sending them at any time during the stream.
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
@@ -192,6 +193,20 @@ type completionOpts struct {
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
DoneReason int
|
||||
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
PeakMemory uint64
|
||||
|
||||
Error *api.StatusError
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
@@ -251,29 +266,24 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var raw struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||
}
|
||||
var raw CompletionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Error != nil {
|
||||
return *raw.Error
|
||||
}
|
||||
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
PromptEvalDuration: raw.PromptEvalDuration,
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
EvalDuration: raw.EvalDuration,
|
||||
PeakMemory: raw.PeakMemory,
|
||||
}
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: total, EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||
for i := range request.Options.MaxTokens {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
@@ -108,7 +108,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
if i == 0 {
|
||||
mlx.Eval(sample)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
@@ -116,18 +116,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
final.EvalCount = i
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
case request.Responses <- CompletionResponse{
|
||||
Content: r.Decode(output, &b),
|
||||
}:
|
||||
}
|
||||
|
||||
@@ -140,7 +138,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
final.EvalDuration = time.Since(now)
|
||||
final.PeakMemory = uint64(mlx.PeakMemory())
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
|
||||
@@ -4,14 +4,15 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan Response
|
||||
Responses chan CompletionResponse
|
||||
Pipeline func(Request) error
|
||||
|
||||
Ctx context.Context
|
||||
@@ -43,21 +44,6 @@ type TextCompletionsRequest struct {
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Text string `json:"content,omitempty"`
|
||||
Token int `json:"token,omitempty"`
|
||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
|
||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
CompletionTokens int `json:"eval_count,omitempty"`
|
||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
@@ -159,6 +145,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
slog.Info("Request terminated", "error", err)
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
statusErr = api.StatusError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
select {
|
||||
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||
case <-request.Ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
|
||||
@@ -79,7 +79,7 @@ func Execute(args []string) error {
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan Response)}
|
||||
request := Request{Responses: make(chan CompletionResponse)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
|
||||
Reference in New Issue
Block a user