diff --git a/cmd/cmd.go b/cmd/cmd.go index 2301efdfe..92e62ca7f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { mfConfig.System = cmd.Args case "license": mfConfig.License = cmd.Args + case "parser": + mfConfig.Parser = cmd.Args + case "renderer": + mfConfig.Renderer = cmd.Args } } diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 2b471d1da..fa9f8b598 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -45,6 +45,10 @@ func ParserForName(name string) Parser { var p Parser switch name { + case "qwen3": + p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false} + case "qwen3-thinking": + p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true} case "qwen3-coder": p = &Qwen3CoderParser{} case "qwen3-vl-instruct": diff --git a/model/parsers/parsers_test.go b/model/parsers/parsers_test.go index 4f8566de3..15c2f664f 100644 --- a/model/parsers/parsers_test.go +++ b/model/parsers/parsers_test.go @@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) { name string }{ {"passthrough"}, + {"qwen3"}, + {"qwen3-thinking"}, {"qwen3-coder"}, {"harmony"}, } diff --git a/model/parsers/qwen3.go b/model/parsers/qwen3.go new file mode 100644 index 000000000..e49111fb5 --- /dev/null +++ b/model/parsers/qwen3.go @@ -0,0 +1,335 @@ +package parsers + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type qwen3ParserState int + +const ( + qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota + qwen3ParserStateThinkingStartedEatingWhitespace + qwen3ParserStateCollectingThinking + qwen3ParserStateThinkingDoneEatingWhitespace + qwen3ParserStateCollectingContent + qwen3ParserStateToolStartedEatingWhitespace + qwen3ParserStateCollectingToolContent +) + +const ( + qwen3ThinkingOpenTag = "" + qwen3ThinkingCloseTag = "" + qwen3ToolOpenTag = "" + qwen3ToolCloseTag = "" +) + +// Qwen3Parser parses Qwen3 output to extract thinking and tool calls. +// Qwen3 prompts end with when thinking is enabled, so output begins +// with thinking content directly (without an opening tag). +type Qwen3Parser struct { + state qwen3ParserState + buffer strings.Builder + tools []api.Tool + hasThinkingSupport bool + defaultThinking bool + maybeThinkingOpenAtBOL bool +} + +func (p *Qwen3Parser) HasToolSupport() bool { + return true +} + +func (p *Qwen3Parser) HasThinkingSupport() bool { + return p.hasThinkingSupport +} + +func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + p.buffer.Reset() + + thinkingEnabled := thinkValue != nil && thinkValue.Bool() + if thinkValue == nil { + thinkingEnabled = p.defaultThinking + } + + if p.hasThinkingSupport && thinkingEnabled { + p.state = qwen3ParserStateCollectingThinking + p.maybeThinkingOpenAtBOL = true + } else { + p.state = qwen3ParserStateCollectingContent + p.maybeThinkingOpenAtBOL = false + } + return tools +} + +type qwen3Event interface { + isQwen3Event() +} + +type qwen3EventContent struct { + content string +} + +func (qwen3EventContent) isQwen3Event() {} + +type qwen3EventRawToolCall struct { + raw string +} + +func (qwen3EventRawToolCall) isQwen3Event() {} + +type qwen3EventThinkingContent struct { + content string +} + +func (qwen3EventThinkingContent) isQwen3Event() {} + +func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var contentSb strings.Builder + var thinkingSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case qwen3EventRawToolCall: + toolCall, err := parseQwen3ToolCall(event, p.tools) + if err != nil { + slog.Warn("qwen3 tool call parsing failed", "error", err) + return "", "", nil, err + } + calls = append(calls, toolCall) + case qwen3EventThinkingContent: + thinkingSb.WriteString(event.content) + case qwen3EventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), calls, nil +} + +func (p *Qwen3Parser) parseEvents() []qwen3Event { + var all []qwen3Event + + keepLooping := true + for keepLooping { + var events []qwen3Event + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String()) + } + + return all +} + +func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) { + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + p.buffer.Reset() + if trimmed == "" { + return nil, false + } + p.state = nextState + p.buffer.WriteString(trimmed) + return nil, true +} + +func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) { + return splitAtTag(&p.buffer, tag, trimAfter) +} + +func (p *Qwen3Parser) eat() ([]qwen3Event, bool) { + var events []qwen3Event + + switch p.state { + case qwen3ParserStateLookingForThinkingOpen: + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) { + after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag) + after = strings.TrimLeftFunc(after, unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + if after == "" { + p.state = qwen3ParserStateThinkingStartedEatingWhitespace + } else { + p.state = qwen3ParserStateCollectingThinking + } + return events, true + } else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) { + return events, false + } else if trimmed == "" { + return events, false + } + p.state = qwen3ParserStateCollectingContent + return events, true + + case qwen3ParserStateThinkingStartedEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking) + + case qwen3ParserStateCollectingThinking: + acc := p.buffer.String() + + // Some qwen3 checkpoints emit an explicit opening tag even + // though the prompt already ended with . Strip exactly one + // leading opening tag if present. + if p.maybeThinkingOpenAtBOL { + trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace) + if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) { + after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag) + after = strings.TrimLeftFunc(after, unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + if after == "" { + return events, false + } + p.maybeThinkingOpenAtBOL = false + return events, true + } + if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) { + return events, false + } + p.maybeThinkingOpenAtBOL = false + } + + if strings.Contains(acc, qwen3ThinkingCloseTag) { + thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true) + if len(thinking) > 0 { + events = append(events, qwen3EventThinkingContent{content: thinking}) + } + if remaining == "" { + p.state = qwen3ParserStateThinkingDoneEatingWhitespace + } else { + p.state = qwen3ParserStateCollectingContent + } + return events, true + } else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 { + beforePartialTag := acc[:len(acc)-overlapLen] + trailingWsLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWsLen + + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwen3EventThinkingContent{content: unambiguous}) + } + return events, false + } + + whitespaceLen := trailingWhitespaceLen(acc) + ambiguousStart := len(acc) - whitespaceLen + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwen3EventThinkingContent{content: unambiguous}) + } + return events, false + + case qwen3ParserStateThinkingDoneEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent) + + case qwen3ParserStateCollectingContent: + acc := p.buffer.String() + if strings.Contains(acc, qwen3ToolOpenTag) { + before, after := p.splitAtTag(qwen3ToolOpenTag, true) + if len(before) > 0 { + events = append(events, qwen3EventContent{content: before}) + } + if after == "" { + p.state = qwen3ParserStateToolStartedEatingWhitespace + } else { + p.state = qwen3ParserStateCollectingToolContent + } + return events, true + } else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 { + beforePartialTag := acc[:len(acc)-overlapLen] + trailingWsLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWsLen + + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwen3EventContent{content: unambiguous}) + } + return events, false + } + + whitespaceLen := trailingWhitespaceLen(acc) + ambiguousStart := len(acc) - whitespaceLen + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwen3EventContent{content: unambiguous}) + } + return events, false + + case qwen3ParserStateToolStartedEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent) + + case qwen3ParserStateCollectingToolContent: + acc := p.buffer.String() + if strings.Contains(acc, qwen3ToolCloseTag) { + toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true) + if len(toolContent) == 0 { + slog.Warn("qwen3 tool call closing tag found but no content before it") + } + events = append(events, qwen3EventRawToolCall{raw: toolContent}) + p.state = qwen3ParserStateCollectingContent + return events, true + } + return events, false + + default: + panic("unreachable") + } +} + +func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + var parsed struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + } + + if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil { + return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err) + } + + if parsed.Name == "" { + return api.ToolCall{}, fmt.Errorf("empty function name") + } + + _ = tools // qwen3 uses direct JSON args and does not require schema coercion here. + + toolCall := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: parsed.Name, + Arguments: api.NewToolCallFunctionArguments(), + }, + } + + for key, value := range parsed.Arguments { + toolCall.Function.Arguments.Set(key, value) + } + + return toolCall, nil +} diff --git a/model/parsers/qwen3_test.go b/model/parsers/qwen3_test.go new file mode 100644 index 000000000..853874ded --- /dev/null +++ b/model/parsers/qwen3_test.go @@ -0,0 +1,147 @@ +package parsers + +import ( + "testing" + + "github.com/ollama/ollama/api" +) + +func TestQwen3ParserThinkingEnabled(t *testing.T) { + parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true} + parser.Init(nil, nil, &api.ThinkValue{Value: true}) + + content, thinking, calls, err := parser.Add("Let me think...Answer.", true) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + if thinking != "Let me think..." { + t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking) + } + if content != "Answer." { + t.Fatalf("expected content %q, got %q", "Answer.", content) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } +} + +func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) { + parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true} + parser.Init(nil, nil, &api.ThinkValue{Value: true}) + + content, thinking, calls, err := parser.Add("\nLet me think...Answer.", true) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + if thinking != "Let me think..." { + t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking) + } + if content != "Answer." { + t.Fatalf("expected content %q, got %q", "Answer.", content) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } +} + +func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) { + parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true} + parser.Init(nil, nil, &api.ThinkValue{Value: true}) + + content, thinking, calls, err := parser.Add("Let me think...Answer.", true) + if err != nil { + t.Fatalf("parse failed on second chunk: %v", err) + } + if thinking != "Let me think..." { + t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking) + } + if content != "Answer." { + t.Fatalf("expected content %q, got %q", "Answer.", content) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } +} + +func TestQwen3ParserThinkingDisabled(t *testing.T) { + parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false} + parser.Init(nil, nil, &api.ThinkValue{Value: false}) + + content, thinking, calls, err := parser.Add("Direct answer", true) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + if thinking != "" { + t.Fatalf("expected no thinking, got %q", thinking) + } + if content != "Direct answer" { + t.Fatalf("expected content %q, got %q", "Direct answer", content) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } +} + +func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) { + parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false} + parser.Init(nil, nil, nil) + + content, thinking, calls, err := parser.Add("Direct answer", true) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + if thinking != "" { + t.Fatalf("expected no thinking, got %q", thinking) + } + if content != "Direct answer" { + t.Fatalf("expected content %q, got %q", "Direct answer", content) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } +} + +func TestQwen3ParserToolCall(t *testing.T) { + parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false} + parser.Init(nil, nil, &api.ThinkValue{Value: false}) + + input := "{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}" + content, thinking, calls, err := parser.Add(input, true) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + if content != "" { + t.Fatalf("expected empty content, got %q", content) + } + if thinking != "" { + t.Fatalf("expected empty thinking, got %q", thinking) + } + if len(calls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name) + } + + location, ok := calls[0].Function.Arguments.Get("location") + if !ok || location != "San Francisco" { + t.Fatalf("expected location %q, got %v", "San Francisco", location) + } + unit, ok := calls[0].Function.Arguments.Get("unit") + if !ok || unit != "celsius" { + t.Fatalf("expected unit %q, got %v", "celsius", unit) + } +} diff --git a/x/create/client/create.go b/x/create/client/create.go index f89f9fc98..b8062f3d4 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -30,6 +30,8 @@ type ModelfileConfig struct { Template string System string License string + Parser string + Renderer string } // CreateOptions holds all options for model creation. @@ -37,7 +39,7 @@ type CreateOptions struct { ModelName string ModelDir string Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization - Modelfile *ModelfileConfig // template/system/license from Modelfile + Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile } // CreateModel imports a model from a local directory. @@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re ModelFormat: "safetensors", Capabilities: caps, Requires: MinOllamaVersion, - Parser: parserName, - Renderer: rendererName, + Parser: resolveParserName(opts.Modelfile, parserName), + Renderer: resolveRendererName(opts.Modelfile, rendererName), } configJSON, err := json.Marshal(configData) if err != nil { @@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re } } +func resolveParserName(mf *ModelfileConfig, inferred string) string { + if mf != nil && mf.Parser != "" { + return mf.Parser + } + + return inferred +} + +func resolveRendererName(mf *ModelfileConfig, inferred string) string { + if mf != nil && mf.Renderer != "" { + return mf.Renderer + } + + return inferred +} + // createModelfileLayers creates layers for template, system, and license from Modelfile config. func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) { var layers []manifest.Layer @@ -410,7 +428,7 @@ func getParserName(modelDir string) string { return "deepseek3" } if strings.Contains(archLower, "qwen3") { - return "qwen3-coder" + return "qwen3" } } @@ -424,7 +442,7 @@ func getParserName(modelDir string) string { return "deepseek3" } if strings.Contains(typeLower, "qwen3") { - return "qwen3-coder" + return "qwen3" } } diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go index b41807279..1e7062237 100644 --- a/x/create/client/create_test.go +++ b/x/create/client/create_test.go @@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) { Template: "{{ .Prompt }}", System: "You are a helpful assistant.", License: "MIT", + Parser: "qwen3", + Renderer: "qwen3", } if config.Template != "{{ .Prompt }}" { @@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) { if config.License != "MIT" { t.Errorf("License = %q, want %q", config.License, "MIT") } + if config.Parser != "qwen3" { + t.Errorf("Parser = %q, want %q", config.Parser, "qwen3") + } + if config.Renderer != "qwen3" { + t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3") + } } func TestModelfileConfig_Empty(t *testing.T) { @@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) { if config.License != "" { t.Errorf("License should be empty, got %q", config.License) } + if config.Parser != "" { + t.Errorf("Parser should be empty, got %q", config.Parser) + } + if config.Renderer != "" { + t.Errorf("Renderer should be empty, got %q", config.Renderer) + } } func TestModelfileConfig_PartialFields(t *testing.T) { @@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) { if config.License != "" { t.Error("License should be empty") } + if config.Parser != "" { + t.Error("Parser should be empty") + } + if config.Renderer != "" { + t.Error("Renderer should be empty") + } } func TestMinOllamaVersion(t *testing.T) { @@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) { Template: "test", System: "system", License: "MIT", + Parser: "qwen3-thinking", + Renderer: "qwen3", }, } @@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) { if opts.Modelfile.Template != "test" { t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test") } + if opts.Modelfile.Parser != "qwen3-thinking" { + t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking") + } + if opts.Modelfile.Renderer != "qwen3" { + t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3") + } +} + +func TestResolveParserName(t *testing.T) { + tests := []struct { + name string + mf *ModelfileConfig + inferred string + want string + }{ + { + name: "nil modelfile uses inferred", + mf: nil, + inferred: "qwen3", + want: "qwen3", + }, + { + name: "empty parser uses inferred", + mf: &ModelfileConfig{ + Parser: "", + }, + inferred: "qwen3", + want: "qwen3", + }, + { + name: "explicit parser overrides inferred", + mf: &ModelfileConfig{ + Parser: "qwen3-thinking", + }, + inferred: "qwen3", + want: "qwen3-thinking", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveParserName(tt.mf, tt.inferred); got != tt.want { + t.Fatalf("resolveParserName() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestResolveRendererName(t *testing.T) { + tests := []struct { + name string + mf *ModelfileConfig + inferred string + want string + }{ + { + name: "nil modelfile uses inferred", + mf: nil, + inferred: "qwen3-coder", + want: "qwen3-coder", + }, + { + name: "empty renderer uses inferred", + mf: &ModelfileConfig{ + Renderer: "", + }, + inferred: "qwen3-coder", + want: "qwen3-coder", + }, + { + name: "explicit renderer overrides inferred", + mf: &ModelfileConfig{ + Renderer: "qwen3", + }, + inferred: "qwen3-coder", + want: "qwen3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want { + t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want) + } + }) + } } func TestCreateOptions_Defaults(t *testing.T) { diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go index a8111b056..5e0cfe86d 100644 --- a/x/mlxrunner/imports.go +++ b/x/mlxrunner/imports.go @@ -6,4 +6,5 @@ import ( _ "github.com/ollama/ollama/x/models/gemma3" _ "github.com/ollama/ollama/x/models/glm4_moe_lite" _ "github.com/ollama/ollama/x/models/llama" + _ "github.com/ollama/ollama/x/models/qwen3" ) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index cd4d78620..0da5862c8 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -18,7 +18,15 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return errors.New("model not loaded") } - mlx.EnableCompile() + enableCompile := true + if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok { + enableCompile = modelCompile.EnableCompile() + } + if enableCompile { + mlx.EnableCompile() + } else { + mlx.DisableCompile() + } inputs := r.Tokenizer.Encode(request.Prompt, true) diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go new file mode 100644 index 000000000..7a49cf37a --- /dev/null +++ b/x/models/qwen3/qwen3.go @@ -0,0 +1,338 @@ +//go:build mlx + +// Package qwen3 provides the Qwen3 text model implementation for MLX. +package qwen3 + +import ( + "encoding/json" + "fmt" + "math" + + "github.com/ollama/ollama/x/imagegen/tokenizer" + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/mlxrunner/model/base" + "github.com/ollama/ollama/x/models/nn" +) + +func init() { + base.Register("Qwen3ForCausalLM", newModel) +} + +// Config holds Qwen3 model configuration. +type Config struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + HeadDim int32 `json:"head_dim"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + TieWordEmbeddings bool `json:"tie_word_embeddings"` + + // Quantization parameters (set during load based on model quantization). + QuantGroupSize int `json:"-"` + QuantBits int `json:"-"` + QuantMode string `json:"-"` + TensorQuant map[string]*model.TensorQuantInfo `json:"-"` + + // Computed fields. + Scale float32 `json:"-"` + QKNormEps float32 `json:"-"` +} + +// Model is the Qwen3 text-only model. +type Model struct { + EmbedTokens *nn.Embedding + Layers []*Layer + Norm *nn.RMSNorm + LMHead nn.LinearLayer + + tok *tokenizer.Tokenizer + *Config + + weightPrefix string +} + +// Layer is a single Qwen3 decoder block. +type Layer struct { + Attention *Attention + MLP *MLP + AttentionNorm *nn.RMSNorm + MLPNorm *nn.RMSNorm +} + +// Attention implements Qwen3 attention with Q/K norms. +type Attention struct { + QProj nn.LinearLayer + KProj nn.LinearLayer + VProj nn.LinearLayer + OProj nn.LinearLayer + QNorm *nn.RMSNorm + KNorm *nn.RMSNorm +} + +// MLP is the feed-forward network with SwiGLU activation. +type MLP struct { + GateProj nn.LinearLayer + UpProj nn.LinearLayer + DownProj nn.LinearLayer +} + +func resolveWeightPrefix(tensors map[string]*mlx.Array) string { + for _, prefix := range []string{"", "language_model."} { + if tensors[prefix+"model.embed_tokens.weight"] != nil { + return prefix + } + } + return "" +} + +func newModel(root *model.Root) (base.Model, error) { + configData, err := root.Manifest.ReadConfig("config.json") + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + var cfg Config + if err := json.Unmarshal(configData, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + if cfg.HiddenSize <= 0 { + return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize) + } + if cfg.NumAttentionHeads <= 0 { + return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads) + } + if cfg.NumKeyValueHeads <= 0 { + cfg.NumKeyValueHeads = cfg.NumAttentionHeads + } + if cfg.HeadDim == 0 { + if cfg.HiddenSize%cfg.NumAttentionHeads != 0 { + return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads) + } + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + } + if cfg.HeadDim <= 0 { + return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim) + } + if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 { + return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads) + } + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 1000000 + } + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + cfg.QKNormEps = 1e-6 + + if qt := root.QuantType(); qt != "" { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) + if gs := root.GroupSize(); gs > 0 { + cfg.QuantGroupSize = gs + } + } else { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("") + } + cfg.TensorQuant = root.AllTensorQuant() + + tokData, err := root.Manifest.ReadConfig("tokenizer.json") + if err != nil { + return nil, fmt.Errorf("load tokenizer config: %w", err) + } + + tokConfig := &tokenizer.TokenizerConfig{ + ConfigJSON: configData, + } + if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil { + tokConfig.GenerationConfigJSON = genConfigData + } + if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil { + tokConfig.TokenizerConfigJSON = tokConfigData + } + + tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) + if err != nil { + return nil, fmt.Errorf("parse tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]*Layer, cfg.NumHiddenLayers), + Config: &cfg, + tok: tok, + } + + return m, nil +} + +// LoadWeights receives all tensors loaded from the manifest and assigns them +// to model fields. +func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { + m.weightPrefix = resolveWeightPrefix(tensors) + prefix := m.weightPrefix + linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + + embedWeight := tensors[prefix+"model.embed_tokens.weight"] + if embedWeight == nil { + return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix) + } + m.EmbedTokens = nn.NewEmbedding(embedWeight) + + normWeight := tensors[prefix+"model.norm.weight"] + if normWeight == nil { + return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix) + } + m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps) + + if m.TieWordEmbeddings { + m.LMHead = nn.NewLinear(embedWeight, nil) + } else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil { + m.LMHead = lmHead + } else if lmHead := linears.Make("lm_head"); lmHead != nil { + m.LMHead = lmHead + } else { + // Qwen3 checkpoints commonly tie output projection to embeddings. + m.LMHead = nn.NewLinear(embedWeight, nil) + } + + for i := int32(0); i < m.NumHiddenLayers; i++ { + layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i) + + layer := &Layer{ + Attention: &Attention{}, + MLP: &MLP{}, + } + + if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil { + layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil { + layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + + layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj") + layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj") + layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj") + layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj") + + if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil { + layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps) + } + if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil { + layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps) + } + + layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj") + layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj") + layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj") + + if layer.AttentionNorm == nil { + return fmt.Errorf("layer %d: missing input_layernorm", i) + } + if layer.MLPNorm == nil { + return fmt.Errorf("layer %d: missing post_attention_layernorm", i) + } + if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil { + return fmt.Errorf("layer %d: missing attention projections", i) + } + if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil { + return fmt.Errorf("layer %d: missing attention q/k norms", i) + } + if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil { + return fmt.Errorf("layer %d: missing mlp projections", i) + } + + m.Layers[i] = layer + } + + collected := mlx.Collect(m) + mlx.Eval(collected...) + + return nil +} + +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + dims := tokens.Dims() + B, L := int32(dims[0]), int32(dims[1]) + + h := m.EmbedTokens.Forward(tokens) + for i, layer := range m.Layers { + var c cache.Cache + if caches != nil && i < len(caches) { + c = caches[i] + } + h = layer.Forward(h, c, B, L, m.Config) + } + + return m.Norm.Forward(h, m.RMSNormEps) +} + +func (m *Model) Unembed(x *mlx.Array) *mlx.Array { + return m.LMHead.Forward(x) +} + +func (m *Model) NumLayers() int { + return len(m.Layers) +} + +func (m *Model) Tokenizer() *tokenizer.Tokenizer { + return m.tok +} + +func (m *Model) NewCaches() []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i := range caches { + caches[i] = cache.NewKVCache() + } + return caches +} + +func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)) + return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps))) +} + +func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim) + k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) + v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) + + q = a.QNorm.Forward(q, cfg.QKNormEps) + k = a.KNorm.Forward(k, cfg.QKNormEps) + + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + offset := 0 + if c != nil { + offset = c.Offset() + } + q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) + k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) + + if c != nil { + k, v = c.Update(k, v) + } + + // MLX SDPA supports grouped-query attention directly (Q heads can be a + // multiple of K/V heads), so avoid materializing repeated K/V tensors. + out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +func (m *MLP) Forward(x *mlx.Array) *mlx.Array { + return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) +}