show peak memory usage (#14485)

This commit is contained in:
Patrick Devine
2026-02-26 18:38:27 -08:00
committed by GitHub
parent cc90a035a0
commit 79917cf80b
7 changed files with 26 additions and 0 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/orderedmap" "github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -569,6 +570,7 @@ type DebugInfo struct {
type Metrics struct { type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
@@ -934,6 +936,10 @@ func (m *Metrics) Summary() {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
} }
if m.PeakMemory > 0 {
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
}
if m.LoadDuration > 0 { if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
} }
@@ -957,6 +963,14 @@ func (m *Metrics) Summary() {
} }
} }
func formatPeakMemory(b uint64) string {
if b >= format.GibiByte {
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
}
return format.HumanBytes2(b)
}
func (opts *Options) FromMap(m map[string]any) error { func (opts *Options) FromMap(m map[string]any) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct

View File

@@ -1518,6 +1518,7 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration `json:"prompt_eval_duration"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"` EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"` EvalDuration time.Duration `json:"eval_duration"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
// Logprobs contains log probability information if requested // Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"` Logprobs []Logprob `json:"logprobs,omitempty"`

View File

@@ -557,6 +557,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration, EvalDuration: cr.EvalDuration,
PeakMemory: cr.PeakMemory,
}, },
Logprobs: toAPILogprobs(cr.Logprobs), Logprobs: toAPILogprobs(cr.Logprobs),
} }
@@ -2309,6 +2310,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
PeakMemory: r.PeakMemory,
}, },
Logprobs: toAPILogprobs(r.Logprobs), Logprobs: toAPILogprobs(r.Logprobs),
} }

View File

@@ -268,6 +268,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"` PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"` EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"` EvalDuration int `json:"eval_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
} }
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil { if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes())) slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
@@ -282,6 +283,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
PromptEvalDuration: time.Duration(raw.PromptEvalDuration), PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount, EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration), EvalDuration: time.Duration(raw.EvalDuration),
PeakMemory: raw.PeakMemory,
} }
fn(cresp) fn(cresp)

View File

@@ -64,6 +64,10 @@ func PeakMemory() int {
return int(peak) return int(peak)
} }
func ResetPeakMemory() {
C.mlx_reset_peak_memory()
}
type Memory struct{} type Memory struct{}
func (Memory) LogValue() slog.Value { func (Memory) LogValue() slog.Value {

View File

@@ -44,6 +44,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} else { } else {
mlx.DisableCompile() mlx.DisableCompile()
} }
mlx.ResetPeakMemory()
inputs := r.Tokenizer.Encode(request.Prompt, true) inputs := r.Tokenizer.Encode(request.Prompt, true)
session := r.cache.begin(r.Model, inputs) session := r.cache.begin(r.Model, inputs)
@@ -138,6 +139,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
final.CompletionTokensDuration = time.Since(now) final.CompletionTokensDuration = time.Since(now)
final.PeakMemory = uint64(mlx.PeakMemory())
select { select {
case <-request.Ctx.Done(): case <-request.Ctx.Done():
return request.Ctx.Err() return request.Ctx.Err()

View File

@@ -54,6 +54,7 @@ type Response struct {
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"` PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"` CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"` CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"` TotalTokens int `json:"total_tokens,omitempty"`
} }