mirror of
https://github.com/ollama/ollama.git
synced 2025-12-05 19:16:53 -06:00
model: ministral w/ llama4 scaling (#13292)
This change: * fixes rope scaling in the mistral converter * updates ministral to include llama4 scaling * includes a new ministral parser for parsing reasoning and tool calling --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
@@ -11,7 +11,6 @@ linters:
|
||||
- errorlint
|
||||
- exptostd
|
||||
- gocheckcompilerdirectives
|
||||
- gocritic
|
||||
- govet
|
||||
- ineffassign
|
||||
- intrange
|
||||
|
||||
@@ -1430,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
@@ -29,6 +29,15 @@ type mistral3Model struct {
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
ScalingBeta float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
@@ -61,8 +70,13 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
||||
|
||||
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
||||
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
|
||||
}
|
||||
|
||||
// Vision configuration
|
||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
|
||||
@@ -159,8 +159,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
positionsScale := m.getScale(ctx, batch.Positions)
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, positionsScale, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -16,6 +16,8 @@ type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
headDim, ropeDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeOrigPosEmbeddings int
|
||||
ropeScalingBeta float32
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -34,7 +36,7 @@ type SelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||
|
||||
@@ -49,6 +51,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
if opts.ropeOrigPosEmbeddings > 0 {
|
||||
q = q.Mul(ctx, positionsScale)
|
||||
}
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
@@ -76,11 +82,11 @@ type Layer struct {
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
@@ -97,7 +103,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||
|
||||
// image embeddings
|
||||
@@ -114,25 +120,36 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
|
||||
posScale := make([]float32, len(positions))
|
||||
for n, pos := range positions {
|
||||
interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
|
||||
posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
|
||||
}
|
||||
return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
return &TextModel{
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
|
||||
ropeScalingBeta: c.Float("rope.scaling_beta"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
136
model/parsers/ministral.go
Normal file
136
model/parsers/ministral.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ministralParserState int
|
||||
|
||||
const (
|
||||
ministralCollectingContent = iota
|
||||
ministralCollectingThinkingContent
|
||||
ministralCollectingToolName
|
||||
ministralCollectingToolArgs
|
||||
)
|
||||
|
||||
type MinistralParser struct {
|
||||
state ministralParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
currentTool *api.Tool
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
if !p.HasThinkingSupport() {
|
||||
p.state = ministralCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = ministralCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = ministralCollectingThinkingContent
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.setInitialState(lastMessage)
|
||||
return tools
|
||||
}
|
||||
|
||||
func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == n {
|
||||
return &tools[i], nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("tool '%s' not found", n)
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
switch p.state {
|
||||
case ministralCollectingContent:
|
||||
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
|
||||
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
|
||||
if before != "" {
|
||||
return before, "", calls, nil
|
||||
}
|
||||
p.state = ministralCollectingToolName
|
||||
} else if strings.Contains(p.buffer.String(), "[THINK]") {
|
||||
p.state = ministralCollectingThinkingContent
|
||||
return "", "", calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return s, "", calls, nil
|
||||
}
|
||||
case ministralCollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), "[/THINK]") {
|
||||
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
|
||||
p.state = ministralCollectingContent
|
||||
if after != "" {
|
||||
p.buffer.Reset()
|
||||
return after, thinkingContent, calls, nil
|
||||
}
|
||||
return "", thinkingContent, calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return "", s, calls, nil
|
||||
}
|
||||
case ministralCollectingToolName:
|
||||
if strings.Contains(p.buffer.String(), "[ARGS]") {
|
||||
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
|
||||
|
||||
t, err := toolByName(p.tools, name)
|
||||
if err != nil {
|
||||
return "", "", calls, err
|
||||
}
|
||||
p.currentTool = t
|
||||
p.state = ministralCollectingToolArgs
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
case ministralCollectingToolArgs:
|
||||
if strings.Contains(p.buffer.String(), "}") {
|
||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(before), &data); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
|
||||
p.state = ministralCollectingContent
|
||||
|
||||
call := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: p.currentTool.Function.Name,
|
||||
Arguments: api.ToolCallFunctionArguments(data),
|
||||
},
|
||||
}
|
||||
calls = append(calls, call)
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
}
|
||||
|
||||
return p.buffer.String(), thinking, calls, nil
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
)
|
||||
@@ -38,16 +41,17 @@ func ParserForName(name string) Parser {
|
||||
if parser, ok := registry.constructors[name]; ok {
|
||||
return parser()
|
||||
}
|
||||
var p Parser
|
||||
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
parser := &Qwen3CoderParser{}
|
||||
return parser
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
parser := &Qwen3VLParser{hasThinkingSupport: false}
|
||||
return parser
|
||||
p = &Qwen3VLParser{hasThinkingSupport: false}
|
||||
case "qwen3-vl-thinking":
|
||||
parser := &Qwen3VLParser{hasThinkingSupport: true}
|
||||
return parser
|
||||
p = &Qwen3VLParser{hasThinkingSupport: true}
|
||||
case "ministral":
|
||||
p = &MinistralParser{hasThinkingSupport: false}
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
case "harmony":
|
||||
@@ -57,6 +61,7 @@ func ParserForName(name string) Parser {
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
type PassthroughParser struct{}
|
||||
@@ -76,3 +81,20 @@ func (p *PassthroughParser) HasToolSupport() bool {
|
||||
func (p *PassthroughParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(sb.String(), tag, 2)
|
||||
if len(split) == 1 {
|
||||
sb.Reset()
|
||||
return split[0], ""
|
||||
}
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
sb.Reset()
|
||||
sb.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -95,3 +96,164 @@ func TestUnknownParserReturnsNil(t *testing.T) {
|
||||
t.Error("expected nil for unknown parser")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitAtTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
tag string
|
||||
trimAfter bool
|
||||
wantBefore string
|
||||
wantAfter string
|
||||
wantSB string // expected content of strings.Builder after operation
|
||||
}{
|
||||
{
|
||||
name: "basic split with trimAfter true",
|
||||
input: "hello <!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "basic split with trimAfter false",
|
||||
input: "hello <!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: " world",
|
||||
wantSB: " world",
|
||||
},
|
||||
{
|
||||
name: "tag at beginning with trimAfter true",
|
||||
input: "<!-- split -->world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "tag at beginning with trimAfter false",
|
||||
input: "<!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "",
|
||||
wantAfter: " world",
|
||||
wantSB: " world",
|
||||
},
|
||||
{
|
||||
name: "tag at end with trimAfter true",
|
||||
input: "hello <!-- split -->",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "tag at end with trimAfter false",
|
||||
input: "hello <!-- split -->",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tags splits at first occurrence",
|
||||
input: "hello <!-- split --> world <!-- split --> end",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "world <!-- split --> end",
|
||||
wantSB: "world <!-- split --> end",
|
||||
},
|
||||
{
|
||||
name: "tag not present",
|
||||
input: "hello world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello world",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "only whitespace before tag",
|
||||
input: " \t\n<!-- split -->world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "only whitespace after tag with trimAfter true",
|
||||
input: "hello<!-- split --> \t\n",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "only whitespace after tag with trimAfter false",
|
||||
input: "hello<!-- split --> \t\n",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: " \t\n",
|
||||
wantSB: " \t\n",
|
||||
},
|
||||
{
|
||||
name: "complex whitespace trimming",
|
||||
input: " hello \t\n <!-- split --> \n\t world ",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: " hello",
|
||||
wantAfter: "world ",
|
||||
wantSB: "world ",
|
||||
},
|
||||
{
|
||||
name: "tag with special characters",
|
||||
input: "text <tag attr=\"value\"> more text",
|
||||
tag: "<tag attr=\"value\">",
|
||||
trimAfter: true,
|
||||
wantBefore: "text",
|
||||
wantAfter: "more text",
|
||||
wantSB: "more text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sb := &strings.Builder{}
|
||||
sb.WriteString(tt.input)
|
||||
|
||||
before, after := splitAtTag(sb, tt.tag, tt.trimAfter)
|
||||
|
||||
// Check return values
|
||||
if before != tt.wantBefore {
|
||||
t.Errorf("splitAtTag() before = %q, want %q", before, tt.wantBefore)
|
||||
}
|
||||
if after != tt.wantAfter {
|
||||
t.Errorf("splitAtTag() after = %q, want %q", after, tt.wantAfter)
|
||||
}
|
||||
|
||||
// Check strings.Builder state
|
||||
if sb.String() != tt.wantSB {
|
||||
t.Errorf("strings.Builder after split = %q, want %q", sb.String(), tt.wantSB)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,6 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
@@ -81,7 +80,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
calls = append(calls, toolCall)
|
||||
case qwenEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case qwenEventContent:
|
||||
@@ -91,7 +90,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
||||
@@ -113,19 +112,6 @@ func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
||||
return all
|
||||
}
|
||||
|
||||
func splitAtTag(p *Qwen3VLParser, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
@@ -144,7 +130,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
case CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
// events = emitContentBeforeTag(p, events, toolOpenTag)
|
||||
before, _ := splitAtTag(p, toolOpenTag, false)
|
||||
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventContent{content: before})
|
||||
}
|
||||
@@ -195,7 +181,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
}
|
||||
case CollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
thinking, remaining := splitAtTag(p, thinkingCloseTag, true)
|
||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user