From d18dcd77755b55c9d761e483abee17d1e2b6c58c Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 13 Feb 2026 22:30:42 -0800 Subject: [PATCH] mlxrunner fixes (#14247) * load glm4_moe_lite from the mlxrunner * fix loading diffusion models * remove log lines * fix --imagegen flag --- cmd/cmd.go | 14 + server/routes.go | 5 +- server/routes_generate_test.go | 1 + server/sched.go | 36 ++- server/sched_test.go | 8 +- x/imagegen/manifest/weights.go | 11 +- x/mlxrunner/client.go | 354 ++++++++++++++++++++---- x/mlxrunner/imports.go | 7 + x/mlxrunner/mlx/array.go | 3 +- x/mlxrunner/mlx/dynamic.go | 81 ++++-- x/mlxrunner/mlx/ops_extra.go | 35 ++- x/mlxrunner/model/base/base.go | 85 ++++++ x/mlxrunner/model/base/base_stub.go | 3 + x/mlxrunner/model/root.go | 97 +++++++ x/mlxrunner/model/root_stub.go | 3 + x/mlxrunner/pipeline.go | 7 +- x/mlxrunner/runner.go | 113 +++++--- x/mlxrunner/sample/sample.go | 2 +- x/models/glm4_moe_lite/glm4_moe_lite.go | 180 ++++-------- 19 files changed, 764 insertions(+), 281 deletions(-) create mode 100644 x/mlxrunner/imports.go create mode 100644 x/mlxrunner/model/base/base.go create mode 100644 x/mlxrunner/model/base/base_stub.go create mode 100644 x/mlxrunner/model/root.go create mode 100644 x/mlxrunner/model/root_stub.go diff --git a/cmd/cmd.go b/cmd/cmd.go index d18e7f66a..bd0648948 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -581,6 +581,17 @@ func RunHandler(cmd *cobra.Command, args []string) error { } opts.WordWrap = !nowrap + useImagegen := false + if cmd.Flags().Lookup("imagegen") != nil { + useImagegen, err = cmd.Flags().GetBool("imagegen") + if err != nil { + return err + } + } + if useImagegen { + opts.Options["use_imagegen_runner"] = true + } + // Fill out the rest of the options based on information about the // model. client, err := api.ClientFromEnvironment() @@ -2141,6 +2152,9 @@ func NewCLI() *cobra.Command { // Image generation flags (width, height, steps, seed, etc.) imagegen.RegisterFlags(runCmd) + runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference") + runCmd.Flags().MarkHidden("imagegen") + stopCmd := &cobra.Command{ Use: "stop MODEL", Short: "Stop a running model", diff --git a/server/routes.go b/server/routes.go index 26cec3544..cbe771d9f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -150,12 +150,15 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return nil, nil, nil, fmt.Errorf("%s %w", name, err) } + useImagegen, _ := requestOpts["use_imagegen_runner"].(bool) + delete(requestOpts, "use_imagegen_runner") + opts, err := s.modelOptions(model, requestOpts) if err != nil { return nil, nil, nil, err } - runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive) + runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen) var runner *runnerRef select { case runner = <-runnerCh: diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 7e21d80aa..677fef369 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -2383,6 +2383,7 @@ func TestImageGenerateStreamFalse(t *testing.T) { llama: &mock, Options: &opts, model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}}, + isImagegen: true, numParallel: 1, }, }, diff --git a/server/sched.go b/server/sched.go index e81c895bc..728ec47b6 100644 --- a/server/sched.go +++ b/server/sched.go @@ -22,6 +22,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/mlxrunner" ) type LlmRequest struct { @@ -32,6 +33,7 @@ type LlmRequest struct { successCh chan *runnerRef errCh chan error schedAttempts uint + useImagegen bool } type Scheduler struct { @@ -82,7 +84,7 @@ func InitScheduler(ctx context.Context) *Scheduler { } // context must be canceled to decrement ref count and release the runner -func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) { +func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) { if opts.NumCtx < 4 { opts.NumCtx = 4 } @@ -99,6 +101,7 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses sessionDuration: sessionDuration, successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), + useImagegen: useImagegen, } s.loadedMu.Lock() @@ -566,17 +569,20 @@ iGPUScan: // loadMLX loads an experimental safetensors model using the unified MLX runner. // This supports both LLM (completion) and image generation models. func (s *Scheduler) loadMLX(req *LlmRequest) bool { - // Determine mode based on capabilities - var mode imagegen.ModelMode - if slices.Contains(req.model.Config.Capabilities, "image") { - mode = imagegen.ModeImageGen - } else { - mode = imagegen.ModeLLM - } - - // Use model name for MLX (it resolves manifests by name, not file path) modelName := req.model.ShortName - server, err := imagegen.NewServer(modelName, mode) + var server llm.LlamaServer + var err error + + isImagegen := false + if slices.Contains(req.model.Config.Capabilities, "image") { + server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen) + isImagegen = true + } else if req.useImagegen { + server, err = imagegen.NewServer(modelName, imagegen.ModeLLM) + isImagegen = true + } else { + server, err = mlxrunner.NewClient(modelName) + } if err != nil { req.errCh <- err return true @@ -593,6 +599,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool { llama: server, Options: &req.opts, loading: false, + isImagegen: isImagegen, sessionDuration: sessionDuration, totalSize: server.TotalSize(), vramSize: server.VRAMSize(), @@ -667,6 +674,7 @@ type runnerRef struct { loading bool // True only during initial load, then false forever gpus []ml.DeviceID // Recorded at time of provisioning discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs + isImagegen bool // True if loaded via imagegen runner (vs mlxrunner) vramSize uint64 totalSize uint64 @@ -699,6 +707,12 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool runner.refMu.Lock() defer runner.refMu.Unlock() + // Check if runner type (imagegen vs mlxrunner) matches what's requested + wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image") + if runner.isImagegen != wantImagegen { + return true + } + timeout := 10 * time.Second if runner.loading { timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems... diff --git a/server/sched_test.go b/server/sched_test.go index 7eaf4a9f9..732e5b3cc 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) { s.getSystemInfoFn = getSystemInfoFn s.newServerFn = a.newServer slog.Info("a") - successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration) + successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false) require.Len(t, s.pendingReqCh, 1) slog.Info("b") - successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration) + successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false) require.Len(t, s.pendingReqCh, 1) require.Empty(t, successCh1b) require.Len(t, errCh1b, 1) @@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) { c.req.model.ModelPath = "bad path" slog.Info("c") - successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration) + successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false) // Starts in pending channel, then should be quickly processed to return an error time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload require.Empty(t, successCh1c) @@ -509,7 +509,7 @@ func TestSchedPrematureExpired(t *testing.T) { s.getGpuFn = getGpuFn s.getSystemInfoFn = getSystemInfoFn s.newServerFn = scenario1a.newServer - successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration) + successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false) require.Len(t, s.pendingReqCh, 1) s.Run(ctx) select { diff --git a/x/imagegen/manifest/weights.go b/x/imagegen/manifest/weights.go index e0ad0399c..e1209c9db 100644 --- a/x/imagegen/manifest/weights.go +++ b/x/imagegen/manifest/weights.go @@ -102,15 +102,20 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error { for _, entry := range entries { name := entry.name - // Try to get tensor by stripped name first, then with component prefix. - // Blobs may store tensors with the full prefixed name (e.g., "text_encoder/model.layers.0.weight") - // while the tensors map uses stripped names (e.g., "model.layers.0.weight"). + // Try to get tensor by stripped name first, then with component prefix, + // then fall back to "data" for legacy blobs created by older versions + // that stored all tensors with the generic key "data". lookupName := name arr := sf.Get(lookupName) if arr == nil && mw.component != "" { lookupName = mw.component + "/" + name arr = sf.Get(lookupName) } + if arr == nil { + // Legacy blob format: tensor stored as "data" + lookupName = "data" + arr = sf.Get(lookupName) + } if arr != nil { // Single-tensor blob or tensor found by name if dtype != 0 && arr.Dtype() != dtype { diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index e3e5157ab..19e987736 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -2,76 +2,298 @@ package mlxrunner import ( "bufio" - "bytes" "context" "encoding/json" "errors" + "fmt" + "io" + "log/slog" "math" + "math/rand" "net" "net/http" - "net/url" + "os" "os/exec" + "path/filepath" + "runtime" "strconv" "strings" + "sync" + "time" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/manifest" ) +// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. type Client struct { - Port int - *exec.Cmd + port int + modelName string + vramSize uint64 + done chan error + client *http.Client + lastErr string + lastErrLock sync.Mutex + mu sync.Mutex + cmd *exec.Cmd } -func (c *Client) JoinPath(path string) string { - return (&url.URL{ - Scheme: "http", - Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)), - }).JoinPath(path).String() +// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready. +func NewClient(modelName string) (*Client, error) { + if err := imagegen.CheckPlatformSupport(); err != nil { + return nil, err + } + + // Find a free port + port := 0 + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + if l, err := net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() + } + } + if port == 0 { + port = rand.Intn(65535-49152) + 49152 + } + + // Get the current executable path + exe, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("unable to lookup executable path: %w", err) + } + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + + // Spawn subprocess: ollama runner --mlx-engine --model --port + cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port)) + cmd.Env = os.Environ() + + // On Linux, set LD_LIBRARY_PATH to include MLX library directories + if runtime.GOOS == "linux" { + libraryPaths := []string{ml.LibOllamaPath} + if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil { + libraryPaths = append(libraryPaths, mlxDirs...) + } + + if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok { + libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...) + } + + pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) + + found := false + for i := range cmd.Env { + if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") { + cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal + found = true + break + } + } + if !found { + cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal) + } + slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) + } + + // Estimate VRAM based on tensor size from manifest + var vramSize uint64 + if modelManifest, err := manifest.LoadManifest(modelName); err == nil { + vramSize = uint64(modelManifest.TotalTensorSize()) + } else { + vramSize = 8 * 1024 * 1024 * 1024 + } + + c := &Client{ + port: port, + modelName: modelName, + vramSize: vramSize, + done: make(chan error, 1), + client: &http.Client{Timeout: 10 * time.Minute}, + cmd: cmd, + } + + // Forward subprocess stdout/stderr to server logs + stdout, _ := cmd.StdoutPipe() + stderr, _ := cmd.StderrPipe() + go func() { + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + slog.Info("mlx-runner", "msg", scanner.Text()) + } + }() + go func() { + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + line := scanner.Text() + slog.Warn("mlx-runner", "msg", line) + c.lastErrLock.Lock() + c.lastErr = line + c.lastErrLock.Unlock() + } + }() + + slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port) + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start mlx runner: %w", err) + } + + // Reap subprocess when it exits + go func() { + err := cmd.Wait() + c.done <- err + }() + + // Wait for subprocess to be ready + if err := c.waitUntilRunning(); err != nil { + c.Close() + return nil, err + } + + return c, nil } -func (c *Client) CheckError(w *http.Response) error { - if w.StatusCode >= 400 { - return errors.New(w.Status) +func (c *Client) getLastErr() string { + c.lastErrLock.Lock() + defer c.lastErrLock.Unlock() + return c.lastErr +} + +func (c *Client) waitUntilRunning() error { + ctx := context.Background() + timeout := time.After(2 * time.Minute) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case err := <-c.done: + errMsg := c.getLastErr() + if errMsg != "" { + return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err) + } + return fmt.Errorf("mlx runner exited unexpectedly: %w", err) + case <-timeout: + errMsg := c.getLastErr() + if errMsg != "" { + return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg) + } + return errors.New("timeout waiting for mlx runner to start") + case <-ticker.C: + if err := c.Ping(ctx); err == nil { + slog.Info("mlx runner is ready", "port", c.port) + return nil + } + } + } +} + +// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization. +type completionRequest struct { + Prompt string `json:"prompt"` + Options *completionOpts `json:"options,omitempty"` +} + +type completionOpts struct { + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + MinP float32 `json:"min_p,omitempty"` + TopK int `json:"top_k,omitempty"` + NumPredict int `json:"num_predict,omitempty"` +} + +// Close terminates the subprocess. +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.cmd != nil && c.cmd.Process != nil { + slog.Info("stopping mlx runner subprocess", "pid", c.cmd.Process.Pid) + c.cmd.Process.Signal(os.Interrupt) + + select { + case <-c.done: + case <-time.After(5 * time.Second): + c.cmd.Process.Kill() + } + c.cmd = nil } return nil } -// Close implements llm.LlamaServer. -func (c *Client) Close() error { - return c.Cmd.Process.Kill() -} - // Completion implements llm.LlamaServer. func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(req); err != nil { - return err + creq := completionRequest{ + Prompt: req.Prompt, + } + if req.Options != nil { + creq.Options = &completionOpts{ + Temperature: req.Options.Temperature, + TopP: req.Options.TopP, + MinP: req.Options.MinP, + TopK: req.Options.TopK, + NumPredict: req.Options.NumPredict, + } } - w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b) + body, err := json.Marshal(creq) if err != nil { return err } - defer w.Body.Close() - if err := c.CheckError(w); err != nil { + httpURL := fmt.Sprintf("http://127.0.0.1:%d/completion", c.port) + httpReq, err := http.NewRequestWithContext(ctx, "POST", httpURL, strings.NewReader(string(body))) + if err != nil { return err } + httpReq.Header.Set("Content-Type", "application/json") - scanner := bufio.NewScanner(w.Body) - for scanner.Scan() { - bts := scanner.Bytes() + resp, err := c.client.Do(httpReq) + if err != nil { + return err + } + defer resp.Body.Close() - var resp llm.CompletionResponse - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } - - fn(resp) + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("%s", strings.TrimSpace(string(respBody))) } - return nil + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + var raw struct { + Content string `json:"content,omitempty"` + Done bool `json:"done"` + DoneReason int `json:"done_reason,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration int `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` + } + if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil { + slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes())) + continue + } + + cresp := llm.CompletionResponse{ + Content: raw.Content, + Done: raw.Done, + DoneReason: llm.DoneReason(raw.DoneReason), + PromptEvalCount: raw.PromptEvalCount, + PromptEvalDuration: time.Duration(raw.PromptEvalDuration), + EvalCount: raw.EvalCount, + EvalDuration: time.Duration(raw.EvalDuration), + } + + fn(cresp) + if cresp.Done { + return nil + } + } + + return scanner.Err() } func (c *Client) ContextLength() int { @@ -80,71 +302,89 @@ func (c *Client) ContextLength() int { // Detokenize implements llm.LlamaServer. func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) { - panic("unimplemented") + return "", errors.New("not supported") } // Embedding implements llm.LlamaServer. func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) { - panic("unimplemented") + return nil, 0, errors.New("not supported") } // GetDeviceInfos implements llm.LlamaServer. func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { - panic("unimplemented") + return nil } // GetPort implements llm.LlamaServer. func (c *Client) GetPort() int { - return c.Port + return c.port } // HasExited implements llm.LlamaServer. func (c *Client) HasExited() bool { - panic("unimplemented") + select { + case <-c.done: + return true + default: + return false + } } // Load implements llm.LlamaServer. func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) { - w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil) - if err != nil { - return nil, err - } - defer w.Body.Close() - - return []ml.DeviceID{}, nil + return nil, nil } // ModelPath implements llm.LlamaServer. func (c *Client) ModelPath() string { - panic("unimplemented") + return c.modelName } // Pid implements llm.LlamaServer. func (c *Client) Pid() int { - panic("unimplemented") + c.mu.Lock() + defer c.mu.Unlock() + if c.cmd != nil && c.cmd.Process != nil { + return c.cmd.Process.Pid + } + return -1 } // Ping implements llm.LlamaServer. func (c *Client) Ping(ctx context.Context) error { - w, err := http.Get(c.JoinPath("/v1/status")) + reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port) + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) if err != nil { return err } - defer w.Body.Close() - + resp, err := c.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("health check failed: %d", resp.StatusCode) + } return nil } // Tokenize implements llm.LlamaServer. func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) { - w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content)) + reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/tokenize", c.port) + req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(content)) if err != nil { return nil, err } - defer w.Body.Close() + req.Header.Set("Content-Type", "text/plain") + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() var tokens []int - if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil { + if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil { return nil, err } @@ -153,22 +393,22 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) { // TotalSize implements llm.LlamaServer. func (c *Client) TotalSize() uint64 { - panic("unimplemented") + return c.vramSize } // VRAMByGPU implements llm.LlamaServer. func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 { - panic("unimplemented") + return c.vramSize } // VRAMSize implements llm.LlamaServer. func (c *Client) VRAMSize() uint64 { - panic("unimplemented") + return c.vramSize } // WaitUntilRunning implements llm.LlamaServer. func (c *Client) WaitUntilRunning(ctx context.Context) error { - panic("unimplemented") + return nil } var _ llm.LlamaServer = (*Client)(nil) diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go new file mode 100644 index 000000000..e8950eff8 --- /dev/null +++ b/x/mlxrunner/imports.go @@ -0,0 +1,7 @@ +//go:build mlx + +package mlxrunner + +import ( + _ "github.com/ollama/ollama/x/models/glm4_moe_lite" +) diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index bec8d3444..43254d230 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -133,6 +133,7 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array { } func (t *Array) Set(other *Array) { + Free(t.desc.inputs...) other.desc.numRefs++ t.desc.inputs = []*Array{other} C.mlx_array_set(&t.ctx, other.ctx) @@ -248,9 +249,9 @@ func Free(s ...*Array) (n int) { free := make([]*Array, 0, 8192) fn := func(t *Array) { if t.Valid() { - free = append(free, t.desc.inputs...) t.desc.numRefs-- if t.desc.numRefs <= 0 { + free = append(free, t.desc.inputs...) logutil.Trace("Free", "t", t) n += t.NumBytes() C.mlx_array_free(t.ctx) diff --git a/x/mlxrunner/mlx/dynamic.go b/x/mlxrunner/mlx/dynamic.go index 0c2306e46..b142cc5f8 100644 --- a/x/mlxrunner/mlx/dynamic.go +++ b/x/mlxrunner/mlx/dynamic.go @@ -24,6 +24,37 @@ func CheckInit() error { return initError } +// tryLoadFromDir searches a directory for libmlxc.* and tries to load it. +// Returns true if the library was successfully loaded. +func tryLoadFromDir(dir string) bool { + matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*") + if err != nil || len(matches) == 0 { + return false + } + + for _, match := range matches { + path := filepath.Join(dir, match) + + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + var handle C.mlx_dynamic_handle + if C.mlx_dynamic_load(&handle, cPath) != 0 { + slog.Error("Failed to load MLX dynamic library", "path", path) + continue + } + + if C.mlx_dynamic_load_symbols(handle) != 0 { + slog.Error("Failed to load MLX dynamic library symbols", "path", path) + C.mlx_dynamic_unload(&handle) + continue + } + + return true + } + return false +} + func init() { switch runtime.GOOS { case "darwin": @@ -33,44 +64,34 @@ func init() { return } - paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH") - if !ok { - slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading") - return + // Try OLLAMA_LIBRARY_PATH first + if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok { + for _, dir := range filepath.SplitList(paths) { + if tryLoadFromDir(dir) { + return + } + } } - for _, path := range filepath.SplitList(paths) { - matches, err := fs.Glob(os.DirFS(path), "libmlxc.*") - if err != nil { - initError = fmt.Errorf("failed to glob for MLX libraries in %s: %w", path, err) - slog.Warn("MLX dynamic library not available", "error", initError) - return + // Build search paths: executable directory, then build directories + var searchDirs []string + if exe, err := os.Executable(); err == nil { + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval } + searchDirs = append(searchDirs, filepath.Dir(exe)) + } - for _, match := range matches { - path := filepath.Join(paths, match) - slog.Info("Loading MLX dynamic library", "path", path) + if cwd, err := os.Getwd(); err == nil { + searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama")) + } - cPath := C.CString(path) - defer C.free(unsafe.Pointer(cPath)) - - var handle C.mlx_dynamic_handle - if C.mlx_dynamic_load(&handle, cPath) != 0 { - slog.Error("Failed to load MLX dynamic library", "path", path) - continue - } - - if C.mlx_dynamic_load_symbols(handle) != 0 { - slog.Error("Failed to load MLX dynamic library symbols", "path", path) - C.mlx_dynamic_unload(&handle) - continue - } - - slog.Info("Loaded MLX dynamic library", "path", path) + for _, dir := range searchDirs { + if tryLoadFromDir(dir) { return } } - initError = fmt.Errorf("failed to load any MLX dynamic library from OLLAMA_LIBRARY_PATH=%s", paths) + initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs) slog.Warn("MLX dynamic library not available", "error", initError) } diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index e5444a4f8..f2882e989 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -306,19 +306,42 @@ func AddMM(c, a, b *Array, alpha, beta float32) *Array { // Scalar helpers +// scalarWithDtype creates a scalar array matching the dtype of a. +// Matching dtype is important for graph fusion and avoiding implicit casts. +func scalarWithDtype(s float32, a *Array) C.mlx_array { + f32 := C.mlx_array_new_float(C.float(s)) + dtype := a.DType() + if dtype == DTypeFloat32 { + return f32 + } + casted := C.mlx_array_new() + C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx) + C.mlx_array_free(f32) + return casted +} + func AddScalar(a *Array, s float32) *Array { - scalar := FromValue(s) - return a.Add(scalar) + scalar := scalarWithDtype(s, a) + out := New("ADD_SCALAR", a) + C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx) + C.mlx_array_free(scalar) + return out } func MulScalar(a *Array, s float32) *Array { - scalar := FromValue(s) - return a.Multiply(scalar) + scalar := scalarWithDtype(s, a) + out := New("MUL_SCALAR", a) + C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx) + C.mlx_array_free(scalar) + return out } func DivScalar(a *Array, s float32) *Array { - scalar := FromValue(s) - return a.Divide(scalar) + scalar := scalarWithDtype(s, a) + out := New("DIV_SCALAR", a) + C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx) + C.mlx_array_free(scalar) + return out } func FloorDivideScalar(a *Array, s int32) *Array { diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go new file mode 100644 index 000000000..fcc8b8627 --- /dev/null +++ b/x/mlxrunner/model/base/base.go @@ -0,0 +1,85 @@ +//go:build mlx + +package base + +import ( + "encoding/json" + "fmt" + "log/slog" + "sync" + + "github.com/ollama/ollama/x/imagegen/tokenizer" + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" +) + +// Model is the interface that model implementations must satisfy. +type Model interface { + Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array + Unembed(x *mlx.Array) *mlx.Array + NumLayers() int + Tokenizer() *tokenizer.Tokenizer + + // LoadWeights receives all tensors loaded from the manifest and assigns + // them to model fields. Model-specific logic (MLA absorption, expert + // stacking, quantized layer creation) happens here. + LoadWeights(tensors map[string]*mlx.Array) error +} + +var ( + mu sync.Mutex + registry = make(map[string]func(root *model.Root) (Model, error)) +) + +// Register registers a model constructor by architecture name. +// Called from init() in model packages. Panics on duplicate registration. +func Register(arch string, fn func(root *model.Root) (Model, error)) { + mu.Lock() + defer mu.Unlock() + + if _, exists := registry[arch]; exists { + panic(fmt.Sprintf("model architecture %q already registered", arch)) + } + registry[arch] = fn +} + +// New reads config.json from the manifest, detects the architecture, looks up +// the registered constructor, and calls it to create the model (with config +// parsed and struct created, but weights not yet loaded). +func New(root *model.Root) (Model, error) { + configData, err := root.Manifest.ReadConfig("config.json") + if err != nil { + return nil, fmt.Errorf("failed to read config.json: %w", err) + } + + var archConfig struct { + Architectures []string `json:"architectures"` + } + if err := json.Unmarshal(configData, &archConfig); err != nil { + return nil, fmt.Errorf("failed to parse config.json: %w", err) + } + + if len(archConfig.Architectures) == 0 { + return nil, fmt.Errorf("no architectures found in config.json") + } + + arch := archConfig.Architectures[0] + slog.Info("Model architecture", "arch", arch) + + mu.Lock() + fn, ok := registry[arch] + mu.Unlock() + + if !ok { + return nil, fmt.Errorf("unsupported architecture: %s", arch) + } + + return fn(root) +} + +// Weights returns the model's LoadWeights method, which encapsulates all +// weight assignment and post-processing (MLA absorption, expert stacking). +func Weights(m Model) func(map[string]*mlx.Array) error { + return m.LoadWeights +} diff --git a/x/mlxrunner/model/base/base_stub.go b/x/mlxrunner/model/base/base_stub.go new file mode 100644 index 000000000..318d8f911 --- /dev/null +++ b/x/mlxrunner/model/base/base_stub.go @@ -0,0 +1,3 @@ +//go:build !mlx + +package base diff --git a/x/mlxrunner/model/root.go b/x/mlxrunner/model/root.go new file mode 100644 index 000000000..885647ab3 --- /dev/null +++ b/x/mlxrunner/model/root.go @@ -0,0 +1,97 @@ +//go:build mlx + +package model + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "os" + "strings" + + "github.com/ollama/ollama/x/imagegen/manifest" +) + +// Root wraps a ModelManifest with pre-scanned quantization metadata. +type Root struct { + Manifest *manifest.ModelManifest + quantType string + groupSize int +} + +// Open loads a manifest for the given model name and pre-scans the first +// tensor blob for quantization metadata (quant_type, group_size). +func Open(modelName string) (*Root, error) { + m, err := manifest.LoadManifest(modelName) + if err != nil { + return nil, err + } + + root := &Root{Manifest: m} + + // Pre-scan first tensor blob for quantization metadata + for _, layer := range m.GetTensorLayers("") { + blobPath := m.BlobPath(layer.Digest) + meta, err := readBlobMetadata(blobPath) + if err != nil || meta == nil { + continue + } + if qt := meta["quant_type"]; qt != "" { + root.quantType = strings.ToUpper(qt) + } + if gs := meta["group_size"]; gs != "" { + fmt.Sscanf(gs, "%d", &root.groupSize) + } + break // only check the first tensor blob + } + + return root, nil +} + +// Close is a no-op for now (future: release resources). +func (r *Root) Close() {} + +// QuantType returns the quantization type detected from tensor metadata. +func (r *Root) QuantType() string { return r.quantType } + +// GroupSize returns the quantization group size detected from tensor metadata. +func (r *Root) GroupSize() int { return r.groupSize } + +// readBlobMetadata reads the __metadata__ from a safetensors blob header. +func readBlobMetadata(path string) (map[string]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var headerSize uint64 + if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { + return nil, err + } + if headerSize > 1024*1024 { + return nil, fmt.Errorf("header too large: %d", headerSize) + } + + data := make([]byte, headerSize) + if _, err := io.ReadFull(f, data); err != nil { + return nil, err + } + + var header map[string]json.RawMessage + if err := json.Unmarshal(data, &header); err != nil { + return nil, err + } + + metaRaw, ok := header["__metadata__"] + if !ok { + return nil, nil + } + + var meta map[string]string + if err := json.Unmarshal(metaRaw, &meta); err != nil { + return nil, err + } + return meta, nil +} diff --git a/x/mlxrunner/model/root_stub.go b/x/mlxrunner/model/root_stub.go new file mode 100644 index 000000000..3fcda9c25 --- /dev/null +++ b/x/mlxrunner/model/root_stub.go @@ -0,0 +1,3 @@ +//go:build !mlx + +package model diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index c094e5a3b..b7650b68d 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -18,6 +18,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return errors.New("model not loaded") } + mlx.EnableCompile() + inputs := r.Tokenizer.Encode(request.Prompt, true) caches, tokens := r.FindNearestCache(inputs) @@ -47,7 +49,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { - logits := r.Model.Unembed(r.Model.Forward(token.ExpandDims(0), caches)) + fwd := r.Model.Forward(token.ExpandDims(0), caches) + logits := r.Model.Unembed(fwd) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logprobs := logits.Subtract(logits.Logsumexp(true)) @@ -60,7 +63,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { var b bytes.Buffer now := time.Now() - final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1} + final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1} outputs := make([]int32, 0, request.Options.MaxTokens) for i := range request.Options.MaxTokens { nextSample, nextLogprobs := step(sample) diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 0b84b5a44..826281c31 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -4,30 +4,22 @@ package mlxrunner import ( "context" - "encoding/json" - "fmt" "log/slog" "net" "net/http" + "strings" "time" "golang.org/x/sync/errgroup" - "github.com/ollama/ollama/x/imagegen/manifest" "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/mlxrunner/sample" - "github.com/ollama/ollama/x/models/glm4_moe_lite" ) -// TextModel is the interface that model implementations must satisfy. -type TextModel interface { - Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array - Unembed(x *mlx.Array) *mlx.Array - NumLayers() int -} - type Request struct { TextCompletionsRequest Responses chan Response @@ -66,52 +58,95 @@ type Response struct { } type Runner struct { - Model TextModel + Model base.Model Tokenizer *tokenizer.Tokenizer Requests chan Request CacheEntries map[int32]*CacheEntry } func (r *Runner) Load(modelName string) error { - modelManifest, err := manifest.LoadManifest(modelName) + root, err := model.Open(modelName) + if err != nil { + return err + } + defer root.Close() + + m, err := base.New(root) if err != nil { return err } - // Read config to detect architecture - configData, err := modelManifest.ReadConfig("config.json") + // Load all tensor blobs from manifest + tensors, err := loadTensorsFromManifest(root) if err != nil { - return fmt.Errorf("failed to read config.json: %w", err) + return err } - var archConfig struct { - Architectures []string `json:"architectures"` - } - if err := json.Unmarshal(configData, &archConfig); err != nil { - return fmt.Errorf("failed to parse config.json: %w", err) - } - - if len(archConfig.Architectures) == 0 { - return fmt.Errorf("no architectures found in config.json") - } - - slog.Info("Model architecture", "arch", archConfig.Architectures[0]) - - switch archConfig.Architectures[0] { - case "Glm4MoeLiteForCausalLM", "GLM4MoeLite": - model, err := glm4_moe_lite.LoadFromManifest(modelManifest) - if err != nil { - return fmt.Errorf("failed to load GLM4-MoE-Lite model: %w", err) - } - r.Model = model - r.Tokenizer = model.Tokenizer() - default: - return fmt.Errorf("unsupported architecture: %s", archConfig.Architectures[0]) + // Assign weights to model (model-specific logic) + loadWeights := base.Weights(m) + if err := loadWeights(tensors); err != nil { + return err } + r.Model = m + r.Tokenizer = m.Tokenizer() return nil } +// loadTensorsFromManifest loads all tensor blobs from the manifest into a +// flat map, deduplicating by digest and remapping safetensors key suffixes. +// +// Uses a two-phase approach: first loads all raw tensors, then remaps +// .bias → _qbias with complete knowledge of which base names have .scale +// entries. This avoids a race condition where Go map iteration order could +// cause .bias to be processed before .scale within the same blob. +func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) { + // Phase 1: Load all tensors raw from all blobs + rawTensors := make(map[string]*mlx.Array) + seen := make(map[string]bool) + for _, layer := range root.Manifest.GetTensorLayers("") { + if seen[layer.Digest] { + continue + } + seen[layer.Digest] = true + blobPath := root.Manifest.BlobPath(layer.Digest) + for name, arr := range mlx.Load(blobPath) { + rawTensors[name] = arr + } + } + + // Phase 2: Identify all base names that have .scale tensors and remap them + scaleBaseNames := make(map[string]bool) + allTensors := make(map[string]*mlx.Array, len(rawTensors)) + for name, arr := range rawTensors { + if strings.HasSuffix(name, ".scale") { + baseName := strings.TrimSuffix(name, ".scale") + allTensors[baseName+"_scale"] = arr + scaleBaseNames[baseName] = true + } + } + + // Phase 3: Process remaining tensors with complete scale knowledge + for name, arr := range rawTensors { + if strings.HasSuffix(name, ".scale") { + continue // already handled + } + if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") { + baseName := strings.TrimSuffix(name, ".bias") + if scaleBaseNames[baseName] { + allTensors[baseName+"_qbias"] = arr + } else { + allTensors[name] = arr + } + } else { + allTensors[name] = arr + } + } + + slog.Info("Loaded tensors from manifest", "count", len(allTensors)) + return allTensors, nil +} + func (r *Runner) Run(host, port string, mux http.Handler) error { g, ctx := errgroup.WithContext(context.Background()) diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index 3a2e7577d..b0656973f 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -52,7 +52,7 @@ func (c chain) Sample(logits *mlx.Array) *mlx.Array { type Temperature float32 func (t Temperature) Sample(logits *mlx.Array) *mlx.Array { - return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1) + return mlx.DivScalar(logits, float32(t)).Categorical(-1) } type TopP float32 diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index 091e95839..974213196 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -5,21 +5,24 @@ package glm4_moe_lite import ( - "encoding/binary" "encoding/json" "fmt" - "io" "math" - "os" "strings" - "github.com/ollama/ollama/x/imagegen/manifest" "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/models/nn" ) +func init() { + base.Register("Glm4MoeLiteForCausalLM", newModel) + base.Register("GLM4MoeLite", newModel) +} + // RopeScaling holds RoPE scaling configuration type RopeScaling struct { Factor float32 `json:"factor"` @@ -131,7 +134,6 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3) out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1) - out = a.UnembedOut.Forward(out) out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim) @@ -386,44 +388,6 @@ func quantizationParams(quantization string) (groupSize, bits int, mode string) } } -// readBlobMetadata reads the __metadata__ from a safetensors blob header. -func readBlobMetadata(path string) (map[string]string, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - - var headerSize uint64 - if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { - return nil, err - } - if headerSize > 1024*1024 { - return nil, fmt.Errorf("header too large: %d", headerSize) - } - - data := make([]byte, headerSize) - if _, err := io.ReadFull(f, data); err != nil { - return nil, err - } - - var header map[string]json.RawMessage - if err := json.Unmarshal(data, &header); err != nil { - return nil, err - } - - metaRaw, ok := header["__metadata__"] - if !ok { - return nil, nil - } - - var meta map[string]string - if err := json.Unmarshal(metaRaw, &meta); err != nil { - return nil, err - } - return meta, nil -} - // ExpertWeight holds a single expert's weight with optional quantization components. type ExpertWeight struct { Weight *mlx.Array @@ -569,9 +533,10 @@ func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.Line return nn.NewLinear(w, bias) } -// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage). -func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { - configData, err := modelManifest.ReadConfig("config.json") +// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer, +// no weights loaded yet). Called by the registry via base.New(). +func newModel(root *model.Root) (base.Model, error) { + configData, err := root.Manifest.ReadConfig("config.json") if err != nil { return nil, fmt.Errorf("load config: %w", err) } @@ -584,66 +549,18 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim cfg.Scale = computeScale(&cfg) - // Load all tensors from manifest blobs into a flat map - allTensors := make(map[string]*mlx.Array) - seen := make(map[string]bool) // dedupe by digest - var quantType string - var quantGroupSize int - - for _, layer := range modelManifest.GetTensorLayers("") { - if seen[layer.Digest] { - continue - } - seen[layer.Digest] = true - blobPath := modelManifest.BlobPath(layer.Digest) - - // Read quantization metadata from first blob - if quantType == "" { - if meta, err := readBlobMetadata(blobPath); err == nil && meta != nil { - if qt := meta["quant_type"]; qt != "" { - quantType = strings.ToUpper(qt) - } - if gs := meta["group_size"]; gs != "" { - fmt.Sscanf(gs, "%d", &quantGroupSize) - } - } - } - - for name, arr := range mlx.Load(blobPath) { - // Map safetensors key naming to our naming convention - // Combined blobs use ".scale" and ".bias" suffixes - if strings.HasSuffix(name, ".scale") { - baseName := strings.TrimSuffix(name, ".scale") - allTensors[baseName+"_scale"] = arr - } else if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") { - // Check if this is a quantization bias or a regular bias - // by checking if there's a corresponding weight - baseName := strings.TrimSuffix(name, ".bias") - if _, hasScale := allTensors[baseName+"_scale"]; hasScale { - allTensors[baseName+"_qbias"] = arr - } else { - allTensors[name] = arr - } - } else { - allTensors[name] = arr - } - } - } - - // Set up quantization parameters - useQuantized := false - if quantType != "" { - _, cfg.QuantBits, cfg.QuantMode = quantizationParams(quantType) - if quantGroupSize > 0 { - cfg.QuantGroupSize = quantGroupSize + // Set up quantization parameters from pre-scanned metadata + if qt := root.QuantType(); qt != "" { + _, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt) + if gs := root.GroupSize(); gs > 0 { + cfg.QuantGroupSize = gs } else { - cfg.QuantGroupSize, _, _ = quantizationParams(quantType) + cfg.QuantGroupSize, _, _ = quantizationParams(qt) } - useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) } // Load tokenizer - tokData, err := modelManifest.ReadConfig("tokenizer.json") + tokData, err := root.Manifest.ReadConfig("tokenizer.json") if err != nil { return nil, fmt.Errorf("load tokenizer config: %w", err) } @@ -652,11 +569,11 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { ConfigJSON: configData, } - if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil { + if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil { tokConfig.GenerationConfigJSON = genConfigData } - if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil { + if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil { tokConfig.TokenizerConfigJSON = tokConfigData } @@ -671,18 +588,28 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { tok: tok, } + return m, nil +} + +// LoadWeights receives all tensors loaded from the manifest and assigns them +// to model fields. Handles MLA absorption, expert stacking, and quantized +// layer creation. +func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { + cfg := m.Config + useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) + // Load embedding - if w := allTensors["model.embed_tokens.weight"]; w != nil { + if w := tensors["model.embed_tokens.weight"]; w != nil { m.EmbedTokens = nn.NewEmbedding(w) } // Load final norm - if w := allTensors["model.norm.weight"]; w != nil { + if w := tensors["model.norm.weight"]; w != nil { m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps) } // Load LM head - m.LMHead = makeLinear(allTensors, "lm_head", &cfg) + m.LMHead = makeLinear(tensors, "lm_head", cfg) // Load layers for i := int32(0); i < cfg.NumHiddenLayers; i++ { @@ -690,24 +617,24 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { // Load attention (same for both block types) attn := &MLAAttention{} - attn.QAProj = makeLinear(allTensors, prefix+".self_attn.q_a_proj", &cfg) - if w := allTensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil { + attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg) + if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil { attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) } - attn.QBProj = makeLinear(allTensors, prefix+".self_attn.q_b_proj", &cfg) - attn.KVAProjWithMQA = makeLinear(allTensors, prefix+".self_attn.kv_a_proj_with_mqa", &cfg) - if w := allTensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil { + attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg) + attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg) + if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil { attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) } - attn.OProj = makeLinear(allTensors, prefix+".self_attn.o_proj", &cfg) + attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg) // Sanitize MLA weights for absorbed attention - embedQ, unembedOut := sanitizeMLAWeights(allTensors, prefix, &cfg) + embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg) attn.EmbedQ = nn.NewMultiLinear(embedQ) attn.UnembedOut = nn.NewMultiLinear(unembedOut) - inputLN := allTensors[prefix+".input_layernorm.weight"] - postAttnLN := allTensors[prefix+".post_attention_layernorm.weight"] + inputLN := tensors[prefix+".input_layernorm.weight"] + postAttnLN := tensors[prefix+".post_attention_layernorm.weight"] if i < cfg.FirstKDenseReplace { // Dense block @@ -720,9 +647,9 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { } block.MLP = &DenseMLP{ - GateProj: makeLinear(allTensors, prefix+".mlp.gate_proj", &cfg), - UpProj: makeLinear(allTensors, prefix+".mlp.up_proj", &cfg), - DownProj: makeLinear(allTensors, prefix+".mlp.down_proj", &cfg), + GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg), + UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg), + DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg), } m.Layers[i] = block @@ -737,7 +664,7 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { } // Stack expert weights - gate, up, down := sanitizeExpertWeights(allTensors, prefix, cfg.NRoutedExperts, useQuantized, &cfg) + gate, up, down := sanitizeExpertWeights(tensors, prefix, cfg.NRoutedExperts, useQuantized, cfg) switchMLP := &SwitchMLP{UseQuantized: useQuantized} if useQuantized { @@ -763,8 +690,8 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { } moeGate := &MoEGate{} - moeGate.Gate = makeLinear(allTensors, prefix+".mlp.gate", &cfg) - if bias := allTensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil { + moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg) + if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil { moeGate.EScoreCorrectionBias = bias } @@ -776,9 +703,9 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { // Load shared experts if present if cfg.NSharedExperts > 0 { block.MoE.SharedExperts = &SharedExperts{ - GateProj: makeLinear(allTensors, prefix+".mlp.shared_experts.gate_proj", &cfg), - UpProj: makeLinear(allTensors, prefix+".mlp.shared_experts.up_proj", &cfg), - DownProj: makeLinear(allTensors, prefix+".mlp.shared_experts.down_proj", &cfg), + GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg), + UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg), + DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg), } } @@ -786,9 +713,10 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { } } - mlx.Eval(mlx.Collect(m)...) + collected := mlx.Collect(m) + mlx.Eval(collected...) - return m, nil + return nil } // Forward computes the forward pass of the model