diff --git a/api/types.go b/api/types.go index 9c2031456..561b75ef3 100644 --- a/api/types.go +++ b/api/types.go @@ -15,7 +15,6 @@ import ( "github.com/google/uuid" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/format" "github.com/ollama/ollama/internal/orderedmap" "github.com/ollama/ollama/types/model" ) @@ -570,7 +569,6 @@ type DebugInfo struct { type Metrics struct { TotalDuration time.Duration `json:"total_duration,omitempty"` - PeakMemory uint64 `json:"peak_memory,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` @@ -936,10 +934,6 @@ func (m *Metrics) Summary() { fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) } - if m.PeakMemory > 0 { - fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory)) - } - if m.LoadDuration > 0 { fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) } @@ -963,14 +957,6 @@ func (m *Metrics) Summary() { } } -func formatPeakMemory(b uint64) string { - if b >= format.GibiByte { - return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte)) - } - - return format.HumanBytes2(b) -} - func (opts *Options) FromMap(m map[string]any) error { valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct diff --git a/llm/server.go b/llm/server.go index cd9014a29..291dd47fe 100644 --- a/llm/server.go +++ b/llm/server.go @@ -1518,7 +1518,6 @@ type CompletionResponse struct { PromptEvalDuration time.Duration `json:"prompt_eval_duration"` EvalCount int `json:"eval_count"` EvalDuration time.Duration `json:"eval_duration"` - PeakMemory uint64 `json:"peak_memory,omitempty"` // Logprobs contains log probability information if requested Logprobs []Logprob `json:"logprobs,omitempty"` diff --git a/server/routes.go b/server/routes.go index a5f6d3f88..a27ce3a96 100644 --- a/server/routes.go +++ b/server/routes.go @@ -558,7 +558,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { PromptEvalDuration: cr.PromptEvalDuration, EvalCount: cr.EvalCount, EvalDuration: cr.EvalDuration, - PeakMemory: cr.PeakMemory, }, Logprobs: toAPILogprobs(cr.Logprobs), } @@ -2317,7 +2316,6 @@ func (s *Server) ChatHandler(c *gin.Context) { PromptEvalDuration: r.PromptEvalDuration, EvalCount: r.EvalCount, EvalDuration: r.EvalDuration, - PeakMemory: r.PeakMemory, }, Logprobs: toAPILogprobs(r.Logprobs), } diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index b6ca5339f..c4f8c77ce 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -202,7 +202,6 @@ type CompletionResponse struct { PromptEvalDuration time.Duration EvalCount int EvalDuration time.Duration - PeakMemory uint64 Error *api.StatusError } @@ -284,7 +283,6 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f PromptEvalDuration: raw.PromptEvalDuration, EvalCount: raw.EvalCount, EvalDuration: raw.EvalDuration, - PeakMemory: raw.PeakMemory, } fn(cresp) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 73b485358..9061029ae 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -21,6 +21,17 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return errors.New("model not loaded") } + enableCompile := true + if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok { + enableCompile = modelCompile.EnableCompile() + } + if enableCompile { + mlx.EnableCompile() + } else { + mlx.DisableCompile() + } + mlx.ResetPeakMemory() + var ( sample, logprobs *mlx.Array nextSample, nextLogprobs *mlx.Array @@ -36,19 +47,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error { mlx.LogArrays() r.cache.log() } + slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory())) }() - enableCompile := true - if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok { - enableCompile = modelCompile.EnableCompile() - } - if enableCompile { - mlx.EnableCompile() - } else { - mlx.DisableCompile() - } - mlx.ResetPeakMemory() - inputs := r.Tokenizer.Encode(request.Prompt, true) if len(inputs) == 0 { return errors.New("empty prompt") @@ -156,7 +157,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } final.EvalDuration = time.Since(now) - final.PeakMemory = uint64(mlx.PeakMemory()) select { case <-request.Ctx.Done(): return request.Ctx.Err()