mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
- Collapse MLX sampling state into a single sample.Sampler struct (options + history). - Replace interface-based sampler chain (TopP, TopK, penalty, etc.) with function-based transforms. - Update request/pipeline wiring to use *sample.Sampler, seed history from prompt tokens, and append generated tokens each step. - Implement top_p, min_p, repeat_penalty, and frequency_penalty
201 lines
4.9 KiB
Go
201 lines
4.9 KiB
Go
//go:build mlx
|
|
|
|
package mlxrunner
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/envconfig"
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/sample"
|
|
)
|
|
|
|
func Execute(args []string) error {
|
|
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
|
|
|
if err := mlx.CheckInit(); err != nil {
|
|
return fmt.Errorf("MLX not available: %w", err)
|
|
}
|
|
|
|
var (
|
|
modelName string
|
|
port int
|
|
)
|
|
|
|
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
|
|
flagSet.StringVar(&modelName, "model", "", "Model name")
|
|
flagSet.IntVar(&port, "port", 0, "Port to listen on")
|
|
_ = flagSet.Bool("verbose", false, "Enable debug logging")
|
|
flagSet.Parse(args)
|
|
|
|
runner := Runner{
|
|
Requests: make(chan Request),
|
|
}
|
|
|
|
if err := runner.Load(modelName); err != nil {
|
|
return err
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
|
if err := json.NewEncoder(w).Encode(statusResponse{
|
|
Status: 0,
|
|
Progress: 100,
|
|
ContextLength: runner.contextLength,
|
|
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
|
}); err != nil {
|
|
slog.Error("Failed to encode response", "error", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
})
|
|
|
|
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case "POST":
|
|
fallthrough
|
|
case "GET":
|
|
if err := json.NewEncoder(w).Encode(map[string]any{
|
|
"Success": true,
|
|
}); err != nil {
|
|
slog.Error("Failed to encode response", "error", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
case "DELETE":
|
|
// TODO: cleanup model and cache
|
|
}
|
|
})
|
|
|
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
request := Request{Responses: make(chan CompletionResponse)}
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
|
slog.Error("Failed to decode request", "error", err)
|
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
|
|
|
request.Pipeline = runner.TextGenerationPipeline
|
|
request.Sampler = sample.New(
|
|
request.Options.Temperature,
|
|
request.Options.TopP,
|
|
request.Options.MinP,
|
|
request.Options.TopK,
|
|
request.Options.RepeatLastN,
|
|
request.Options.PresencePenalty,
|
|
)
|
|
|
|
var cancel context.CancelFunc
|
|
request.Ctx, cancel = context.WithCancel(r.Context())
|
|
defer cancel()
|
|
|
|
select {
|
|
case <-r.Context().Done():
|
|
return
|
|
case runner.Requests <- request:
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/jsonl")
|
|
w.WriteHeader(http.StatusOK)
|
|
enc := json.NewEncoder(w)
|
|
for {
|
|
select {
|
|
case <-r.Context().Done():
|
|
return
|
|
case response, ok := <-request.Responses:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if err := enc.Encode(response); err != nil {
|
|
slog.Error("Failed to encode response", "error", err)
|
|
return
|
|
}
|
|
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
|
|
var b bytes.Buffer
|
|
if _, err := io.Copy(&b, r.Body); err != nil {
|
|
slog.Error("Failed to read request body", "error", err)
|
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
tokens := runner.Tokenizer.Encode(b.String(), true)
|
|
|
|
if err := json.NewEncoder(w).Encode(tokens); err != nil {
|
|
slog.Error("Failed to encode response", "error", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
})
|
|
|
|
for source, target := range map[string]string{
|
|
"GET /health": "/v1/status",
|
|
"POST /load": "/v1/models",
|
|
"POST /completion": "/v1/completions",
|
|
} {
|
|
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
|
|
}
|
|
|
|
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
|
|
t := time.Now()
|
|
mux.ServeHTTP(recorder, r)
|
|
|
|
var level slog.Level
|
|
switch {
|
|
case recorder.code >= 500:
|
|
level = slog.LevelError
|
|
case recorder.code >= 400:
|
|
level = slog.LevelWarn
|
|
case recorder.code >= 300:
|
|
return
|
|
}
|
|
|
|
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
|
|
}))
|
|
}
|
|
|
|
type statusRecorder struct {
|
|
http.ResponseWriter
|
|
code int
|
|
}
|
|
|
|
func (w *statusRecorder) WriteHeader(code int) {
|
|
w.code = code
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (w *statusRecorder) Status() string {
|
|
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
|
|
}
|
|
|
|
func (w *statusRecorder) Flush() {
|
|
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
}
|