mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 03:12:11 -05:00
mlxrunner fixes (#14247)
* load glm4_moe_lite from the mlxrunner * fix loading diffusion models * remove log lines * fix --imagegen flag
This commit is contained in:
@@ -150,12 +150,15 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
|
||||
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
|
||||
delete(requestOpts, "use_imagegen_runner")
|
||||
|
||||
opts, err := s.modelOptions(model, requestOpts)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
|
||||
@@ -2383,6 +2383,7 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -32,6 +33,7 @@ type LlmRequest struct {
|
||||
successCh chan *runnerRef
|
||||
errCh chan error
|
||||
schedAttempts uint
|
||||
useImagegen bool
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
@@ -82,7 +84,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
@@ -99,6 +101,7 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
||||
sessionDuration: sessionDuration,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
useImagegen: useImagegen,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
@@ -566,17 +569,20 @@ iGPUScan:
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode imagegen.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = imagegen.ModeImageGen
|
||||
} else {
|
||||
mode = imagegen.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName, mode)
|
||||
var server llm.LlamaServer
|
||||
var err error
|
||||
|
||||
isImagegen := false
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
|
||||
isImagegen = true
|
||||
} else if req.useImagegen {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
|
||||
isImagegen = true
|
||||
} else {
|
||||
server, err = mlxrunner.NewClient(modelName)
|
||||
}
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
@@ -593,6 +599,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
isImagegen: isImagegen,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
@@ -667,6 +674,7 @@ type runnerRef struct {
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus []ml.DeviceID // Recorded at time of provisioning
|
||||
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
|
||||
isImagegen bool // True if loaded via imagegen runner (vs mlxrunner)
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
|
||||
@@ -699,6 +707,12 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested
|
||||
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
|
||||
if runner.isImagegen != wantImagegen {
|
||||
return true
|
||||
}
|
||||
|
||||
timeout := 10 * time.Second
|
||||
if runner.loading {
|
||||
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
|
||||
|
||||
@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
slog.Info("b")
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
require.Empty(t, successCh1b)
|
||||
require.Len(t, errCh1b, 1)
|
||||
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
|
||||
c.req.model.ModelPath = "bad path"
|
||||
slog.Info("c")
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
|
||||
// Starts in pending channel, then should be quickly processed to return an error
|
||||
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||
require.Empty(t, successCh1c)
|
||||
@@ -509,7 +509,7 @@ func TestSchedPrematureExpired(t *testing.T) {
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = scenario1a.newServer
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
s.Run(ctx)
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user