diff --git a/convert/convert.go b/convert/convert.go index bd3c84344..b2e6f5e37 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -311,6 +311,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { conv = &deepseekocr{} case "DeepseekV3ForCausalLM": conv = &deepseek2Model{} + case "Glm4MoeLiteForCausalLM": + conv = &glm4MoeLiteModel{} default: return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0]) } diff --git a/convert/convert_glm4moelite.go b/convert/convert_glm4moelite.go new file mode 100644 index 000000000..a74a2fee6 --- /dev/null +++ b/convert/convert_glm4moelite.go @@ -0,0 +1,150 @@ +package convert + +import ( + "cmp" + "fmt" + "log/slog" + "regexp" + "strconv" + + "github.com/ollama/ollama/fs/ggml" +) + +type glm4MoeLiteModel struct { + ModelParameters + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + HiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RMSNormEPS float32 `json:"rms_norm_eps"` + + RopeTheta float32 `json:"rope_theta"` + QKNopeHeadDim uint32 `json:"qk_nope_head_dim"` + QKRopeHeadDim uint32 `json:"qk_rope_head_dim"` + KVLoraRank uint32 `json:"kv_lora_rank"` + QLoraRank uint32 `json:"q_lora_rank"` + VHeadDim uint32 `json:"v_head_dim"` + + ExpertCount uint32 `json:"n_routed_experts"` + ExpertSharedCount uint32 `json:"n_shared_experts"` + ExpertIntermediateSize uint32 `json:"moe_intermediate_size"` + ExpertUsedCount uint32 `json:"num_experts_per_tok"` + ExpertWeightsNorm bool `json:"norm_topk_prob"` + ExpertWeightsScale float32 `json:"routed_scaling_factor"` + + LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"` +} + +func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "glm4moelite" + kv["general.type"] = "model" + kv["glm4moelite.block_count"] = p.HiddenLayers + + numHeads := p.NumAttentionHeads + numKVHeads := p.NumKeyValueHeads + + kv["glm4moelite.attention.head_count"] = numHeads + kv["glm4moelite.attention.head_count_kv"] = numKVHeads + kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim + kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank + kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank + kv["glm4moelite.attention.value_length"] = p.VHeadDim + kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings + kv["glm4moelite.embedding_length"] = p.HiddenSize + kv["glm4moelite.expert_count"] = p.ExpertCount + kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize + kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount + + kv["glm4moelite.expert_gating_func"] = uint32(2) + kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount + kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm + kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale + kv["glm4moelite.feed_forward_length"] = p.IntermediateSize + kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount + + kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim + kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0)) + + kv["tokenizer.ggml.pre"] = "glm4" + + return kv +} + +func (p *glm4MoeLiteModel) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.norm", "output_norm", + "model.layers", "blk", + "input_layernorm", "attn_norm", + "self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa", + "self_attn.kv_a_layernorm", "attn_kv_a_norm", + "self_attn.kv_b_proj", "attn_kv_b", + "self_attn.q_a_proj", "attn_q_a", + "self_attn.q_a_layernorm", "attn_q_a_norm", + "self_attn.q_b_proj", "attn_q_b", + "self_attn.o_proj", "attn_output", + "post_attention_layernorm", "ffn_norm", + "mlp.shared_experts.down_proj", "ffn_down_shexp", + "mlp.shared_experts.gate_proj", "ffn_gate_shexp", + "mlp.shared_experts.up_proj", "ffn_up_shexp", + "mlp.gate_proj", "ffn_gate", + "mlp.down_proj", "ffn_down", + "mlp.up_proj", "ffn_up", + "mlp.gate.e_score_correction_bias", "exp_probs_b.bias", + "mlp.gate", "ffn_gate_inp", + } +} + +func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) { + merges := make([]merge, p.HiddenLayers*3) + for i := range p.HiddenLayers { + merges[i*3+0] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), + } + merges[i*3+1] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_up_exps.weight", i), + } + merges[i*3+2] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), + } + } + + skipLayer := func(n string, minValue uint32) bool { + re := regexp.MustCompile(`^blk\.(\d+)`) + matches := re.FindStringSubmatch(n) + if matches == nil { + return false + } + + blkNum, err := strconv.Atoi(matches[1]) + if err != nil { + return false + } + + return uint32(blkNum) >= minValue + } + + out, s = mergeTensors(s, merges...) + for _, t := range s { + // skip any additional layers (such as the Multi-Token Prediction layer) + if skipLayer(t.Name(), p.HiddenLayers) { + slog.Debug("skipping layer", "name", t.Name()) + continue + } + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + return out +} diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 4d0dcb07c..6db305f77 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -269,6 +269,7 @@ func (kv KV) OllamaEngineRequired() bool { "qwen25vl", "qwen3", "qwen3moe", "qwen3vl", "qwen3vlmoe", + "glm4moelite", }, kv.Architecture()) } @@ -856,6 +857,7 @@ func (f GGML) FlashAttention() bool { return slices.Contains([]string{ "bert", "gemma3", + "glm4moelite", "gptoss", "gpt-oss", "mistral3", "olmo3", diff --git a/model/models/glm4moelite/model.go b/model/models/glm4moelite/model.go new file mode 100644 index 000000000..2e51f7d56 --- /dev/null +++ b/model/models/glm4moelite/model.go @@ -0,0 +1,304 @@ +package glm4moelite + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + numExpertsUsed int + numExperts int + normTopKProb bool + routedScalingFactor float32 + + kvLoraRank, + qkNopeHeadDim, + qkRopeHeadDim, + kqNopeHeadDim, + qkHeadDim int + qLoraRank int + vHeadDim int + + hiddenSize, + numHeads, + numKVHeads int + + eps, + ropeBase float32 + kqScale float64 +} + +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0) +} + +type Attention struct { + Q *nn.Linear `gguf:"attn_q"` + + QA *nn.Linear `gguf:"attn_q_a"` + QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"` + QB *nn.Linear `gguf:"attn_q_b"` + + KVA *nn.Linear `gguf:"attn_kv_a_mqa"` + KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"` + KVB *nn.Linear `gguf:"attn_kv_b"` + + Output *nn.Linear `gguf:"attn_out,alt:attn_output"` +} + +func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + seqLength := hiddenStates.Dim(1) + + var query ml.Tensor + if opts.qLoraRank == 0 { + query = attn.Q.Forward(ctx, hiddenStates) + } else { + query = attn.QA.Forward(ctx, hiddenStates) + query = attn.QANorm.Forward(ctx, query, opts.eps) + query = attn.QB.Forward(ctx, query) + } + + query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) + queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim) + + compressedKV := attn.KVA.Forward(ctx, hiddenStates) + kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1) + kRot := compressedKV.View(ctx, + opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim, + compressedKV.Stride(1), 1, + compressedKV.Stride(1), compressedKV.Dim(1), + ) + + qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions) + kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions) + kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) + kPass = attn.KVB.Forward(ctx, kPass) + + kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) + kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim) + + kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) + query = qRot.Concat(ctx, queryChunks[0], 0) + key := kRot.Concat(ctx, kvChunks[0], 0) + attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) + + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) + return attn.Output.Forward(ctx, attention) +} + +type MLP interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.Linear `gguf:"ffn_gate_exps"` + Up *nn.Linear `gguf:"ffn_up_exps"` + Down *nn.Linear `gguf:"ffn_down_exps"` + SharedExpert *dense `gguf:",suf:_shexp"` + ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"` +} + +func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor { + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices) + hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices) + hiddenStates = hiddenStates.SILU(ctx, upStates) + + experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices) + experts = experts.Mul(ctx, topKWeights) + + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + return nextStates +} + +func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor { + if moe.ExpProbsBias != nil { + scores = scores.Add(ctx, moe.ExpProbsBias) + } + topKIndices := scores.TopK(ctx, opts.numExpertsUsed) + return topKIndices +} + +func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + residuals := hiddenStates + + routerLogits := moe.Router.Forward(ctx, hiddenStates) + scores := routerLogits.Sigmoid(ctx) + topKIndices := moe.topKIndices(ctx, scores, opts) + topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices) + + if opts.normTopKProb { + topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1)) + topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx)) + topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1)) + } + + topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor)) + hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts) + sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts) + + hiddenStates = hiddenStates.Add(ctx, sharedExpertResult) + return hiddenStates +} + +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Attention *Attention + + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP MLP +} + +func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + return hiddenStates +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *Options +} + +func New(c fs.Config) (model.Model, error) { + layers := make([]Layer, c.Uint("block_count")) + + firstDenseLayerIndex := int(c.Uint("leading_dense_block_count")) + for i := range layers { + if i < firstDenseLayerIndex { + layers[i].MLP = &dense{} + } else { + layers[i].MLP = &sparse{} + } + } + + keyLength := int(c.Uint("attention.key_length")) + valueLength := int(c.Uint("attention.value_length")) + + kqScale := 1.0 / math.Sqrt(float64(keyLength)) + + var pre []string + switch c.String("tokenizer.ggml.pre") { + case "glm4": + pre = []string{ + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + } + default: + return nil, model.ErrUnsupportedTokenizer + } + + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + pre..., + ), + Layers: layers, + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("expert_weights_norm", true), + + qLoraRank: int(c.Uint("attention.q_lora_rank")), + kvLoraRank: int(c.Uint("attention.kv_lora_rank")), + qkHeadDim: keyLength, + vHeadDim: valueLength, + qkRopeHeadDim: int(c.Uint("rope.dimension_count")), + qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")), + kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")), + + routedScalingFactor: c.Float("expert_weights_scale"), + + kqScale: kqScale, + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func init() { + model.Register("glm4moelite", New) +} diff --git a/model/models/models.go b/model/models/models.go index b471e8166..d900f7cc3 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -7,6 +7,7 @@ import ( _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n" + _ "github.com/ollama/ollama/model/models/glm4moelite" _ "github.com/ollama/ollama/model/models/gptoss" _ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama4" diff --git a/model/parsers/glm46.go b/model/parsers/glm46.go new file mode 100644 index 000000000..05826d9ed --- /dev/null +++ b/model/parsers/glm46.go @@ -0,0 +1,410 @@ +package parsers + +import ( + "context" + "encoding/xml" + "fmt" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type glm46ParserState int + +const ( + glm46ParserState_LookingForThinkingOpen glm46ParserState = iota + glm46ParserState_ThinkingStartedEatingWhitespace + glm46ParserState_CollectingThinking + glm46ParserState_ThinkingDoneEatingWhitespace + glm46ParserState_CollectingContent + glm46ParserState_ToolStartedEatingWhitespace + glm46ParserState_CollectingToolContent +) + +const ( + glm46ThinkingOpenTag = "" + glm46ThinkingCloseTag = "" + glm46ToolOpenTag = "" + glm46ToolCloseTag = "" +) + +type GLM46Parser struct { + state glm46ParserState + buffer strings.Builder + tools []api.Tool +} + +func (p *GLM46Parser) HasToolSupport() bool { + return true +} + +func (p *GLM46Parser) HasThinkingSupport() bool { + return true +} + +// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { +func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + return tools +} + +type glm46Event interface { + isGLM46Event() +} + +type glm46EventContent struct { + content string +} + +func (glm46EventContent) isGLM46Event() {} + +type glm46EventRawToolCall struct { + raw string +} + +func (glm46EventRawToolCall) isGLM46Event() {} + +type glm46EventThinkingContent struct { + content string +} + +func (glm46EventThinkingContent) isGLM46Event() {} + +func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + + for _, event := range events { + switch event := event.(type) { + case glm46EventRawToolCall: + toolCall, err := parseGLM46ToolCall(event, p.tools) + if err != nil { + slog.Warn("glm-4.6 tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case glm46EventThinkingContent: + thinkingSb.WriteString(event.content) + case glm46EventContent: + // TODO(drifkin): if the same turn contains multiple interleaved content + // events, we naively append them together here. + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *GLM46Parser) parseEvents() []glm46Event { + var all []glm46Event + + keepLooping := true + for keepLooping { + var events []glm46Event + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String()) + } + + return all +} + +// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer +// and transitions to the next state. Returns (nil, false) if only whitespace remains +// in the buffer (needs more input), or (nil, true) if we successfully transitioned. +func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) { + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + p.buffer.Reset() + if trimmed == "" { + return nil, false // Still only whitespace, keep waiting for more input + } + p.state = nextState + p.buffer.WriteString(trimmed) + return nil, true // Successfully transitioned +} + +// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace), +// the content after (optionally trimmed of leading whitespace), and updates the buffer +func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) { + split := strings.SplitN(p.buffer.String(), tag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + after := split[1] + if trimAfter { + after = strings.TrimLeftFunc(after, unicode.IsSpace) + } + p.buffer.Reset() + p.buffer.WriteString(after) + return before, after +} + +func (p *GLM46Parser) eat() ([]glm46Event, bool) { + var events []glm46Event + + switch p.state { + case glm46ParserState_LookingForThinkingOpen: + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) { + // Found opening tag + after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag) + after = strings.TrimLeftFunc(after, unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + if after == "" { + p.state = glm46ParserState_ThinkingStartedEatingWhitespace + } else { + p.state = glm46ParserState_CollectingThinking + } + return events, true + } else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) { + // Partial opening tag seen, keep accumulating + return events, false + } else if trimmed == "" { + // Only whitespace, keep accumulating + return events, false + } else { + // No thinking tag found, skip to content collection + p.state = glm46ParserState_CollectingContent + // Don't trim - we want to keep the original content + return events, true + } + + case glm46ParserState_ThinkingStartedEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking) + + case glm46ParserState_CollectingThinking: + acc := p.buffer.String() + if strings.Contains(acc, glm46ThinkingCloseTag) { + thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true) + if len(thinking) > 0 { + events = append(events, glm46EventThinkingContent{content: thinking}) + } + if remaining == "" { + p.state = glm46ParserState_ThinkingDoneEatingWhitespace + } else { + p.state = glm46ParserState_CollectingContent + } + return events, true + } else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 { + // Partial closing tag - withhold it along with any trailing whitespace before it + beforePartialTag := acc[:len(acc)-overlapLen] + trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, glm46EventThinkingContent{content: unambiguous}) + } + return events, false + } else { + // Pure thinking content - withhold trailing whitespace (might precede closing tag) + whitespaceLen := trailingWhitespaceLen(acc) + ambiguousStart := len(acc) - whitespaceLen + + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, glm46EventThinkingContent{content: unambiguous}) + } + return events, false + } + + case glm46ParserState_ThinkingDoneEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent) + + case glm46ParserState_CollectingContent: + if strings.Contains(p.buffer.String(), glm46ToolOpenTag) { + before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true) + if len(before) > 0 { + events = append(events, glm46EventContent{content: before}) + } + if after == "" { + p.state = glm46ParserState_ToolStartedEatingWhitespace + } else { + p.state = glm46ParserState_CollectingToolContent + } + return events, true + } else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 { + beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen] + trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, glm46EventContent{content: unambiguous}) + } + return events, false + } else { + whitespaceLen := trailingWhitespaceLen(p.buffer.String()) + ambiguousStart := len(p.buffer.String()) - whitespaceLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, glm46EventContent{content: unambiguous}) + } + return events, false + } + + case glm46ParserState_ToolStartedEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent) + + case glm46ParserState_CollectingToolContent: + acc := p.buffer.String() + if strings.Contains(acc, glm46ToolCloseTag) { + toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true) + if len(toolContent) == 0 { + slog.Warn("glm46 tool call closing tag found but no content before it") + } + events = append(events, glm46EventRawToolCall{raw: toolContent}) + p.state = glm46ParserState_CollectingContent + return events, true + } else { + // Keep accumulating - tool calls are not streamed + // We just wait for the closing tag + return events, false + } + + default: + panic("unreachable") + } +} + +// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing +type GLMToolCallXML struct { + XMLName xml.Name `xml:"tool_call"` + Content string `xml:",chardata"` // Function name (text nodes between tags) + Keys []string `xml:"arg_key"` // All arg_key elements in document order + Values []string `xml:"arg_value"` // All arg_value elements in document order +} + +// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags +func escapeGLM46Content(s string) string { + var result strings.Builder + inTag := false + + for i := range len(s) { + ch := s[i] + + if ch == '<' { + // Check if this is a known tag + if strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") { + inTag = true + } + } + + if inTag { + result.WriteByte(ch) + if ch == '>' { + inTag = false + } + } else { + // Escape special characters in text content + switch ch { + case '&': + result.WriteString("&") + case '<': + result.WriteString("<") + case '>': + result.WriteString(">") + default: + result.WriteByte(ch) + } + } + } + + return result.String() +} + +func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + // Escape any unescaped entities in text content + // We need to escape text between tags, but not the tags themselves + escaped := escapeGLM46Content(raw.raw) + + // Wrap the content in a root element to make it valid XML + xmlString := "" + escaped + "" + + // Parse XML into struct + var parsed GLMToolCallXML + if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil { + return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err) + } + + // Extract and trim function name + functionName := strings.TrimSpace(parsed.Content) + if functionName == "" { + return api.ToolCall{}, fmt.Errorf("empty function name") + } + + // Verify keys and values are paired correctly + if len(parsed.Keys) != len(parsed.Values) { + return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values)) + } + + // Find the matching tool to get parameter types + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == functionName { + matchedTool = &tools[i] + break + } + } + + // Build arguments map by pairing keys and values + toolCall := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: functionName, + Arguments: api.NewToolCallFunctionArguments(), + }, + } + + for i := range parsed.Keys { + key := strings.TrimSpace(parsed.Keys[i]) + value := parsed.Values[i] // Don't trim here - parseValue handles it + + // Look up parameter type + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok { + // Handle anyOf by collecting all types from the union + if len(prop.AnyOf) > 0 { + for _, anyOfProp := range prop.AnyOf { + paramType = append(paramType, anyOfProp.Type...) + } + } else { + paramType = prop.Type + } + } + } + + // Parse value with type coercion + toolCall.Function.Arguments.Set(key, parseValue(value, paramType)) + } + + return toolCall, nil +} diff --git a/model/parsers/glm46_test.go b/model/parsers/glm46_test.go new file mode 100644 index 000000000..341b93fbe --- /dev/null +++ b/model/parsers/glm46_test.go @@ -0,0 +1,862 @@ +package parsers + +import ( + "encoding/xml" + "reflect" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestGLM46ParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []glm46Event + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "leading whitespace before think tag", + steps: []step{ + { + input: " \n\t ", + wantEvents: []glm46Event{}, + }, + { + input: "thinking", + wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}}, + }, + }, + }, + { + desc: "think tag with whitespace inside", + steps: []step{ + { + input: " \n thinking content \n regular content", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "thinking content"}, + glm46EventContent{content: "regular content"}, + }, + }, + }, + }, + { + desc: "tool call with leading whitespace after opening tag", + steps: []step{ + { + input: " \n test \n ", + wantEvents: []glm46Event{ + glm46EventRawToolCall{raw: "test"}, + }, + }, + }, + }, + { + desc: "simple thinking then content", + steps: []step{ + { + input: "I am thinkingNow I respond", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "I am thinking"}, + glm46EventContent{content: "Now I respond"}, + }, + }, + }, + }, + { + desc: "streamed thinking content", + steps: []step{ + { + input: "hello", + wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}}, + }, + { + input: " world", + wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}}, + }, + { + input: "content", + wantEvents: []glm46Event{ + glm46EventContent{content: "content"}, + }, + }, + }, + }, + { + desc: "content before tool call", + steps: []step{ + { + input: "Let me call a toolhere is text", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "Let me call a tool"}, + glm46EventContent{content: "here is text"}, + }, + }, + { + input: "function_name\nparam\nvalue\n", + wantEvents: []glm46Event{ + glm46EventRawToolCall{raw: "function_name\nparam\nvalue"}, + }, + }, + }, + }, + { + desc: "tool call with content after", + steps: []step{ + { + input: "thinkingtestafter tool", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "thinking"}, + glm46EventRawToolCall{raw: "test"}, + glm46EventContent{content: "after tool"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between content and tool call is trimmed", + steps: []step{ + { + input: "thinkingcontent\n \t test", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "thinking"}, + glm46EventContent{content: "content"}, + glm46EventRawToolCall{raw: "test"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between tool call and content is trimmed", + steps: []step{ + { + input: "thinktest\n\t after", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "think"}, + glm46EventRawToolCall{raw: "test"}, + glm46EventContent{content: "after"}, + }, + }, + }, + }, + { + desc: "split thinking close tag", + steps: []step{ + { + input: "thinking contentafter", + wantEvents: []glm46Event{ + glm46EventContent{content: "after"}, + }, + }, + }, + }, + { + desc: "split thinking open tag", + steps: []step{ + { + input: " content", + wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}}, + }, + }, + }, + { + desc: "split tool open tag", + steps: []step{ + { + input: "thinkcontentinside", + wantEvents: []glm46Event{}, + }, + { + input: "", + wantEvents: []glm46Event{ + glm46EventRawToolCall{raw: "inside"}, + }, + }, + }, + }, + { + desc: "partial thinking close tag fakeout", + steps: []step{ + { + input: "contentcontent\ncontent", + wantEvents: []glm46Event{ + glm46EventRawToolCall{raw: "contentcontent here", + wantEvents: []glm46Event{ + glm46EventContent{content: "content here"}, + }, + }, + }, + }, + { + desc: "multiple tool calls in sequence", + steps: []step{ + { + input: "thinkfirstbetweensecondend", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "think"}, + glm46EventRawToolCall{raw: "first"}, + glm46EventContent{content: "between"}, + glm46EventRawToolCall{raw: "second"}, + glm46EventContent{content: "end"}, + }, + }, + }, + }, + { + desc: "no thinking tag - direct to content", + steps: []step{ + { + input: "just content here", + wantEvents: []glm46Event{ + glm46EventContent{content: "just content here"}, + }, + }, + }, + }, + { + desc: "no thinking tag - skip to content then tool call", + steps: []step{ + { + input: "Here's the answer:testdone", + wantEvents: []glm46Event{ + glm46EventContent{content: "Here's the answer:"}, + glm46EventRawToolCall{raw: "test"}, + glm46EventContent{content: "done"}, + }, + }, + }, + }, + { + desc: "no thinking tag - whitespace preserved when no tags", + steps: []step{ + { + input: " \n content with leading whitespace", + wantEvents: []glm46Event{ + glm46EventContent{content: " \n content with leading whitespace"}, + }, + }, + }, + }, + { + desc: "whitespace after think close tag gets eaten", + steps: []step{ + { + input: "thinking \n\t content", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "thinking"}, + glm46EventContent{content: "content"}, + }, + }, + }, + }, + { + desc: "whitespace after tool_call close tag gets eaten", + steps: []step{ + { + input: "test \n\t content", + wantEvents: []glm46Event{ + glm46EventRawToolCall{raw: "test"}, + glm46EventContent{content: "content"}, + }, + }, + }, + }, + { + desc: "thinking content withholds trailing whitespace (single chunk)", + steps: []step{ + { + input: "thinking content ", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "thinking content"}, + }, + }, + { + input: "after", + wantEvents: []glm46Event{ + glm46EventContent{content: "after"}, + }, + }, + }, + }, + { + desc: "thinking content withholds trailing whitespace with newlines", + steps: []step{ + { + input: "thinking\n\n ", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "thinking"}, + }, + }, + { + input: "content", + wantEvents: []glm46Event{ + glm46EventContent{content: "content"}, + }, + }, + }, + }, + { + desc: "thinking content trailing whitespace emitted when more content arrives", + steps: []step{ + { + input: "thinking ", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: "thinking"}, + }, + }, + { + input: "more thinking", + wantEvents: []glm46Event{ + glm46EventThinkingContent{content: " more thinking"}, + }, + }, + { + input: "", + wantEvents: []glm46Event{}, + }, + }, + }, + { + desc: "thinking content withholds trailing whitespace before partial close tag", + steps: []step{ + { + input: "thinking content", + wantEvents: []glm46Event{ + glm46EventContent{content: "content"}, + }, + }, + }, + }, + } + + anyOnlies := false + for _, tc := range cases { + if tc.only { + anyOnlies = true + } + } + + for _, tc := range cases { + if anyOnlies && !tc.only { + continue + } + + t.Run(tc.desc, func(t *testing.T) { + parser := GLM46Parser{} + + for i, step := range tc.steps { + parser.buffer.WriteString(step.input) + gotEvents := parser.parseEvents() + + if len(gotEvents) == 0 && len(step.wantEvents) == 0 { + // avoid deep equal on empty vs. nil slices + continue + } + + if !reflect.DeepEqual(gotEvents, step.wantEvents) { + t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents) + } + } + }) + } +} + +// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves +// document order when collecting multiple elements with the same tag name into slices. +// This is a critical assumption for the GLM-4.6 parser's struct-based approach. +func TestGLMToolCallXMLOrderPreservation(t *testing.T) { + testCases := []struct { + name string + xml string + wantKeys []string + wantValues []string + }{ + { + name: "alternating keys and values", + xml: ` +function_name +first +A +second +B +third +C +`, + wantKeys: []string{"first", "second", "third"}, + wantValues: []string{"A", "B", "C"}, + }, + { + name: "all keys then all values", + xml: ` +function_name +key1 +key2 +key3 +val1 +val2 +val3 +`, + wantKeys: []string{"key1", "key2", "key3"}, + wantValues: []string{"val1", "val2", "val3"}, + }, + { + name: "mixed grouping", + xml: ` +function_name +a +1 +b +c +2 +3 +`, + wantKeys: []string{"a", "b", "c"}, + wantValues: []string{"1", "2", "3"}, + }, + { + name: "reverse order - all values then all keys", + xml: ` +function_name +X +Y +Z +x +y +z +`, + wantKeys: []string{"x", "y", "z"}, + wantValues: []string{"X", "Y", "Z"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var parsed GLMToolCallXML + err := xml.Unmarshal([]byte(tc.xml), &parsed) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) { + t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys) + } + + if !reflect.DeepEqual(parsed.Values, tc.wantValues) { + t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues) + } + }) + } +} + +func TestGLM46ToolCallParsing(t *testing.T) { + type testCase struct { + name string + rawToolCall string + tools []api.Tool + wantToolCall api.ToolCall + } + + cases := []testCase{ + { + name: "simple tool call", + tools: []api.Tool{}, + rawToolCall: `get-current-weather +location +New York, NY +unit +celsius`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get-current-weather", + Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`), + }, + }, + }, + { + name: "tool call with typed parameters", + tools: []api.Tool{ + tool("calculate", map[string]api.ToolProperty{ + "x": {Type: api.PropertyType{"number"}}, + "y": {Type: api.PropertyType{"integer"}}, + "enabled": {Type: api.PropertyType{"boolean"}}, + "items": {Type: api.PropertyType{"array"}}, + }), + }, + rawToolCall: `calculate +x +3.14 +y +42 +enabled +true +items +["a", "b", "c"]`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`), + }, + }, + }, + { + name: "function name with whitespace", + tools: []api.Tool{}, + rawToolCall: ` get-weather +city +Paris`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get-weather", + Arguments: args(`{"city": "Paris"}`), + }, + }, + }, + { + name: "values with special characters", + tools: []api.Tool{}, + rawToolCall: `execute-command +command +ls && echo "done" +message +a < b and c > d`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "execute-command", + Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`), + }, + }, + }, + { + name: "unicode in function names and values", + tools: []api.Tool{}, + rawToolCall: `获取天气 +城市 +北京 +message +Hello! 你好! 🌟`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "获取天气", + Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`), + }, + }, + }, + { + name: "empty value", + tools: []api.Tool{}, + rawToolCall: `test-function +param1 +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"param1": ""}`), + }, + }, + }, + { + name: "special chars in arg_key names", + tools: []api.Tool{}, + rawToolCall: `test-function +param<1> +value1 +a&b +value2 +x>y +value3`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`), + }, + }, + }, + { + name: "multiple consecutive ampersands", + tools: []api.Tool{}, + rawToolCall: `test-function +param +test &&&& more`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"param": "test &&&& more"}`), + }, + }, + }, + { + name: "mixed special chars together", + tools: []api.Tool{}, + rawToolCall: `test-function +param +<>&<>&`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"param": "<>&<>&"}`), + }, + }, + }, + { + name: "newlines and tabs in parameter values", + tools: []api.Tool{}, + rawToolCall: `test-function +multiline +line1 + indented line2 +line3`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`), + }, + }, + }, + { + name: "single and double quotes in values", + tools: []api.Tool{}, + rawToolCall: `test-function +quotes +She said "Hello's there!"`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"quotes": "She said \"Hello's there!\""}`), + }, + }, + }, + { + name: "CDATA-like content that should be treated as text", + tools: []api.Tool{}, + rawToolCall: `test-function +cdata +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"cdata": ""}`), + }, + }, + }, + { + name: "all special XML entities", + tools: []api.Tool{}, + rawToolCall: `test-function +entities +<>&'"`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"entities": "<>&'""}`), + }, + }, + }, + { + name: "order preservation with multiple parameters", + tools: []api.Tool{}, + rawToolCall: `test-function +first +value1 +second +value2 +third +value3 +fourth +value4 +fifth +value5`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`), + }, + }, + }, + { + name: "order preservation with identical key names but different positions", + tools: []api.Tool{}, + rawToolCall: `test-function +param +first occurrence +other +middle +param +second occurrence`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test-function", + // Later occurrence should overwrite earlier one + Arguments: args(`{"other": "middle", "param": "second occurrence"}`), + }, + }, + }, + { + name: "array with mixed types", + tools: []api.Tool{ + tool("process", map[string]api.ToolProperty{ + "items": {Type: api.PropertyType{"array"}}, + }), + }, + rawToolCall: `process +items +[1, "hello", true, null]`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "process", + Arguments: args(`{"items": [1, "hello", true, null]}`), + }, + }, + }, + { + name: "empty array", + tools: []api.Tool{ + tool("test", map[string]api.ToolProperty{ + "tags": {Type: api.PropertyType{"array"}}, + }), + }, + rawToolCall: `test +tags +[]`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test", + Arguments: args(`{"tags": []}`), + }, + }, + }, + { + name: "anyOf array or string - with array of objects", + tools: []api.Tool{ + tool("TodoWrite", map[string]api.ToolProperty{ + "todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}}, + }), + }, + // TodoWrite + // todos + // [{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}] + // + rawToolCall: `TodoWrite +todos +[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "TodoWrite", + Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`), + }, + }, + }, + { + name: "anyOf array or string - with plain string", + tools: []api.Tool{ + tool("TodoWrite", map[string]api.ToolProperty{ + "todos": {Type: api.PropertyType{"array", "string"}}, + }), + }, + rawToolCall: `TodoWrite +todos +Error: could not load todos`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "TodoWrite", + Arguments: args(`{"todos": "Error: could not load todos"}`), + }, + }, + }, + } + + for i, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools) + if err != nil { + t.Errorf("case %d (%s): %v", i, tc.name, err) + } + if !toolCallEqual(gotToolCall, tc.wantToolCall) { + t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall) + } + }) + } +} diff --git a/model/parsers/glm47.go b/model/parsers/glm47.go new file mode 100644 index 000000000..4b49934e8 --- /dev/null +++ b/model/parsers/glm47.go @@ -0,0 +1,20 @@ +package parsers + +import "github.com/ollama/ollama/api" + +// GLM47Parser extends GLM46Parser with thinking-aware initialization. +// GLM-4.7's prompt ends with when thinking is enabled, so the parser +// must start in CollectingThinking state (the model outputs thinking content directly). +type GLM47Parser struct { + GLM46Parser +} + +func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + // When thinking is enabled (nil or true), the prompt ends with , + // so model output starts directly with thinking content (no opening tag). + if thinkValue == nil || thinkValue.Bool() { + p.state = glm46ParserState_CollectingThinking + } + return tools +} diff --git a/model/parsers/glm47_test.go b/model/parsers/glm47_test.go new file mode 100644 index 000000000..26c5d7113 --- /dev/null +++ b/model/parsers/glm47_test.go @@ -0,0 +1,99 @@ +package parsers + +import ( + "reflect" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestGLM47ParserAdd(t *testing.T) { + parser := GLM47Parser{} + parser.Init([]api.Tool{ + tool("calculate", map[string]api.ToolProperty{ + "count": {Type: api.PropertyType{"integer"}}, + "enabled": {Type: api.PropertyType{"boolean"}}, + }), + }, nil, nil) + + // When thinking is enabled (thinkValue nil), the prompt ends with , + // so the model output does NOT include the opening tag. + content, thinking, calls, err := parser.Add("planAnswercalculatecount3enabledtrue", true) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + if thinking != "plan" { + t.Fatalf("expected thinking 'plan', got %q", thinking) + } + if content != "Answer" { + t.Fatalf("expected content 'Answer', got %q", content) + } + if len(calls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(calls)) + } + expectedArgs := args(`{"count": 3, "enabled": true}`) + if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) { + t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap()) + } +} + +func TestGLM47ParserNoThinkingContent(t *testing.T) { + parser := GLM47Parser{} + parser.Init(nil, nil, nil) + + // When thinking is enabled but model has no thinking to output, + // it should output immediately followed by content. + content, thinking, calls, err := parser.Add("Plain answer", true) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + if thinking != "" { + t.Fatalf("expected empty thinking, got %q", thinking) + } + if content != "Plain answer" { + t.Fatalf("expected content 'Plain answer', got %q", content) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } +} + +func TestGLM47ParserThinkingDisabled(t *testing.T) { + parser := GLM47Parser{} + // When thinking is disabled, parser stays in LookingForThinkingOpen state + parser.Init(nil, nil, &api.ThinkValue{Value: false}) + + // Model outputs plain content (prompt ended with ) + content, thinking, calls, err := parser.Add("Plain answer", true) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + if thinking != "" { + t.Fatalf("expected empty thinking, got %q", thinking) + } + if content != "Plain answer" { + t.Fatalf("expected content 'Plain answer', got %q", content) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } +} + +func TestGLM47ParserToolCallEscaping(t *testing.T) { + toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec +expr +a < b && c > d`}, nil) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + expected := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: args(`{"expr": "a < b && c > d"}`), + }, + } + if !reflect.DeepEqual(toolCall, expected) { + t.Fatalf("expected %#v, got %#v", expected, toolCall) + } +} diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 79039e52c..3a3261a04 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -68,6 +68,8 @@ func ParserForName(name string) Parser { return &Nemotron3NanoParser{} case "functiongemma": return &FunctionGemmaParser{} + case "glm-4.7": + return &GLM47Parser{} default: return nil } diff --git a/model/parsers/testhelpers_test.go b/model/parsers/testhelpers_test.go index 0c252be83..dc07c4536 100644 --- a/model/parsers/testhelpers_test.go +++ b/model/parsers/testhelpers_test.go @@ -96,3 +96,11 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments { } return args } + +func args(s string) api.ToolCallFunctionArguments { + var result api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(s), &result); err != nil { + panic("invalid JSON in args(): " + err.Error()) + } + return result +} diff --git a/model/renderers/glm46.go b/model/renderers/glm46.go new file mode 100644 index 000000000..d200e55f2 --- /dev/null +++ b/model/renderers/glm46.go @@ -0,0 +1,110 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/ollama/ollama/api" +) + +type GLM46Renderer struct{} + +func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { + var sb strings.Builder + + sb.WriteString("[gMASK]") + + var lastUserIndex int + for i, message := range messages { + if message.Role == "user" { + lastUserIndex = i + } + } + + if len(tools) > 0 { + sb.WriteString("<|system|>\n") + sb.WriteString("# Tools\n\n") + sb.WriteString("You may call one or more functions to assist with the user query.\n\n") + sb.WriteString("You are provided with function signatures within XML tags:\n") + sb.WriteString("\n") + for _, tool := range tools { + d, _ := json.Marshal(tool) + sb.WriteString(string(d) + "\n") + } + sb.WriteString("\n\n") + sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n") + sb.WriteString("{function-name}\n") + sb.WriteString("{arg-key-1}\n") + sb.WriteString("{arg-value-1}\n") + sb.WriteString("{arg-key-2}\n") + sb.WriteString("{arg-value-2}\n") + sb.WriteString("...\n") + sb.WriteString("") + } + + for i, message := range messages { + switch message.Role { + case "user": + sb.WriteString("<|user|>\n") + sb.WriteString(message.Content) + if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") { + sb.WriteString("/nothink") + } + case "assistant": + sb.WriteString("<|assistant|>") + if i > lastUserIndex { + if message.Thinking != "" { + sb.WriteString("\n" + message.Thinking + "") + } else { + sb.WriteString("\n") + } + } + if message.Content != "" { + sb.WriteString("\n" + message.Content) + } + if len(message.ToolCalls) > 0 { + for _, toolCall := range message.ToolCalls { + sb.WriteString("\n" + toolCall.Function.Name + "\n") + for key, value := range toolCall.Function.Arguments.All() { + sb.WriteString("" + key + "\n") + + var valueStr string + if str, ok := value.(string); ok { + valueStr = str + } else { + jsonBytes, err := json.Marshal(value) + if err != nil { + valueStr = fmt.Sprintf("%v", value) + } else { + valueStr = string(jsonBytes) + } + } + + sb.WriteString("" + valueStr + "\n") + } + + sb.WriteString("") + } + } + case "tool": + if i == 0 || messages[i-1].Role != "tool" { + sb.WriteString("<|observation|>") + } + sb.WriteString("\n\n") + sb.WriteString(message.Content) + sb.WriteString("\n") + case "system": + sb.WriteString("<|system|>\n") + sb.WriteString(message.Content) + } + } + + // Add generation prompt + sb.WriteString("<|assistant|>") + if thinkValue != nil && !thinkValue.Bool() { + sb.WriteString("\n\n") + } + + return sb.String(), nil +} diff --git a/model/renderers/glm46_test.go b/model/renderers/glm46_test.go new file mode 100644 index 000000000..92967521c --- /dev/null +++ b/model/renderers/glm46_test.go @@ -0,0 +1,223 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestGLM46Renderer(t *testing.T) { + tests := []struct { + name string + messages []api.Message + tools []api.Tool + thinkValue *api.ThinkValue + expected string + skip string + }{ + { + name: "basic", + messages: []api.Message{ + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `[gMASK]<|user|> +Hello, how are you?<|assistant|>`, + }, + { + name: "basic with system message", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `[gMASK]<|system|> +You are a helpful assistant.<|user|> +Hello, how are you?<|assistant|>`, + }, + { + name: "basic with user assistant user", + messages: []api.Message{ + {Role: "user", Content: "What is the capital of France?"}, + {Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."}, + {Role: "user", Content: "Fantastic!"}, + }, + expected: `[gMASK]<|user|> +What is the capital of France?<|assistant|> +The capital of France is Paris.<|user|> +Fantastic!<|assistant|>`, + }, + { + skip: "tool call ordering not guaranteed yet", + name: "tools", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "What is the weather like in Tokyo?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`), + }, + }, + }, + }, + expected: `[gMASK]<|system|> +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}} + + +For each function call, output the function name and arguments within the following XML format: +{function-name} +{arg-key-1} +{arg-value-1} +{arg-key-2} +{arg-value-2} +... +<|system|> +You are a helpful assistant with access to tools.<|user|> +What is the weather like in Tokyo?<|assistant|>`, + }, + { + skip: "tool call ordering not guaranteed yet", + name: "tool calls", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "What is the weather like in Tokyo?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`), + }, + }, + }, + }, + { + Role: "tool", + Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}", + ToolName: "get_weather", + }, + { + Role: "tool", + Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}", + ToolName: "get_weather", + }, + { + Role: "assistant", + Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.", + }, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`), + }, + }, + }, + }, + expected: `[gMASK]<|system|> +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}} + + +For each function call, output the function name and arguments within the following XML format: +{function-name} +{arg-key-1} +{arg-value-1} +{arg-key-2} +{arg-value-2} +... +<|system|> +You are a helpful assistant with access to tools.<|user|> +What is the weather like in Tokyo?<|assistant|> + +get_weather +location +Tokyo, Japan +unit +celsius + +get_weather +location +Japan +unit +fahrenheit +<|observation|> + +{"temperature": 22, "weather": "partly cloudy", "humidity": 65} + + +{"temperature": 68, "weather": "sunny", "humidity": 75} +<|assistant|> + +The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`, + }, + { + name: "think true", + messages: []api.Message{ + {Role: "user", Content: "Hello, how are you?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `[gMASK]<|user|> +Hello, how are you?<|assistant|>`, + }, + { + name: "think false", + messages: []api.Message{ + {Role: "user", Content: "Hello, how are you?"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `[gMASK]<|user|> +Hello, how are you?/nothink<|assistant|> + +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skip != "" { + t.Skip(tt.skip) + } + renderer := &GLM46Renderer{} + rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + t.Logf("Got:\n%s", rendered) + t.Logf("Expected:\n%s", tt.expected) + } + }) + } +} diff --git a/model/renderers/glm47.go b/model/renderers/glm47.go new file mode 100644 index 000000000..d095b1d6e --- /dev/null +++ b/model/renderers/glm47.go @@ -0,0 +1,170 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/ollama/ollama/api" +) + +// GLM47Renderer renders messages for GLM-4.7 models. +// +// GLM-4.7 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode): +// +// 1. INTERLEAVED THINKING +// The model thinks between tool calls and after receiving tool results. +// This enables complex step-by-step reasoning: interpreting each tool output +// before deciding what to do next. Thinking blocks are preserved and returned +// with tool results to maintain reasoning continuity. +// +// 2. PRESERVED THINKING +// The model retains reasoning content from previous assistant turns in context. +// This preserves reasoning continuity across multi-turn conversations. The +// upstream API has a "clear_thinking" parameter to control this: +// - clear_thinking=true: clears reasoning from previous turns (outputs ) +// - clear_thinking=false: preserves ... blocks from previous turns +// +// 3. TURN-LEVEL THINKING +// Controls whether the model should reason on each turn. The upstream API +// uses "enable_thinking" parameter: +// - enable_thinking=true: outputs to start reasoning +// - enable_thinking=false: outputs to skip reasoning +// +// OLLAMA DEFAULTS: +// - Thinking is ENABLED by default (thinkValue=nil or true outputs ) +// - Thinking is PRESERVED by default (reasoning content from previous turns is always +// included in ... blocks, equivalent to clear_thinking=false) +// - Users can disable thinking per-turn via thinkValue=false +type GLM47Renderer struct{} + +func (r *GLM47Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { + var sb strings.Builder + + sb.WriteString("[gMASK]") + + if len(tools) > 0 { + sb.WriteString("<|system|>\n") + sb.WriteString("# Tools\n\n") + sb.WriteString("You may call one or more functions to assist with the user query.\n\n") + sb.WriteString("You are provided with function signatures within XML tags:\n") + sb.WriteString("\n") + for _, tool := range tools { + d, _ := json.Marshal(tool) + sb.WriteString(formatGLM47ToolJSON(d)) + sb.WriteString("\n") + } + sb.WriteString("\n\n") + sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n") + sb.WriteString("{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...") + } + + think := true + if thinkValue != nil && !thinkValue.Bool() { + think = false + } + + for i, message := range messages { + switch message.Role { + case "user": + sb.WriteString("<|user|>") + sb.WriteString(message.Content) + case "assistant": + sb.WriteString("<|assistant|>") + if message.Thinking != "" { + sb.WriteString("" + message.Thinking + "") + } else { + sb.WriteString("") + } + if message.Content != "" { + sb.WriteString(message.Content) + } + if len(message.ToolCalls) > 0 { + for _, toolCall := range message.ToolCalls { + sb.WriteString("" + toolCall.Function.Name) + sb.WriteString(renderGLM47ToolArguments(toolCall.Function.Arguments)) + sb.WriteString("") + } + } + case "tool": + if i == 0 || messages[i-1].Role != "tool" { + sb.WriteString("<|observation|>") + } + sb.WriteString("") + sb.WriteString(message.Content) + sb.WriteString("") + case "system": + sb.WriteString("<|system|>") + sb.WriteString(message.Content) + } + } + + sb.WriteString("<|assistant|>") + if think { + sb.WriteString("") + } else { + sb.WriteString("") + } + + return sb.String(), nil +} + +func renderGLM47ToolArguments(args api.ToolCallFunctionArguments) string { + var sb strings.Builder + for key, value := range args.All() { + sb.WriteString("" + key + "") + var valueStr string + if str, ok := value.(string); ok { + valueStr = str + } else { + jsonBytes, err := json.Marshal(value) + if err != nil { + valueStr = fmt.Sprintf("%v", value) + } else { + valueStr = string(jsonBytes) + } + } + + sb.WriteString("" + valueStr + "") + } + + return sb.String() +} + +func formatGLM47ToolJSON(raw []byte) string { + var sb strings.Builder + sb.Grow(len(raw) + len(raw)/10) + + inString := false + escaped := false + for i := range raw { + ch := raw[i] + sb.WriteByte(ch) + + if inString { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + inString = false + } + continue + } + + if ch == '"' { + inString = true + continue + } + + if ch == ':' || ch == ',' { + sb.WriteByte(' ') + } + } + + return sb.String() +} diff --git a/model/renderers/glm47_test.go b/model/renderers/glm47_test.go new file mode 100644 index 000000000..e44ce80de --- /dev/null +++ b/model/renderers/glm47_test.go @@ -0,0 +1,191 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestGLM47Renderer(t *testing.T) { + tests := []struct { + name string + messages []api.Message + tools []api.Tool + thinkValue *api.ThinkValue + expected string + }{ + { + name: "basic user message", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + expected: "[gMASK]<|user|>Hello<|assistant|>", + }, + { + name: "thinking disabled", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "[gMASK]<|user|>Hello<|assistant|>", + }, + { + name: "system and user", + messages: []api.Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hello"}, + }, + expected: "[gMASK]<|system|>You are helpful.<|user|>Hello<|assistant|>", + }, + { + name: "multi-turn conversation", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello there"}, + {Role: "user", Content: "How are you?"}, + }, + expected: "[gMASK]<|user|>Hi<|assistant|>Hello there<|user|>How are you?<|assistant|>", + }, + { + name: "assistant with reasoning_content", + messages: []api.Message{ + {Role: "user", Content: "Answer with reasoning."}, + {Role: "assistant", Thinking: "Plan.", Content: "Done."}, + }, + expected: "[gMASK]<|user|>Answer with reasoning.<|assistant|>Plan.Done.<|assistant|>", + }, + { + name: "tool call with empty content", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: args(`{"location": "Tokyo", "unit": "celsius"}`), + }, + }, + }, + }, + {Role: "tool", Content: `{"temperature":22}`}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: propsMap(`{"location": {"type": "string"}}`), + }, + }, + }, + }, + expected: "[gMASK]<|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n\n\nFor each function call, output the function name and arguments within the following XML format:\n{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...<|user|>Weather?<|assistant|>get_weatherlocationTokyounitcelsius<|observation|>{\"temperature\":22}<|assistant|>", + }, + { + name: "tool call with content", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + Content: "Let me check", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: args(`{"location": "Tokyo"}`), + }, + }, + }, + }, + {Role: "tool", Content: `{"temperature":22}`}, + {Role: "assistant", Content: "It is 22C."}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: propsMap(`{"location": {"type": "string"}}`), + }, + }, + }, + }, + expected: "[gMASK]<|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n\n\nFor each function call, output the function name and arguments within the following XML format:\n{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...<|user|>Weather?<|assistant|>Let me checkget_weatherlocationTokyo<|observation|>{\"temperature\":22}<|assistant|>It is 22C.<|assistant|>", + }, + { + name: "multiple tool calls and responses", + messages: []api.Message{ + {Role: "user", Content: "Compare weather"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: args(`{"location": "Tokyo"}`), + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: args(`{"location": "Paris"}`), + }, + }, + }, + }, + {Role: "tool", Content: `{"temperature":22}`}, + {Role: "tool", Content: `{"temperature":18}`}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: propsMap(`{"location": {"type": "string"}}`), + }, + }, + }, + }, + expected: "[gMASK]<|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n\n\nFor each function call, output the function name and arguments within the following XML format:\n{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...<|user|>Compare weather<|assistant|>get_weatherlocationTokyoget_weatherlocationParis<|observation|>{\"temperature\":22}{\"temperature\":18}<|assistant|>", + }, + { + name: "preserved thinking in multi-turn", + messages: []api.Message{ + {Role: "user", Content: "Think step by step"}, + {Role: "assistant", Thinking: "Let me think...", Content: "Here's my answer."}, + {Role: "user", Content: "Continue"}, + }, + expected: "[gMASK]<|user|>Think step by step<|assistant|>Let me think...Here's my answer.<|user|>Continue<|assistant|>", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + renderer := &GLM47Renderer{} + rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + t.Logf("Got:\n%s", rendered) + t.Logf("Expected:\n%s", tt.expected) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 2aed5dca0..dbb63b07c 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -80,6 +80,8 @@ func rendererForName(name string) Renderer { return &Nemotron3NanoRenderer{} case "functiongemma": return &FunctionGemmaRenderer{} + case "glm-4.7": + return &GLM47Renderer{} default: return nil } diff --git a/model/renderers/testhelpers_test.go b/model/renderers/testhelpers_test.go index 8b628e1a9..a121bcc03 100644 --- a/model/renderers/testhelpers_test.go +++ b/model/renderers/testhelpers_test.go @@ -1,6 +1,26 @@ package renderers -import "github.com/ollama/ollama/api" +import ( + "encoding/json" + + "github.com/ollama/ollama/api" +) + +func args(s string) api.ToolCallFunctionArguments { + var result api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(s), &result); err != nil { + panic("invalid JSON in args(): " + err.Error()) + } + return result +} + +func propsMap(s string) *api.ToolPropertiesMap { + var result api.ToolPropertiesMap + if err := json.Unmarshal([]byte(s), &result); err != nil { + panic("invalid JSON in propsMap(): " + err.Error()) + } + return &result +} // testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved) func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {