mirror of
https://github.com/ollama/ollama.git
synced 2026-04-30 17:58:49 -05:00
mlxrunner: Propagate pipeline errors to client via api.StatusError
Errors that occur during pipeline processing are currently only logged but not sent back to the client. Rather than using HTTP status codes as we have historically done, this serializes errors as messages to allow sending them at any time during the stream.
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
@@ -192,6 +193,20 @@ type completionOpts struct {
|
|||||||
NumPredict int `json:"num_predict,omitempty"`
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string
|
||||||
|
Done bool
|
||||||
|
DoneReason int
|
||||||
|
|
||||||
|
PromptEvalCount int
|
||||||
|
PromptEvalDuration time.Duration
|
||||||
|
EvalCount int
|
||||||
|
EvalDuration time.Duration
|
||||||
|
PeakMemory uint64
|
||||||
|
|
||||||
|
Error *api.StatusError
|
||||||
|
}
|
||||||
|
|
||||||
// Close terminates the subprocess.
|
// Close terminates the subprocess.
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -251,29 +266,24 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var raw struct {
|
var raw CompletionResponse
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
Done bool `json:"done"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
|
||||||
EvalCount int `json:"eval_count,omitempty"`
|
|
||||||
EvalDuration int `json:"eval_duration,omitempty"`
|
|
||||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if raw.Error != nil {
|
||||||
|
return *raw.Error
|
||||||
|
}
|
||||||
|
|
||||||
cresp := llm.CompletionResponse{
|
cresp := llm.CompletionResponse{
|
||||||
Content: raw.Content,
|
Content: raw.Content,
|
||||||
Done: raw.Done,
|
Done: raw.Done,
|
||||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||||
PromptEvalCount: raw.PromptEvalCount,
|
PromptEvalCount: raw.PromptEvalCount,
|
||||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
PromptEvalDuration: raw.PromptEvalDuration,
|
||||||
EvalCount: raw.EvalCount,
|
EvalCount: raw.EvalCount,
|
||||||
EvalDuration: time.Duration(raw.EvalDuration),
|
EvalDuration: raw.EvalDuration,
|
||||||
PeakMemory: raw.PeakMemory,
|
PeakMemory: raw.PeakMemory,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
final := CompletionResponse{Done: true, PromptEvalCount: total, EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||||
for i := range request.Options.MaxTokens {
|
for i := range request.Options.MaxTokens {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -108,7 +108,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
mlx.Eval(sample)
|
mlx.Eval(sample)
|
||||||
final.PromptTokensDuration = time.Since(now)
|
final.PromptEvalDuration = time.Since(now)
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,18 +116,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
session.outputs = append(session.outputs, output)
|
session.outputs = append(session.outputs, output)
|
||||||
|
|
||||||
if r.Tokenizer.IsEOS(output) {
|
if r.Tokenizer.IsEOS(output) {
|
||||||
final.Token = int(output)
|
|
||||||
final.DoneReason = 0
|
final.DoneReason = 0
|
||||||
final.CompletionTokens = i
|
final.EvalCount = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return request.Ctx.Err()
|
return request.Ctx.Err()
|
||||||
case request.Responses <- Response{
|
case request.Responses <- CompletionResponse{
|
||||||
Text: r.Decode(output, &b),
|
Content: r.Decode(output, &b),
|
||||||
Token: int(output),
|
|
||||||
}:
|
}:
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,7 +138,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.CompletionTokensDuration = time.Since(now)
|
final.EvalDuration = time.Since(now)
|
||||||
final.PeakMemory = uint64(mlx.PeakMemory())
|
final.PeakMemory = uint64(mlx.PeakMemory())
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
|
|||||||
@@ -4,14 +4,15 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
@@ -21,7 +22,7 @@ import (
|
|||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
TextCompletionsRequest
|
TextCompletionsRequest
|
||||||
Responses chan Response
|
Responses chan CompletionResponse
|
||||||
Pipeline func(Request) error
|
Pipeline func(Request) error
|
||||||
|
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
@@ -43,21 +44,6 @@ type TextCompletionsRequest struct {
|
|||||||
} `json:"options"`
|
} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Text string `json:"content,omitempty"`
|
|
||||||
Token int `json:"token,omitempty"`
|
|
||||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
|
||||||
Done bool `json:"done,omitempty"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
|
|
||||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
|
||||||
CompletionTokens int `json:"eval_count,omitempty"`
|
|
||||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
|
||||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
|
||||||
TotalTokens int `json:"total_tokens,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
Model base.Model
|
Model base.Model
|
||||||
Tokenizer *tokenizer.Tokenizer
|
Tokenizer *tokenizer.Tokenizer
|
||||||
@@ -159,6 +145,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
|||||||
case request := <-r.Requests:
|
case request := <-r.Requests:
|
||||||
if err := request.Pipeline(request); err != nil {
|
if err := request.Pipeline(request); err != nil {
|
||||||
slog.Info("Request terminated", "error", err)
|
slog.Info("Request terminated", "error", err)
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
statusErr = api.StatusError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
ErrorMessage: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||||
|
case <-request.Ctx.Done():
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
close(request.Responses)
|
close(request.Responses)
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func Execute(args []string) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
request := Request{Responses: make(chan Response)}
|
request := Request{Responses: make(chan CompletionResponse)}
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||||
slog.Error("Failed to decode request", "error", err)
|
slog.Error("Failed to decode request", "error", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user