diff --git a/server/images.go b/server/images.go index 05dfe1468..fad4e102a 100644 --- a/server/images.go +++ b/server/images.go @@ -71,6 +71,10 @@ type Model struct { Template *template.Template } +func (m *Model) IsMLX() bool { + return m.Config.ModelFormat == "safetensors" +} + // Capabilities returns the capabilities that the model supports func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} diff --git a/server/prompt.go b/server/prompt.go index c9a7702b8..0737fa215 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. lastMsgIdx := len(msgs) - 1 currMsgIdx := 0 - // Start with all messages and remove from the front until it fits in context - for i := 0; i <= lastMsgIdx; i++ { - // Collect system messages from the portion we're about to skip - system = make([]api.Message, 0) - for j := range i { - if msgs[j].Role == "system" { - system = append(system, msgs[j]) + if truncate { + // Start with all messages and remove from the front until it fits in context + for i := 0; i <= lastMsgIdx; i++ { + // Collect system messages from the portion we're about to skip + system = make([]api.Message, 0) + for j := range i { + if msgs[j].Role == "system" { + system = append(system, msgs[j]) + } } - } - p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think) - if err != nil { - return "", nil, err - } - - s, err := tokenize(ctx, p) - if err != nil { - return "", nil, err - } - - ctxLen := len(s) - if m.ProjectorPaths != nil { - for _, msg := range msgs[i:] { - ctxLen += imageNumTokens * len(msg.Images) + p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think) + if err != nil { + return "", nil, err } - } - if !truncate || ctxLen <= opts.NumCtx { - currMsgIdx = i - break - } + s, err := tokenize(ctx, p) + if err != nil { + return "", nil, err + } - // Must always include at least the last message - if i == lastMsgIdx { - currMsgIdx = lastMsgIdx - break + ctxLen := len(s) + if m.ProjectorPaths != nil { + for _, msg := range msgs[i:] { + ctxLen += imageNumTokens * len(msg.Images) + } + } + + if ctxLen <= opts.NumCtx { + currMsgIdx = i + break + } + + // Must always include at least the last message + if i == lastMsgIdx { + currMsgIdx = lastMsgIdx + break + } } } diff --git a/server/routes.go b/server/routes.go index 123993af2..a5f6d3f88 100644 --- a/server/routes.go +++ b/server/routes.go @@ -484,7 +484,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { // the real chat handler, but doing this as a stopgap to get renderer // support for generate if values.Messages != nil && values.Suffix == "" && req.Template == "" { - prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate) + genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX() + prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -2217,6 +2218,9 @@ func (s *Server) ChatHandler(c *gin.Context) { } truncate := req.Truncate == nil || *req.Truncate + if m.IsMLX() { + truncate = false + } prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) if err != nil { slog.Error("chat prompt error", "error", err) diff --git a/server/sched.go b/server/sched.go index af768cf56..4a64223e5 100644 --- a/server/sched.go +++ b/server/sched.go @@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) { } // Check for experimental safetensors LLM models - if pending.model.Config.ModelFormat == "safetensors" { + if pending.model.IsMLX() { if slices.Contains(pending.model.Config.Capabilities, "completion") { // LLM model with safetensors format - use MLX runner if s.loadMLX(pending) { @@ -764,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool defer cancel() if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed? !reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed? - !reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed? + (!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed? runner.llama.Ping(ctx) != nil { return true } diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index 58f8c87bc..b6ca5339f 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "log/slog" - "math" "math/rand" "net" "net/http" @@ -30,15 +29,16 @@ import ( // Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. type Client struct { - port int - modelName string - memory atomic.Uint64 - done chan error - client *http.Client - lastErr string - lastErrLock sync.Mutex - mu sync.Mutex - cmd *exec.Cmd + port int + modelName string + contextLength atomic.Int64 + memory atomic.Uint64 + done chan error + client *http.Client + lastErr string + lastErrLock sync.Mutex + mu sync.Mutex + cmd *exec.Cmd } // NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready. @@ -297,7 +297,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f } func (c *Client) ContextLength() int { - return math.MaxInt + return int(c.contextLength.Load()) } // Detokenize implements llm.LlamaServer. @@ -351,9 +351,10 @@ func (c *Client) Pid() int { } type statusResponse struct { - Status int - Progress int - Memory uint64 + Status int + Progress int + ContextLength int + Memory uint64 } // Ping implements llm.LlamaServer. @@ -376,7 +377,10 @@ func (c *Client) Ping(ctx context.Context) error { if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { return err } + + c.contextLength.Store(int64(status.ContextLength)) c.memory.Store(status.Memory) + return nil } diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go index e35eddd2d..4cdf6df33 100644 --- a/x/mlxrunner/model/base/base.go +++ b/x/mlxrunner/model/base/base.go @@ -20,6 +20,7 @@ type Model interface { Unembed(x *mlx.Array) *mlx.Array NumLayers() int Tokenizer() *tokenizer.Tokenizer + MaxContextLength() int // LoadWeights receives all tensors loaded from the manifest and assigns // them to model fields. Model-specific logic (MLA absorption, expert diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 24aaccdc0..405225c45 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -6,9 +6,12 @@ import ( "bytes" "context" "errors" + "fmt" "log/slog" + "net/http" "time" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/mlx" ) @@ -51,9 +54,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return errors.New("empty prompt") } + if len(inputs) >= r.contextLength { + return api.StatusError{ + StatusCode: http.StatusBadRequest, + ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength), + } + } + + // Cap generation to stay within the model's context length + maxGenerate := r.contextLength - len(inputs) + if request.Options.MaxTokens <= 0 { + request.Options.MaxTokens = maxGenerate + } else { + request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate) + } + session := r.cache.begin(r.Model, inputs) defer session.close() - caches := session.caches tokens := session.remaining diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index fe1ec0f0b..5fe06bcd5 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -45,10 +45,11 @@ type TextCompletionsRequest struct { } type Runner struct { - Model base.Model - Tokenizer *tokenizer.Tokenizer - Requests chan Request - cache kvCache + Model base.Model + Tokenizer *tokenizer.Tokenizer + Requests chan Request + cache kvCache + contextLength int } func (r *Runner) Load(modelName string) error { @@ -77,6 +78,7 @@ func (r *Runner) Load(modelName string) error { r.Model = m r.Tokenizer = m.Tokenizer() + r.contextLength = m.MaxContextLength() return nil } diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index c44f795b6..436b47e59 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -51,9 +51,10 @@ func Execute(args []string) error { mux := http.NewServeMux() mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { if err := json.NewEncoder(w).Encode(statusResponse{ - Status: 0, - Progress: 100, - Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()), + Status: 0, + Progress: 100, + ContextLength: runner.contextLength, + Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()), }); err != nil { slog.Error("Failed to encode response", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -88,9 +89,6 @@ func Execute(args []string) error { } request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict) - if request.Options.MaxTokens < 1 { - request.Options.MaxTokens = 16 << 10 - } request.Pipeline = runner.TextGenerationPipeline request.Sampler = sample.New( diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index 3d2c398b7..edf66657c 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -430,6 +430,10 @@ func (m *Model) NumLayers() int { return len(m.Layers) } +func (m *Model) MaxContextLength() int { + return int(m.MaxPositionEmbeddings) +} + func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index b79e245da..fb9c4af6f 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array { func (m *Model) NumLayers() int { return len(m.Layers) } // MaxContextLength returns the maximum context length -func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings } +func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) } // VocabSize returns the vocabulary size func (m *Model) VocabSize() int32 { return m.Config.VocabSize } diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index bef98fbb4..fc7f34488 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -262,6 +262,10 @@ func (m *Model) NumLayers() int { return len(m.Layers) } +func (m *Model) MaxContextLength() int { + return int(m.MaxPositionEmbeddings) +} + func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 392f90755..85d427f58 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -279,6 +279,10 @@ func (m *Model) NumLayers() int { return len(m.Layers) } +func (m *Model) MaxContextLength() int { + return int(m.MaxPositionEmbeddings) +} + func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }