From 79917cf80bf74538a4ae694e6b61adb908b0f8df Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Thu, 26 Feb 2026 18:38:27 -0800 Subject: [PATCH] show peak memory usage (#14485) --- api/types.go | 14 ++++++++++++++ llm/server.go | 1 + server/routes.go | 2 ++ x/mlxrunner/client.go | 2 ++ x/mlxrunner/mlx/memory.go | 4 ++++ x/mlxrunner/pipeline.go | 2 ++ x/mlxrunner/runner.go | 1 + 7 files changed, 26 insertions(+) diff --git a/api/types.go b/api/types.go index 82caf17dc..f891a043b 100644 --- a/api/types.go +++ b/api/types.go @@ -15,6 +15,7 @@ import ( "github.com/google/uuid" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" "github.com/ollama/ollama/internal/orderedmap" "github.com/ollama/ollama/types/model" ) @@ -569,6 +570,7 @@ type DebugInfo struct { type Metrics struct { TotalDuration time.Duration `json:"total_duration,omitempty"` + PeakMemory uint64 `json:"peak_memory,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` @@ -934,6 +936,10 @@ func (m *Metrics) Summary() { fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) } + if m.PeakMemory > 0 { + fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory)) + } + if m.LoadDuration > 0 { fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) } @@ -957,6 +963,14 @@ func (m *Metrics) Summary() { } } +func formatPeakMemory(b uint64) string { + if b >= format.GibiByte { + return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte)) + } + + return format.HumanBytes2(b) +} + func (opts *Options) FromMap(m map[string]any) error { valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct diff --git a/llm/server.go b/llm/server.go index de8ad0f75..b6b4f4442 100644 --- a/llm/server.go +++ b/llm/server.go @@ -1518,6 +1518,7 @@ type CompletionResponse struct { PromptEvalDuration time.Duration `json:"prompt_eval_duration"` EvalCount int `json:"eval_count"` EvalDuration time.Duration `json:"eval_duration"` + PeakMemory uint64 `json:"peak_memory,omitempty"` // Logprobs contains log probability information if requested Logprobs []Logprob `json:"logprobs,omitempty"` diff --git a/server/routes.go b/server/routes.go index cbe771d9f..1813cdc98 100644 --- a/server/routes.go +++ b/server/routes.go @@ -557,6 +557,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { PromptEvalDuration: cr.PromptEvalDuration, EvalCount: cr.EvalCount, EvalDuration: cr.EvalDuration, + PeakMemory: cr.PeakMemory, }, Logprobs: toAPILogprobs(cr.Logprobs), } @@ -2309,6 +2310,7 @@ func (s *Server) ChatHandler(c *gin.Context) { PromptEvalDuration: r.PromptEvalDuration, EvalCount: r.EvalCount, EvalDuration: r.EvalDuration, + PeakMemory: r.PeakMemory, }, Logprobs: toAPILogprobs(r.Logprobs), } diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index 2152c382f..d2cd239d8 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -268,6 +268,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f PromptEvalDuration int `json:"prompt_eval_duration,omitempty"` EvalCount int `json:"eval_count,omitempty"` EvalDuration int `json:"eval_duration,omitempty"` + PeakMemory uint64 `json:"peak_memory,omitempty"` } if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil { slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes())) @@ -282,6 +283,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f PromptEvalDuration: time.Duration(raw.PromptEvalDuration), EvalCount: raw.EvalCount, EvalDuration: time.Duration(raw.EvalDuration), + PeakMemory: raw.PeakMemory, } fn(cresp) diff --git a/x/mlxrunner/mlx/memory.go b/x/mlxrunner/mlx/memory.go index e9a174b1e..cf36c304c 100644 --- a/x/mlxrunner/mlx/memory.go +++ b/x/mlxrunner/mlx/memory.go @@ -64,6 +64,10 @@ func PeakMemory() int { return int(peak) } +func ResetPeakMemory() { + C.mlx_reset_peak_memory() +} + type Memory struct{} func (Memory) LogValue() slog.Value { diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 50e4681e1..945f94755 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -44,6 +44,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } else { mlx.DisableCompile() } + mlx.ResetPeakMemory() inputs := r.Tokenizer.Encode(request.Prompt, true) session := r.cache.begin(r.Model, inputs) @@ -138,6 +139,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } final.CompletionTokensDuration = time.Since(now) + final.PeakMemory = uint64(mlx.PeakMemory()) select { case <-request.Ctx.Done(): return request.Ctx.Err() diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 7d538b02a..f05ee7eef 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -54,6 +54,7 @@ type Response struct { PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"` CompletionTokens int `json:"eval_count,omitempty"` CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"` + PeakMemory uint64 `json:"peak_memory,omitempty"` TotalTokens int `json:"total_tokens,omitempty"` }