diff --git a/convert/convert.go b/convert/convert.go index abb0bc336..f40682330 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -316,8 +316,10 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { conv = &glm4MoeLiteModel{} case "GlmOcrForConditionalGeneration": conv = &glmOcrModel{} - case "Lfm2ForCausalLM": + case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM": conv = &lfm2Model{} + case "Lfm2VlForConditionalGeneration": + conv = &lfm2VLTextModel{} case "Qwen3NextForCausalLM": conv = &qwen3NextModel{} case "NemotronHForCausalLM": diff --git a/convert/convert_lfm2.go b/convert/convert_lfm2.go index fdae1074c..76cb78777 100644 --- a/convert/convert_lfm2.go +++ b/convert/convert_lfm2.go @@ -1,6 +1,8 @@ package convert import ( + "cmp" + "fmt" "slices" "strings" @@ -13,42 +15,149 @@ type lfm2Model struct { NumHiddenLayers uint32 `json:"num_hidden_layers"` MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` IntermediateSize uint32 `json:"intermediate_size"` + BlockFFDim uint32 `json:"block_ff_dim"` + BlockMultipleOf uint32 `json:"block_multiple_of"` + BlockAutoAdjustFFDim bool `json:"block_auto_adjust_ff_dim"` + BlockFFNDimMultiplier float32 `json:"block_ffn_dim_multiplier"` NumAttentionHeads uint32 `json:"num_attention_heads"` NumKeyValueHeads uint32 `json:"num_key_value_heads"` RopeTheta float32 `json:"rope_theta"` NormEps float32 `json:"norm_eps"` ConvLCache uint32 `json:"conv_L_cache"` + MoEIntermediateSize uint32 `json:"moe_intermediate_size"` + NumExperts uint32 `json:"num_experts"` + NumLocalExperts uint32 `json:"num_local_experts"` + NumExpertsPerToken uint32 `json:"num_experts_per_tok"` + NumDenseLayers uint32 `json:"num_dense_layers"` + RoutedScalingFactor float32 `json:"routed_scaling_factor"` LayerTypes []string `json:"layer_types"` TieEmbedding bool `json:"tie_embedding"` + RopeParameters struct { + RopeTheta float32 `json:"rope_theta"` + } `json:"rope_parameters"` } var _ ModelConverter = (*lfm2Model)(nil) +const ( + defaultMaxPositionEmbeddings = uint32(128_000) + fallbackContextLength = uint32(32_768) +) + +func (p *lfm2Model) isMoE() bool { + return p.ModelType == "lfm2_moe" || p.expertCount() > 0 +} + +func (p *lfm2Model) ropeFreqBase() float32 { + if p.RopeTheta != 0 { + return p.RopeTheta + } + + return p.RopeParameters.RopeTheta +} + +func (p *lfm2Model) expertCount() uint32 { + if p.NumLocalExperts > 0 { + return p.NumLocalExperts + } + return p.NumExperts +} + +func (p *lfm2Model) feedForwardLength() uint32 { + ff := p.IntermediateSize + if p.BlockFFDim != 0 { + ff = p.BlockFFDim + } + + if !p.BlockAutoAdjustFFDim || p.BlockMultipleOf == 0 { + return ff + } + + ff = (2 * ff) / 3 + + // Keep default multiplier behavior consistent with llama.cpp conversion. + if p.BlockFFNDimMultiplier != 0 { + ff = uint32(float32(ff) * p.BlockFFNDimMultiplier) + } + + m := p.BlockMultipleOf + return m * ((ff + m - 1) / m) +} + +func (p *lfm2Model) hasKnownContextLengthFallbackSignature() bool { + return p.isMoE() && + p.VocabSize == 65536 && + p.HiddenSize == 2048 && + p.NumHiddenLayers == 40 && + p.IntermediateSize == 11776 && + p.NumAttentionHeads == 32 && + p.NumKeyValueHeads == 8 && + p.NumDenseLayers == 2 && + p.expertCount() == 64 && + p.NumExpertsPerToken == 4 && + p.MoEIntermediateSize == 1536 +} + +func (p *lfm2Model) contextLength() uint32 { + if p.MaxPositionEmbeddings == defaultMaxPositionEmbeddings && p.hasKnownContextLengthFallbackSignature() { + return fallbackContextLength + } + + return p.MaxPositionEmbeddings +} + func (p *lfm2Model) KV(t *Tokenizer) KV { + architecture := "lfm2" + if p.isMoE() { + architecture = "lfm2moe" + } + kv := p.ModelParameters.KV(t) - kv["general.architecture"] = "lfm2" - kv["lfm2.vocab_size"] = p.VocabSize - kv["lfm2.block_count"] = p.NumHiddenLayers - kv["lfm2.embedding_length"] = p.HiddenSize - kv["lfm2.feed_forward_length"] = p.IntermediateSize - kv["lfm2.context_length"] = p.MaxPositionEmbeddings + kv["general.architecture"] = architecture + kv["tokenizer.ggml.pre"] = "lfm2" + kv["vocab_size"] = p.VocabSize + kv["block_count"] = p.NumHiddenLayers + kv["embedding_length"] = p.HiddenSize + kv["feed_forward_length"] = p.feedForwardLength() + kv["context_length"] = p.contextLength() // Build per-layer KV head count array based on layer_types - // (0 = shortconv layer, non-zero = attention layer with that many KV heads) + // (0 = shortconv layer, non-zero = attention layer with that many KV heads). + // + // Dense LFM2 in HF defaults to all attention layers when layer_types is absent. + // Preserve that behavior to avoid accidentally emitting all-conv metadata. kvHeadCounts := make([]uint32, p.NumHiddenLayers) - for i := range p.NumHiddenLayers { - if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" { + if len(p.LayerTypes) == 0 { + for i := range p.NumHiddenLayers { kvHeadCounts[i] = p.NumKeyValueHeads } + } else { + for i := range p.NumHiddenLayers { + if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" { + kvHeadCounts[i] = p.NumKeyValueHeads + } + } } - kv["lfm2.attention.head_count"] = p.NumAttentionHeads - kv["lfm2.attention.head_count_kv"] = kvHeadCounts - kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads - kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads - kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps - kv["lfm2.rope.freq_base"] = p.RopeTheta - kv["lfm2.shortconv.l_cache"] = p.ConvLCache + kv["attention.head_count"] = p.NumAttentionHeads + kv["attention.head_count_kv"] = kvHeadCounts + kv["attention.key_length"] = p.HiddenSize / p.NumAttentionHeads + kv["attention.value_length"] = p.HiddenSize / p.NumAttentionHeads + kv["attention.layer_norm_rms_epsilon"] = p.NormEps + kv["shortconv.l_cache"] = p.ConvLCache + + if ropeFreqBase := p.ropeFreqBase(); ropeFreqBase != 0 { + kv["rope.freq_base"] = ropeFreqBase + } + + if p.isMoE() { + kv["expert_count"] = p.expertCount() + kv["expert_used_count"] = p.NumExpertsPerToken + kv["expert_feed_forward_length"] = p.MoEIntermediateSize + kv["leading_dense_block_count"] = p.NumDenseLayers + kv["expert_gating_func"] = uint32(2) // sigmoid + kv["expert_weights_scale"] = cmp.Or(p.RoutedScalingFactor, float32(1.0)) + } return kv } @@ -56,6 +165,30 @@ func (p *lfm2Model) KV(t *Tokenizer) KV { func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor { var out []*ggml.Tensor + if p.isMoE() { + merges := make([]merge, 0, p.NumHiddenLayers*3) + for i := range p.NumHiddenLayers { + if i < p.NumDenseLayers { + continue + } + + merges = append(merges, merge{ + fmt.Sprintf("blk.%d.feed_forward.experts.*.w1.weight", i), + fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), + }, merge{ + fmt.Sprintf("blk.%d.feed_forward.experts.*.w2.weight", i), + fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), + }, merge{ + fmt.Sprintf("blk.%d.feed_forward.experts.*.w3.weight", i), + fmt.Sprintf("blk.%d.ffn_up_exps.weight", i), + }) + } + + merged, remaining := mergeTensors(ts, merges...) + out = append(out, merged...) + ts = remaining + } + for _, t := range ts { shape := t.Shape() @@ -80,7 +213,7 @@ func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor { func (p *lfm2Model) Replacements() []string { return []string{ "model.embed_tokens", "token_embd", - "model.embedding_norm", "output_norm", + "model.embedding_norm", "token_embd_norm", "model.layers", "blk", "operator_norm", "attn_norm", "self_attn.q_proj", "attn_q", @@ -92,6 +225,8 @@ func (p *lfm2Model) Replacements() []string { "conv.conv", "shortconv.conv", "conv.in_proj", "shortconv.in_proj", "conv.out_proj", "shortconv.out_proj", + "feed_forward.gate", "ffn_gate_inp", + "feed_forward.expert_bias", "exp_probs_b.bias", "feed_forward.w1", "ffn_gate", "feed_forward.w2", "ffn_down", "feed_forward.w3", "ffn_up", diff --git a/convert/convert_lfm2_test.go b/convert/convert_lfm2_test.go new file mode 100644 index 000000000..f260fa07d --- /dev/null +++ b/convert/convert_lfm2_test.go @@ -0,0 +1,271 @@ +package convert + +import ( + "io" + "slices" + "strings" + "testing" +) + +type lfm2StubTensor struct { + tensorBase +} + +func newLFM2StubTensor(name string, shape []uint64) *lfm2StubTensor { + return &lfm2StubTensor{ + tensorBase: tensorBase{ + name: name, + shape: shape, + }, + } +} + +func (t *lfm2StubTensor) WriteTo(io.Writer) (int64, error) { + return 0, nil +} + +func (t *lfm2StubTensor) Clone() Tensor { + return &lfm2StubTensor{ + tensorBase: tensorBase{ + name: t.name, + shape: slices.Clone(t.shape), + }, + } +} + +func TestLFM2MoEKV(t *testing.T) { + var p lfm2Model + p.ModelParameters.ModelType = "lfm2_moe" + p.VocabSize = 65536 + p.HiddenSize = 2048 + p.NumHiddenLayers = 4 + p.MaxPositionEmbeddings = 128000 + p.IntermediateSize = 11776 + p.NumAttentionHeads = 32 + p.NumKeyValueHeads = 8 + p.LayerTypes = []string{"conv", "full_attention", "conv", "full_attention"} + p.NormEps = 1e-5 + p.ConvLCache = 3 + p.MoEIntermediateSize = 1536 + p.NumExperts = 64 + p.NumExpertsPerToken = 4 + p.NumDenseLayers = 2 + p.RopeParameters.RopeTheta = 1_000_000 + + kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}}) + + if got, want := kv["general.architecture"], "lfm2moe"; got != want { + t.Fatalf("general.architecture = %v, want %v", got, want) + } + if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want { + t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want) + } + + if got, want := kv["expert_count"], uint32(64); got != want { + t.Fatalf("expert_count = %v, want %v", got, want) + } + + if got, want := kv["expert_used_count"], uint32(4); got != want { + t.Fatalf("expert_used_count = %v, want %v", got, want) + } + + if got, want := kv["expert_feed_forward_length"], uint32(1536); got != want { + t.Fatalf("expert_feed_forward_length = %v, want %v", got, want) + } + + if got, want := kv["leading_dense_block_count"], uint32(2); got != want { + t.Fatalf("leading_dense_block_count = %v, want %v", got, want) + } + + if got, want := kv["expert_gating_func"], uint32(2); got != want { + t.Fatalf("expert_gating_func = %v, want %v", got, want) + } + + gotHeadCounts, ok := kv["attention.head_count_kv"].([]uint32) + if !ok { + t.Fatalf("attention.head_count_kv has unexpected type %T", kv["attention.head_count_kv"]) + } + + wantHeadCounts := []uint32{0, 8, 0, 8} + if !slices.Equal(gotHeadCounts, wantHeadCounts) { + t.Fatalf("attention.head_count_kv = %v, want %v", gotHeadCounts, wantHeadCounts) + } + + if got, want := kv["rope.freq_base"], float32(1_000_000); got != want { + t.Fatalf("rope.freq_base = %v, want %v", got, want) + } +} + +func TestLFM2DenseKV(t *testing.T) { + p := lfm2Model{ + ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 32000}, + HiddenSize: 1024, + NumHiddenLayers: 2, + MaxPositionEmbeddings: 32768, + IntermediateSize: 4096, + NumAttentionHeads: 16, + NumKeyValueHeads: 4, + LayerTypes: []string{"conv", "full_attention"}, + NormEps: 1e-5, + ConvLCache: 3, + RopeTheta: 10000, + } + + kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}}) + + if got, want := kv["general.architecture"], "lfm2"; got != want { + t.Fatalf("general.architecture = %v, want %v", got, want) + } + if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want { + t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want) + } + + if _, ok := kv["expert_count"]; ok { + t.Fatalf("expert_count should not be set for dense lfm2") + } +} + +func TestLFM2MoETensors(t *testing.T) { + p := lfm2Model{ + ModelParameters: ModelParameters{ModelType: "lfm2_moe"}, + NumHiddenLayers: 4, + NumDenseLayers: 2, + } + + in := []Tensor{ + newLFM2StubTensor("blk.2.feed_forward.experts.0.w1.weight", []uint64{1536, 2048}), + newLFM2StubTensor("blk.2.feed_forward.experts.1.w1.weight", []uint64{1536, 2048}), + newLFM2StubTensor("blk.2.feed_forward.experts.0.w2.weight", []uint64{2048, 1536}), + newLFM2StubTensor("blk.2.feed_forward.experts.1.w2.weight", []uint64{2048, 1536}), + newLFM2StubTensor("blk.2.feed_forward.experts.0.w3.weight", []uint64{1536, 2048}), + newLFM2StubTensor("blk.2.feed_forward.experts.1.w3.weight", []uint64{1536, 2048}), + newLFM2StubTensor("blk.0.shortconv.conv.weight", []uint64{2048, 1, 3}), + } + + out := p.Tensors(in) + + byName := make(map[string][]uint64, len(out)) + for _, tns := range out { + byName[tns.Name] = tns.Shape + } + + if got, ok := byName["blk.2.ffn_gate_exps.weight"]; !ok { + t.Fatalf("missing merged tensor blk.2.ffn_gate_exps.weight") + } else if !slices.Equal(got, []uint64{2, 1536, 2048}) { + t.Fatalf("blk.2.ffn_gate_exps.weight shape = %v, want [2 1536 2048]", got) + } + + if got, ok := byName["blk.2.ffn_down_exps.weight"]; !ok { + t.Fatalf("missing merged tensor blk.2.ffn_down_exps.weight") + } else if !slices.Equal(got, []uint64{2, 2048, 1536}) { + t.Fatalf("blk.2.ffn_down_exps.weight shape = %v, want [2 2048 1536]", got) + } + + if got, ok := byName["blk.2.ffn_up_exps.weight"]; !ok { + t.Fatalf("missing merged tensor blk.2.ffn_up_exps.weight") + } else if !slices.Equal(got, []uint64{2, 1536, 2048}) { + t.Fatalf("blk.2.ffn_up_exps.weight shape = %v, want [2 1536 2048]", got) + } + + if got, ok := byName["blk.0.shortconv.conv.weight"]; !ok { + t.Fatalf("missing shortconv tensor") + } else if !slices.Equal(got, []uint64{2048, 3}) { + t.Fatalf("blk.0.shortconv.conv.weight shape = %v, want [2048 3]", got) + } + + if _, ok := byName["blk.2.feed_forward.experts.0.w1.weight"]; ok { + t.Fatalf("unmerged expert tensor should not be present") + } +} + +func TestLFM2MoEReplacements(t *testing.T) { + p := lfm2Model{} + replacer := strings.NewReplacer(p.Replacements()...) + + if got, want := replacer.Replace("model.layers.2.feed_forward.expert_bias"), "blk.2.exp_probs_b.bias"; got != want { + t.Fatalf("expert bias replacement = %q, want %q", got, want) + } + + if got, want := replacer.Replace("model.layers.2.feed_forward.gate.weight"), "blk.2.ffn_gate_inp.weight"; got != want { + t.Fatalf("gate replacement = %q, want %q", got, want) + } +} + +func TestLFM2KVContextLengthEdgeCaseFallbackOverride(t *testing.T) { + p := lfm2Model{ + ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536}, + HiddenSize: 2048, + NumHiddenLayers: 40, + MaxPositionEmbeddings: 128000, + IntermediateSize: 11776, + NumAttentionHeads: 32, + NumKeyValueHeads: 8, + LayerTypes: make([]string, 40), + NormEps: 1e-5, + ConvLCache: 3, + MoEIntermediateSize: 1536, + NumExperts: 64, + NumExpertsPerToken: 4, + NumDenseLayers: 2, + } + for i := 0; i < len(p.LayerTypes); i++ { + p.LayerTypes[i] = "conv" + } + p.LayerTypes[2] = "full_attention" + + kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}}) + + if got, want := kv["context_length"], uint32(32768); got != want { + t.Fatalf("context_length = %v, want %v", got, want) + } +} + +func TestLFM2KVContextLengthNoOverride(t *testing.T) { + p := lfm2Model{ + ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536}, + HiddenSize: 2048, + NumHiddenLayers: 39, // mismatch: should not trigger edge case + MaxPositionEmbeddings: 128000, + IntermediateSize: 11776, + NumAttentionHeads: 32, + NumKeyValueHeads: 8, + LayerTypes: []string{"conv", "full_attention"}, + NormEps: 1e-5, + ConvLCache: 3, + MoEIntermediateSize: 1536, + NumExperts: 64, + NumExpertsPerToken: 4, + NumDenseLayers: 2, + } + + kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}}) + + if got, want := kv["context_length"], uint32(128000); got != want { + t.Fatalf("context_length = %v, want %v", got, want) + } +} + +func TestLFM2KVFeedForwardLengthAutoAdjust(t *testing.T) { + p := lfm2Model{ + ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536}, + HiddenSize: 2048, + NumHiddenLayers: 16, + MaxPositionEmbeddings: 128000, + IntermediateSize: 12288, // should be ignored when block_ff_dim is set + BlockFFDim: 12288, + BlockAutoAdjustFFDim: true, + BlockMultipleOf: 256, + BlockFFNDimMultiplier: 1.0, + NumAttentionHeads: 32, + NumKeyValueHeads: 8, + LayerTypes: []string{"conv", "full_attention"}, + NormEps: 1e-5, + ConvLCache: 3, + } + + kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}}) + + if got, want := kv["feed_forward_length"], uint32(8192); got != want { + t.Fatalf("feed_forward_length = %v, want %v", got, want) + } +} diff --git a/convert/convert_lfm2_vl.go b/convert/convert_lfm2_vl.go new file mode 100644 index 000000000..a0b4c46dc --- /dev/null +++ b/convert/convert_lfm2_vl.go @@ -0,0 +1,417 @@ +package convert + +import ( + "cmp" + "encoding/json" + "errors" + "fmt" + "io/fs" + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" +) + +// lfm2VLTextModel converts the language model component of LFM2 VL checkpoints. +type lfm2VLTextModel struct { + TextConfig lfm2Model `json:"text_config"` + DoImageSplitting *bool `json:"do_image_splitting"` + DownsampleFactor uint32 `json:"downsample_factor"` + EncoderPatchSize uint32 `json:"encoder_patch_size"` + ImageTokenID uint32 `json:"image_token_id"` + MaxImageTokens uint32 `json:"max_image_tokens"` + MinImageTokens uint32 `json:"min_image_tokens"` + MaxTiles uint32 `json:"max_tiles"` + MinTiles uint32 `json:"min_tiles"` + TileSize uint32 `json:"tile_size"` + MaxPixelsTolerance float32 `json:"max_pixels_tolerance"` + ProjectorUseLayernorm bool `json:"projector_use_layernorm"` + ProjectorHiddenSize uint32 `json:"projector_hidden_size"` + ProjectorHiddenAct string `json:"projector_hidden_act"` + UseImageSpecialTokens *bool `json:"use_image_special_tokens"` + UseThumbnail *bool `json:"use_thumbnail"` + VisionConfig struct { + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + NumChannels uint32 `json:"num_channels"` + PatchSize uint32 `json:"patch_size"` + LayerNormEpsilon float32 `json:"layer_norm_eps"` + } `json:"vision_config"` + Processor struct { + ImageProcessor struct { + DoImageSplitting *bool `json:"do_image_splitting"` + DownsampleFactor uint32 `json:"downsample_factor"` + MaxImageTokens uint32 `json:"max_image_tokens"` + MinImageTokens uint32 `json:"min_image_tokens"` + MaxTiles uint32 `json:"max_tiles"` + MinTiles uint32 `json:"min_tiles"` + MaxPixelsTol float32 `json:"max_pixels_tolerance"` + TileSize uint32 `json:"tile_size"` + UseThumbnail *bool `json:"use_thumbnail"` + ImageMean []float32 `json:"image_mean"` + ImageStd []float32 `json:"image_std"` + Size struct { + Height uint32 `json:"height"` + Width uint32 `json:"width"` + } `json:"size"` + } `json:"image_processor"` + } +} + +func (p *lfm2VLTextModel) textModel() *lfm2Model { + return &p.TextConfig +} + +func (p *lfm2VLTextModel) specialTokenTypes() []string { + return p.textModel().specialTokenTypes() +} + +func (p *lfm2VLTextModel) parseMore(fsys fs.FS) error { + bts, err := fs.ReadFile(fsys, "processor_config.json") + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err + } + + return json.Unmarshal(bts, &p.Processor) +} + +func (p *lfm2VLTextModel) visionImageSize() uint32 { + // LFM2-VL image processor operates on 512 tiles and downsamples by factor 2 + // before projection. Keep a fixed square image size compatible with position + // embeddings and the simplified runtime image pipeline. + tile := cmp.Or( + p.Processor.ImageProcessor.TileSize, + p.Processor.ImageProcessor.Size.Height, + p.Processor.ImageProcessor.Size.Width, + uint32(512), + ) + downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2)) + if downsample == 0 { + return tile + } + + return max(uint32(1), tile/downsample) +} + +func (p *lfm2VLTextModel) KV(t *Tokenizer) KV { + kv := p.textModel().KV(t) + + boolOr := func(defaultValue bool, values ...*bool) bool { + for _, v := range values { + if v != nil { + return *v + } + } + return defaultValue + } + + kv["vision.block_count"] = cmp.Or(p.VisionConfig.NumHiddenLayers, uint32(27)) + kv["vision.embedding_length"] = cmp.Or(p.VisionConfig.HiddenSize, uint32(1152)) + kv["vision.feed_forward_length"] = cmp.Or(p.VisionConfig.IntermediateSize, uint32(4304)) + kv["vision.attention.head_count"] = cmp.Or(p.VisionConfig.NumAttentionHeads, uint32(16)) + kv["vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionConfig.LayerNormEpsilon, float32(1e-6)) + kv["vision.patch_size"] = cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16)) + kv["vision.num_channels"] = cmp.Or(p.VisionConfig.NumChannels, uint32(3)) + kv["vision.image_size"] = p.visionImageSize() + kv["vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2)) + kv["vision.projector.use_layernorm"] = p.ProjectorUseLayernorm + kv["vision.do_image_splitting"] = boolOr(true, p.DoImageSplitting, p.Processor.ImageProcessor.DoImageSplitting) + kv["vision.min_tiles"] = cmp.Or(p.MinTiles, p.Processor.ImageProcessor.MinTiles, uint32(2)) + kv["vision.max_tiles"] = cmp.Or(p.MaxTiles, p.Processor.ImageProcessor.MaxTiles, uint32(10)) + kv["vision.tile_size"] = cmp.Or(p.TileSize, p.Processor.ImageProcessor.TileSize, uint32(512)) + kv["vision.min_image_tokens"] = cmp.Or(p.MinImageTokens, p.Processor.ImageProcessor.MinImageTokens, uint32(64)) + kv["vision.max_image_tokens"] = cmp.Or(p.MaxImageTokens, p.Processor.ImageProcessor.MaxImageTokens, uint32(256)) + kv["vision.max_pixels_tolerance"] = cmp.Or(p.MaxPixelsTolerance, p.Processor.ImageProcessor.MaxPixelsTol, float32(2.0)) + kv["vision.use_thumbnail"] = boolOr(true, p.UseThumbnail, p.Processor.ImageProcessor.UseThumbnail) + kv["vision.use_image_special_tokens"] = boolOr(true, p.UseImageSpecialTokens) + kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5})) + kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5})) + kv["vision.image_token_id"] = cmp.Or(p.ImageTokenID, uint32(396)) + + setVisionTokenID := func(k, token string) { + if t == nil || t.Vocabulary == nil { + return + } + for i, v := range t.Vocabulary.Tokens { + if v == token { + kv[k] = uint32(i) + return + } + } + } + setVisionTokenID("vision.image_start_token_id", "<|image_start|>") + setVisionTokenID("vision.image_end_token_id", "<|image_end|>") + setVisionTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>") + + return kv +} + +func (p *lfm2VLTextModel) Tensors(ts []Tensor) []*ggml.Tensor { + patchSize := int(cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16))) + numChannels := int(cmp.Or(p.VisionConfig.NumChannels, uint32(3))) + + for _, t := range ts { + if t.Name() == "v.patch_embd.weight" { + shape := t.Shape() + if len(shape) == 2 { + inputDim := uint64(numChannels * patchSize * patchSize) + if shape[1] == inputDim { + channels := numChannels + patch := patchSize + t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) { + return repackPatchEmbeddingWeight(data, srcShape, channels, patch) + }) + } + } + } + } + + out := p.textModel().Tensors(ts) + for _, t := range out { + if t.Name == "v.patch_embd.weight" && len(t.Shape) == 2 { + t.Shape = []uint64{t.Shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)} + } + } + return out +} + +func (p *lfm2VLTextModel) Replacements() []string { + out := make([]string, 0, 96) + + addText := func(from, to string) { + out = append(out, from, to) + if strings.HasPrefix(from, "model.") { + suffix := strings.TrimPrefix(from, "model.") + out = append(out, + "model.language_model."+suffix, to, + "model.language_model.model."+suffix, to, + ) + } + } + + base := p.textModel().Replacements() + for i := 0; i+1 < len(base); i += 2 { + addText(base[i], base[i+1]) + } + + // Vision tower + multimodal projector tensors (single-file conversion). + out = append(out, + "model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd", + "model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd", + "model.vision_tower.vision_model.encoder.layers", "v.blk", + "model.vision_tower.vision_model.post_layernorm", "v.post_ln", + "model.multi_modal_projector.layer_norm", "mm.layer_norm", + "model.multi_modal_projector.linear_1", "mm.1", + "model.multi_modal_projector.linear_2", "mm.2", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.out_proj", "attn_out", + "layer_norm1", "ln1", + "layer_norm2", "ln2", + "mlp.fc1", "ffn_up", + "mlp.fc2", "ffn_down", + ) + + return out +} + +// lfm2VLProjectorModel converts the vision encoder + projector component of LFM2 VL checkpoints. +type lfm2VLProjectorModel struct { + ModelParameters + DownsampleFactor uint32 `json:"downsample_factor"` + ProjectorHiddenDim uint32 `json:"projector_hidden_size"` + VisionModel struct { + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + NumChannels uint32 `json:"num_channels"` + PatchSize uint32 `json:"patch_size"` + LayerNormEpsilon float32 `json:"layer_norm_eps"` + ImageSize uint32 `json:"image_size"` + } `json:"vision_config"` + Processor struct { + ImageProcessor struct { + DownsampleFactor uint32 `json:"downsample_factor"` + TileSize uint32 `json:"tile_size"` + ImageMean []float32 `json:"image_mean"` + ImageStd []float32 `json:"image_std"` + Size struct { + Height uint32 `json:"height"` + Width uint32 `json:"width"` + } `json:"size"` + } `json:"image_processor"` + } +} + +var ( + _ ModelConverter = (*lfm2VLTextModel)(nil) + _ ModelConverter = (*lfm2VLProjectorModel)(nil) + _ moreParser = (*lfm2VLTextModel)(nil) + _ moreParser = (*lfm2VLProjectorModel)(nil) +) + +func (p *lfm2VLProjectorModel) parseMore(fsys fs.FS) error { + bts, err := fs.ReadFile(fsys, "processor_config.json") + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err + } + + return json.Unmarshal(bts, &p.Processor) +} + +func (p *lfm2VLProjectorModel) imageSize() uint32 { + if p.VisionModel.ImageSize > 0 { + return p.VisionModel.ImageSize + } + + downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2)) + baseSize := cmp.Or( + p.Processor.ImageProcessor.TileSize, + p.Processor.ImageProcessor.Size.Height, + p.Processor.ImageProcessor.Size.Width, + uint32(256), + ) + if downsample == 0 { + return baseSize + } + + return max(uint32(1), baseSize/downsample) +} + +func (p *lfm2VLProjectorModel) KV(_ *Tokenizer) KV { + kv := KV{ + "general.architecture": "clip", + "general.type": "mmproj", + "general.file_type": uint32(1), + "general.quantization_version": uint32(2), + "clip.has_vision_encoder": true, + "clip.projector_type": "lfm2", + "clip.use_gelu": true, + } + + kv["clip.vision.block_count"] = cmp.Or(p.VisionModel.NumHiddenLayers, uint32(27)) + kv["clip.vision.embedding_length"] = cmp.Or(p.VisionModel.HiddenSize, uint32(1152)) + kv["clip.vision.feed_forward_length"] = cmp.Or(p.VisionModel.IntermediateSize, uint32(4304)) + kv["clip.vision.attention.head_count"] = cmp.Or(p.VisionModel.NumAttentionHeads, uint32(16)) + kv["clip.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, float32(1e-6)) + kv["clip.vision.patch_size"] = cmp.Or(p.VisionModel.PatchSize, uint32(16)) + kv["clip.vision.image_size"] = p.imageSize() + kv["clip.vision.projection_dim"] = cmp.Or(p.ProjectorHiddenDim, uint32(2048)) + kv["clip.vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2)) + kv["clip.vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5})) + kv["clip.vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5})) + + return kv +} + +func defaultFloat32Slice(v, fallback []float32) []float32 { + if len(v) > 0 { + return v + } + + return fallback +} + +func (p *lfm2VLProjectorModel) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + + numChannels := cmp.Or(p.VisionModel.NumChannels, uint32(3)) + patchSize := cmp.Or(p.VisionModel.PatchSize, uint32(16)) + + for _, t := range ts { + name := t.Name() + if !(strings.HasPrefix(name, "v.") || strings.HasPrefix(name, "mm.")) { + continue + } + + shape := t.Shape() + if name == "v.patch_embd.weight" && len(shape) == 2 { + inputDim := uint64(numChannels * patchSize * patchSize) + if shape[1] == inputDim { + shape = []uint64{shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)} + channels := int(numChannels) + patch := int(patchSize) + t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) { + return repackPatchEmbeddingWeight(data, srcShape, channels, patch) + }) + } + } + + out = append(out, &ggml.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: slices.Clone(shape), + WriterTo: t, + }) + } + + return out +} + +func (p *lfm2VLProjectorModel) Replacements() []string { + return []string{ + "model.multi_modal_projector.linear_1", "mm.1", + "model.multi_modal_projector.linear_2", "mm.2", + "model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd", + "model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd", + "model.vision_tower.vision_model.encoder.layers", "v.blk", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.out_proj", "attn_out", + "layer_norm1", "ln1", + "layer_norm2", "ln2", + "mlp.fc1", "ffn_up", + "mlp.fc2", "ffn_down", + "model.vision_tower.vision_model.post_layernorm", "v.post_ln", + } +} + +func repackPatchEmbeddingWeight(data []float32, srcShape []uint64, channels, patch int) ([]float32, error) { + if len(srcShape) != 2 { + return nil, fmt.Errorf("invalid patch embedding shape rank: %d", len(srcShape)) + } + + outDim := int(srcShape[0]) + flatInputDim := int(srcShape[1]) + expectedInputDim := channels * patch * patch + if flatInputDim != expectedInputDim { + return nil, fmt.Errorf("invalid patch embedding input dim: got %d, want %d", flatInputDim, expectedInputDim) + } + + expectedSize := outDim * flatInputDim + if len(data) != expectedSize { + return nil, fmt.Errorf("invalid patch embedding data size: got %d, want %d", len(data), expectedSize) + } + + repacked := make([]float32, len(data)) + perChannel := patch * patch + + for o := range outDim { + inBase := o * flatInputDim + outBase := o * flatInputDim + + for y := range patch { + for x := range patch { + inPixelBase := inBase + (y*patch+x)*channels + for c := range channels { + src := inPixelBase + c + dst := outBase + c*perChannel + y*patch + x + repacked[dst] = data[src] + } + } + } + } + + return repacked, nil +} diff --git a/convert/convert_lfm2_vl_test.go b/convert/convert_lfm2_vl_test.go new file mode 100644 index 000000000..a2d88561c --- /dev/null +++ b/convert/convert_lfm2_vl_test.go @@ -0,0 +1,249 @@ +package convert + +import ( + "slices" + "strings" + "testing" +) + +func TestLFM2VLTextModelKVUsesTextConfig(t *testing.T) { + p := lfm2VLTextModel{ + TextConfig: lfm2Model{ + ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536}, + HiddenSize: 2048, + NumHiddenLayers: 16, + MaxPositionEmbeddings: 128000, + IntermediateSize: 12288, + BlockFFDim: 12288, + BlockAutoAdjustFFDim: true, + BlockMultipleOf: 256, + BlockFFNDimMultiplier: 1.0, + NumAttentionHeads: 32, + NumKeyValueHeads: 8, + LayerTypes: []string{"conv", "full_attention"}, + NormEps: 1e-5, + ConvLCache: 3, + }, + DownsampleFactor: 2, + VisionConfig: struct { + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + NumChannels uint32 `json:"num_channels"` + PatchSize uint32 `json:"patch_size"` + LayerNormEpsilon float32 `json:"layer_norm_eps"` + }{ + HiddenSize: 1152, + IntermediateSize: 4304, + NumAttentionHeads: 16, + NumHiddenLayers: 27, + NumChannels: 3, + PatchSize: 16, + LayerNormEpsilon: 1e-6, + }, + } + p.Processor.ImageProcessor.TileSize = 512 + p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5} + p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5} + + kv := p.KV(&Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + Tokens: []string{"<|pad|>", "", "<|image_start|>", "<|image_end|>", "<|img_thumbnail|>"}, + }, + }) + + if got, want := kv["general.architecture"], "lfm2"; got != want { + t.Fatalf("general.architecture = %v, want %v", got, want) + } + + if got, want := kv["feed_forward_length"], uint32(8192); got != want { + t.Fatalf("feed_forward_length = %v, want %v", got, want) + } + + if got, want := kv["vision.block_count"], uint32(27); got != want { + t.Fatalf("vision.block_count = %v, want %v", got, want) + } + + if got, want := kv["vision.image_size"], uint32(256); got != want { + t.Fatalf("vision.image_size = %v, want %v", got, want) + } + + if got, want := kv["vision.image_token_id"], uint32(396); got != want { + t.Fatalf("vision.image_token_id = %v, want %v", got, want) + } + + if got, want := kv["vision.image_start_token_id"], uint32(2); got != want { + t.Fatalf("vision.image_start_token_id = %v, want %v", got, want) + } + + if got, want := kv["vision.do_image_splitting"], true; got != want { + t.Fatalf("vision.do_image_splitting = %v, want %v", got, want) + } + if got, want := kv["vision.min_tiles"], uint32(2); got != want { + t.Fatalf("vision.min_tiles = %v, want %v", got, want) + } + if got, want := kv["vision.max_tiles"], uint32(10); got != want { + t.Fatalf("vision.max_tiles = %v, want %v", got, want) + } + if got, want := kv["vision.tile_size"], uint32(512); got != want { + t.Fatalf("vision.tile_size = %v, want %v", got, want) + } + if got, want := kv["vision.use_thumbnail"], true; got != want { + t.Fatalf("vision.use_thumbnail = %v, want %v", got, want) + } + if got, want := kv["vision.use_image_special_tokens"], true; got != want { + t.Fatalf("vision.use_image_special_tokens = %v, want %v", got, want) + } +} + +func TestLFM2VLTextModelTensorsIncludeVision(t *testing.T) { + p := lfm2VLTextModel{} + p.VisionConfig.PatchSize = 16 + p.VisionConfig.NumChannels = 3 + input := []Tensor{ + newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}), + newLFM2StubTensor("model.layers.0.ffn_norm.weight", []uint64{2048}), + newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}), + newLFM2StubTensor("v.blk.0.attn_q.weight", []uint64{1152, 1152}), + newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}), + } + + out := p.Tensors(input) + if len(out) == 0 { + t.Fatal("expected non-empty tensor list") + } + + foundPatch := false + foundVision := false + for _, tns := range out { + if tns.Name == "v.patch_embd.weight" { + foundPatch = true + if !slices.Equal(tns.Shape, []uint64{1152, 3, 16, 16}) { + t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", tns.Shape) + } + } + if strings.HasPrefix(tns.Name, "v.") || strings.HasPrefix(tns.Name, "mm.") { + foundVision = true + } + } + + if !foundPatch { + t.Fatal("expected v.patch_embd.weight in output tensors") + } + if !foundVision { + t.Fatal("expected at least one vision/projector tensor in output") + } +} + +func TestLFM2VLTextModelReplacements(t *testing.T) { + p := lfm2VLTextModel{} + r := strings.NewReplacer(p.Replacements()...) + + tests := []struct { + name string + in string + want string + }{ + { + name: "language_model_embed_tokens", + in: "model.language_model.embed_tokens.weight", + want: "token_embd.weight", + }, + { + name: "language_model_layers", + in: "model.language_model.layers.2.self_attn.q_proj.weight", + want: "blk.2.attn_q.weight", + }, + { + name: "nested_language_model_prefix", + in: "model.language_model.model.embedding_norm.weight", + want: "token_embd_norm.weight", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := r.Replace(tt.in); got != tt.want { + t.Fatalf("replacement(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestLFM2VLProjectorKV(t *testing.T) { + p := lfm2VLProjectorModel{ + DownsampleFactor: 2, + ProjectorHiddenDim: 2048, + } + p.VisionModel.NumHiddenLayers = 27 + p.VisionModel.HiddenSize = 1152 + p.VisionModel.IntermediateSize = 4304 + p.VisionModel.NumAttentionHeads = 16 + p.VisionModel.PatchSize = 16 + p.VisionModel.LayerNormEpsilon = 1e-6 + p.Processor.ImageProcessor.TileSize = 512 + p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5} + p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5} + + kv := p.KV(nil) + + if got, want := kv["general.architecture"], "clip"; got != want { + t.Fatalf("general.architecture = %v, want %v", got, want) + } + if got, want := kv["clip.projector_type"], "lfm2"; got != want { + t.Fatalf("clip.projector_type = %v, want %v", got, want) + } + if got, want := kv["clip.vision.image_size"], uint32(256); got != want { + t.Fatalf("clip.vision.image_size = %v, want %v", got, want) + } +} + +func TestLFM2VLProjectorTensorsPatchReshape(t *testing.T) { + p := lfm2VLProjectorModel{} + p.VisionModel.NumChannels = 3 + p.VisionModel.PatchSize = 16 + + input := []Tensor{ + newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}), + newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}), + newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}), + } + + out := p.Tensors(input) + if len(out) != 2 { + t.Fatalf("expected 2 tensors, got %d", len(out)) + } + + var patchShape []uint64 + for _, tns := range out { + if tns.Name == "v.patch_embd.weight" { + patchShape = tns.Shape + break + } + } + + if !slices.Equal(patchShape, []uint64{1152, 3, 16, 16}) { + t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", patchShape) + } +} + +func TestRepackPatchEmbeddingWeight(t *testing.T) { + data := []float32{ + 0, 1, // y=0,x=0 + 2, 3, // y=0,x=1 + 4, 5, // y=1,x=0 + 6, 7, // y=1,x=1 + } + + got, err := repackPatchEmbeddingWeight(data, []uint64{1, 8}, 2, 2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + want := []float32{0, 2, 4, 6, 1, 3, 5, 7} + if !slices.Equal(got, want) { + t.Fatalf("repacked data = %v, want %v", got, want) + } +} diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 41d0310a0..7d281d2ea 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -212,8 +212,13 @@ type tokenizer struct { PreTokenizer struct { PreTokenizers []struct { - Type string `json:"type"` - Pattern struct { + Type string `json:"type"` + Behavior string `json:"behavior"` + Invert bool `json:"invert"` + AddPrefixSpace bool `json:"add_prefix_space"` + TrimOffsets bool `json:"trim_offsets"` + UseRegex bool `json:"use_regex"` + Pattern struct { Regex string `json:"Regex"` } `json:"pattern"` } `json:"pretokenizers"` diff --git a/convert/tokenizer_test.go b/convert/tokenizer_test.go index 813096fd9..926e323cb 100644 --- a/convert/tokenizer_test.go +++ b/convert/tokenizer_test.go @@ -191,6 +191,84 @@ func TestParseTokenizer(t *testing.T) { Pre: "default", }, }, + { + name: "llama-bpe pretokenizer and control tokens", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "added_tokens": [ + {"id": 1, "content": "<|startoftext|>", "special": true}, + {"id": 6, "content": "<|im_start|>", "special": true}, + {"id": 7, "content": "<|im_end|>", "special": true}, + {"id": 8, "content": "<|tool_list_start|>", "special": true}, + {"id": 9, "content": "<|tool_list_end|>", "special": true}, + {"id": 10, "content": "<|tool_call_start|>", "special": true}, + {"id": 11, "content": "<|tool_call_end|>", "special": true}, + {"id": 12, "content": "<|tool_response_start|>", "special": true}, + {"id": 13, "content": "<|tool_response_end|>", "special": true}, + {"id": 396, "content": "", "special": true}, + {"id": 64400, "content": "", "special": true}, + {"id": 64401, "content": "", "special": true} + ], + "model": { + "vocab": { + "<|startoftext|>": 1, + "<|im_start|>": 6, + "<|im_end|>": 7, + "<|tool_list_start|>": 8, + "<|tool_list_end|>": 9, + "<|tool_call_start|>": 10, + "<|tool_call_end|>": 11, + "<|tool_response_start|>": 12, + "<|tool_response_end|>": 13, + "": 396, + "": 64400, + "": 64401 + } + }, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "Split", + "pattern": { + "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + }, + "behavior": "Isolated", + "invert": false + }, + { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": false + } + ] + } + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + Tokens: []string{ + "<|startoftext|>", + "<|im_start|>", + "<|im_end|>", + "<|tool_list_start|>", + "<|tool_list_end|>", + "<|tool_call_start|>", + "<|tool_call_end|>", + "<|tool_response_start|>", + "<|tool_response_end|>", + "", + "", + "", + }, + Scores: []float32{1, 6, 7, 8, 9, 10, 11, 12, 13, 396, 64400, 64401}, + Types: []int32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, + }, + Pre: "llama-bpe", + }, + }, { name: "list string merges", fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index d75ac19fa..e43e93aab 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -295,6 +295,7 @@ func (kv KV) OllamaEngineRequired() bool { "glm4moelite", "glmocr", "lfm2", + "lfm2moe", }, kv.Architecture()) } @@ -886,6 +887,7 @@ func (f GGML) FlashAttention() bool { "glmocr", "gptoss", "gpt-oss", "lfm2", + "lfm2moe", "mistral3", "nemotron_h", "nemotron_h_moe", "olmo3", diff --git a/model/models/lfm2/cache.go b/model/models/lfm2/cache.go index 7e9d35f5f..7c1185e6e 100644 --- a/model/models/lfm2/cache.go +++ b/model/models/lfm2/cache.go @@ -1,410 +1,44 @@ package lfm2 import ( - "slices" - "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" - "github.com/ollama/ollama/model/input" ) -var _ kvcache.Cache = (*HybridCache)(nil) +var ( + _ kvcache.Cache = (*HybridCache)(nil) + _ kvcache.CheckpointCache = (*HybridCache)(nil) +) -// HybridCache stores: -// - a standard causal KV cache for attention layers -// - a per-sequence recurrent conv state for shortconv layers +// HybridCache adapts the shared recurrent cache for LFM2: +// - KV attention cache is handled by the embedded causal cache +// - shortconv recurrent state uses conv slots [dConv, hiddenSize] // -// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1. -// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots]. +// This reuses shared checkpoint/restore logic for prefix mismatch recovery. type HybridCache struct { - kv *kvcache.Causal - - backend ml.Backend - dtype ml.DType - maxSequences int - - hiddenSize int - dConv int - - // slot mapping for recurrent state - slotForSeq map[int]int - refCount []int - freeSlots []int - - // per-layer conv state buffers (allocated lazily) - convCtxs map[int]ml.Context - convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots] - - // current forward batch (derived in StartForward) - curSeqs []int - curSlots []int - curSlotsInput ml.Tensor - curSeqTokens int - - // track if EnsureWritable has been called for this forward pass - writableEnsured bool - // track any error from EnsureWritable to propagate later - writableError error + *kvcache.Recurrent } func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache { - return &HybridCache{ - kv: kvcache.NewCausalCache(shift), - hiddenSize: hiddenSize, - dConv: dConv, - slotForSeq: make(map[int]int), - convCtxs: make(map[int]ml.Context), - convStates: make(map[int]ml.Tensor), - } -} + base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{ + Shift: shift, + ConvDim: dConv, + ConvChannels: hiddenSize, + RecurrentStateSize: 1, // LFM2 uses only conv state; keep a minimal recurrent buffer size. + CheckpointLogPrefix: "lfm2", + }) -func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { - c.backend = backend - c.dtype = dtype - c.maxSequences = maxSequences - - // initialize slot allocator - c.refCount = make([]int, maxSequences) - c.freeSlots = c.freeSlots[:0] - for i := maxSequences - 1; i >= 0; i-- { - c.freeSlots = append(c.freeSlots, i) - } - - c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch) -} - -func (c *HybridCache) Close() { - for _, ctx := range c.convCtxs { - ctx.Close() - } - c.kv.Close() -} - -func (c *HybridCache) SetConfig(config ml.CacheConfig) { - c.kv.SetConfig(config) -} - -func (c *HybridCache) SetLayer(layer int) { - c.kv.SetLayer(layer) -} - -func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { - return c.kv.Get(ctx) -} - -func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) { - c.kv.Put(ctx, key, value) -} - -func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { - if err := c.kv.StartForward(ctx, batch, reserve); err != nil { - return err - } - - // Derive equal-length sequence layout for shortconv. - // LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid. - seqCounts := make(map[int]int) - c.curSeqs = c.curSeqs[:0] - for _, s := range batch.Sequences { - if _, ok := seqCounts[s]; !ok { - c.curSeqs = append(c.curSeqs, s) - } - seqCounts[s]++ - } - - if len(c.curSeqs) == 0 { - return nil - } - - nTokens := len(batch.Sequences) - nSeqs := len(c.curSeqs) - want := nTokens / nSeqs - for _, s := range c.curSeqs { - if seqCounts[s] != want { - return kvcache.ErrNotSupported - } - } - - c.curSeqTokens = want - - // When reserving memory for estimation, use fake slot assignments - // without modifying permanent state (slotForSeq, refCount) - if reserve { - c.curSlots = c.curSlots[:0] - slots := make([]int32, nSeqs) - for i := range nSeqs { - c.curSlots = append(c.curSlots, i) - slots[i] = int32(i) - } - c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) - return nil - } - - // Ensure slots exist for sequences in this batch - c.curSlots = c.curSlots[:0] - var newSlots []int // track newly allocated slots that need zeroing - for _, s := range c.curSeqs { - slot, ok := c.slotForSeq[s] - if !ok { - var err error - slot, err = c.allocSlot() - if err != nil { - return err - } - c.slotForSeq[s] = slot - c.refCount[slot] = 1 - newSlots = append(newSlots, slot) - } - c.curSlots = append(c.curSlots, slot) - } - - // Zero conv state for newly allocated slots to clear stale data from previous sequences - if len(newSlots) > 0 { - c.zeroConvSlots(ctx, newSlots) - } - - // Create a tensor for the current slots - slots := make([]int32, len(c.curSlots)) - for i, v := range c.curSlots { - slots[i] = int32(v) - } - c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) - - // Reset writable state for new forward pass - c.writableEnsured = false - c.writableError = nil - - return nil -} - -func (c *HybridCache) allocSlot() (int, error) { - if len(c.freeSlots) == 0 { - return 0, kvcache.ErrKvCacheFull - } - slot := c.freeSlots[len(c.freeSlots)-1] - c.freeSlots = c.freeSlots[:len(c.freeSlots)-1] - return slot, nil -} - -func (c *HybridCache) freeSlot(slot int) { - // Bounds check before freeing - if slot >= 0 && slot < c.maxSequences { - c.freeSlots = append(c.freeSlots, slot) - } -} - -// zeroConvSlots zeros the conv state for the given slots across all layers. -// This must be called when recycling slots to prevent stale state from affecting new sequences. -func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) { - if len(slots) == 0 || len(c.convStates) == 0 { - return - } - - // Use input context for creating tensors - inputCtx := ctx.Input() - - // Create slot indices tensor - slotIndices := make([]int32, len(slots)) - for i, s := range slots { - slotIndices[i] = int32(s) - } - slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices)) - - // Create zero tensor for the slots (SetRows requires F32 source) - zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots)) - - // Zero each layer's conv state for these slots - for _, buf := range c.convStates { - ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor)) - } -} - -// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots. -// Returns an error if slot allocation fails. -func (c *HybridCache) EnsureWritable(ctx ml.Context) error { - for i, seq := range c.curSeqs { - slot, ok := c.slotForSeq[seq] - if !ok { - continue - } - - // Bounds check - if slot < 0 || slot >= len(c.refCount) { - continue - } - - if c.refCount[slot] <= 1 { - continue - } - - newSlot, err := c.allocSlot() - if err != nil { - return err - } - c.refCount[slot]-- - c.refCount[newSlot] = 1 - c.slotForSeq[seq] = newSlot - c.curSlots[i] = newSlot - - // Copy existing conv state for all initialized layers - for _, buf := range c.convStates { - // buf: [dConv*hiddenSize, maxSlots] - src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1)) - // SetRows requires F32 source - srcF32 := src.Cast(ctx, ml.DTypeF32) - ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1))) - } - } - - // Rebuild current slots tensor - slots := make([]int32, len(c.curSlots)) - for i, v := range c.curSlots { - slots[i] = int32(v) - } - c.curSlotsInput = ctx.Input().FromInts(slots, len(slots)) - - return nil -} - -func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) { - // KV cache shares prefix metadata (no copy) which is correct for prefix reuse. - c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen) - - // For shortconv state we implement copy-on-write: dst shares the same slot as src. - // On the first write to dst, EnsureWritable will create a private slot. - if dstSlot, ok := c.slotForSeq[dstSeq]; ok { - // Bounds check before decrementing - if dstSlot >= 0 && dstSlot < len(c.refCount) { - c.refCount[dstSlot]-- - if c.refCount[dstSlot] <= 0 { - c.refCount[dstSlot] = 0 - c.freeSlot(dstSlot) - } - } - delete(c.slotForSeq, dstSeq) - } - - srcSlot, ok := c.slotForSeq[srcSeq] - if !ok { - // src may not have a slot yet; dst will allocate on demand - return - } - - // Bounds check before incrementing - if srcSlot >= 0 && srcSlot < len(c.refCount) { - c.slotForSeq[dstSeq] = srcSlot - c.refCount[srcSlot]++ - } -} - -func (c *HybridCache) CanResume(seq int, pos int32) bool { - return c.kv.CanResume(seq, pos) -} - -func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error { - if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil { - return err - } - - // For recurrent state, any removal invalidates the state because - // the state at position N depends on all previous positions. - // Drop the slot mapping so it resets on next use. - slot, ok := c.slotForSeq[seq] - if !ok { - return nil - } - - // Bounds check - if slot < 0 || slot >= len(c.refCount) { - delete(c.slotForSeq, seq) - return nil - } - - c.refCount[slot]-- - if c.refCount[slot] <= 0 { - c.refCount[slot] = 0 - c.freeSlot(slot) - } - delete(c.slotForSeq, seq) - - return nil + return &HybridCache{Recurrent: base} } func (c *HybridCache) slotsTensor() ml.Tensor { - return c.curSlotsInput + return c.SlotsTensor() } func (c *HybridCache) seqTokens() int { - return c.curSeqTokens + return c.SeqTokens() } func (c *HybridCache) numSeqs() int { - return len(c.curSeqs) -} - -func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor { - if buf, ok := c.convStates[layer]; ok { - return buf - } - - if _, ok := c.convCtxs[layer]; !ok { - c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer) - } - - buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences) - c.convStates[layer] = buf - return buf -} - -// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs]. -// Returns an error if copy-on-write allocation fails. -func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) { - if !c.writableEnsured { - needsWritable := false - for _, seq := range c.curSeqs { - slot, ok := c.slotForSeq[seq] - if !ok { - continue - } - if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 { - needsWritable = true - break - } - } - - if needsWritable { - if err := c.EnsureWritable(ctx); err != nil { - c.writableError = err - } - } - c.writableEnsured = true - } - - if c.writableError != nil { - return nil, c.writableError - } - - buf := c.convBuffer(ctx, layer) - cur := buf.Rows(ctx, c.slotsTensor()) - return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil -} - -// UpdateConvState writes a new conv state for current batch sequences. -// newState must have shape [dConv, hiddenSize, nSeqs]. -func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) { - buf := c.convBuffer(ctx, layer) - src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs()) - // SetRows requires F32 source - srcF32 := src.Cast(ctx, ml.DTypeF32) - ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor())) -} - -// IsSupportedForBatch returns true if the current batch layout supports shortconv. -func (c *HybridCache) IsSupportedForBatch() bool { - return c.curSeqTokens > 0 && len(c.curSeqs) > 0 -} - -// Seqs returns the ordered unique sequences for the current forward pass. -func (c *HybridCache) Seqs() []int { - return slices.Clone(c.curSeqs) + return c.NumSeqs() } diff --git a/model/models/lfm2/cache_test.go b/model/models/lfm2/cache_test.go index f4c493c20..d9f860843 100644 --- a/model/models/lfm2/cache_test.go +++ b/model/models/lfm2/cache_test.go @@ -4,441 +4,39 @@ import ( "testing" "github.com/ollama/ollama/kvcache" - "github.com/ollama/ollama/ml" ) -// TestHybridCache tests verify the slot management logic of HybridCache. -// These tests focus on the recurrent state slot allocation, reference counting, -// and copy-on-write semantics without requiring a full ML backend. +func TestHybridCache_New(t *testing.T) { + cache := NewHybridCache(nil, 512, 2) + if cache == nil { + t.Fatal("expected cache to be created") + } -// createSlotOnlyCache creates a HybridCache with only the slot management -// fields initialized. Used to test slot logic in isolation. -func createSlotOnlyCache(maxSequences int) *HybridCache { - return &HybridCache{ - hiddenSize: 256, - dConv: 3, - maxSequences: maxSequences, - refCount: make([]int, maxSequences), - freeSlots: initFreeSlots(maxSequences), - slotForSeq: make(map[int]int), - convCtxs: make(map[int]ml.Context), - convStates: make(map[int]ml.Tensor), + if cache.Recurrent == nil { + t.Fatal("expected embedded recurrent cache to be created") } } -func initFreeSlots(n int) []int { - slots := make([]int, 0, n) - for i := n - 1; i >= 0; i-- { - slots = append(slots, i) - } - return slots -} +func TestHybridCache_ImplementsCheckpointCache(t *testing.T) { + cache := NewHybridCache(nil, 512, 2) -func TestHybridCache_SlotAllocation(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Verify initial state - if len(cache.freeSlots) != 4 { - t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots)) - } - - // Allocate all slots - for range 4 { - slot, err := cache.allocSlot() - if err != nil { - t.Fatalf("allocSlot failed: %v", err) - } - cache.refCount[slot] = 1 - } - - // Should be full now - if len(cache.freeSlots) != 0 { - t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots)) - } - - // Trying to allocate another should fail - _, err := cache.allocSlot() - if err != kvcache.ErrKvCacheFull { - t.Errorf("expected ErrKvCacheFull, got %v", err) + if _, ok := any(cache).(kvcache.CheckpointCache); !ok { + t.Fatal("expected HybridCache to implement CheckpointCache") } } -func TestHybridCache_SlotReuse(t *testing.T) { - cache := createSlotOnlyCache(4) +func TestHybridCache_DefaultBatchState(t *testing.T) { + cache := NewHybridCache(nil, 512, 2) - // Allocate a slot - slot1, _ := cache.allocSlot() - cache.refCount[slot1] = 1 - - // Free it - cache.refCount[slot1] = 0 - cache.freeSlot(slot1) - - // Allocate again - should get the same slot back (LIFO) - slot2, _ := cache.allocSlot() - if slot2 != slot1 { - t.Errorf("expected slot %d to be reused, got %d", slot1, slot2) - } -} - -func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Allocate slot for seq 1 - slot1, _ := cache.allocSlot() - cache.slotForSeq[1] = slot1 - cache.refCount[slot1] = 1 - - // Simulate sharing slot with seq 2 (copy-on-write style) - cache.slotForSeq[2] = slot1 - cache.refCount[slot1]++ - - // Should share the same slot - if cache.slotForSeq[2] != slot1 { - t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2]) + if got := cache.numSeqs(); got != 0 { + t.Fatalf("expected 0 sequences before StartForward, got %d", got) } - // Ref count should be 2 - if cache.refCount[slot1] != 2 { - t.Errorf("expected refCount 2, got %d", cache.refCount[slot1]) - } -} - -func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Allocate slot for seq 1 - slot1, _ := cache.allocSlot() - cache.slotForSeq[1] = slot1 - cache.refCount[slot1] = 1 - - // Share with seq 2 - cache.slotForSeq[2] = slot1 - cache.refCount[slot1]++ - - // Unshare seq 2 - cache.refCount[slot1]-- - delete(cache.slotForSeq, 2) - - // Ref count should be back to 1 - if cache.refCount[slot1] != 1 { - t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1]) + if got := cache.seqTokens(); got != 0 { + t.Fatalf("expected 0 sequence tokens before StartForward, got %d", got) } - // Seq 2 should no longer have a slot - if _, ok := cache.slotForSeq[2]; ok { - t.Error("seq 2 should not have a slot after unshare") - } -} - -func TestHybridCache_SlotFreeWhenUnused(t *testing.T) { - cache := createSlotOnlyCache(4) - - initialFreeSlots := len(cache.freeSlots) - - // Allocate slot for seq 1 - slot1, _ := cache.allocSlot() - cache.slotForSeq[1] = slot1 - cache.refCount[slot1] = 1 - - // Free the slot when refCount drops to 0 - cache.refCount[slot1]-- - if cache.refCount[slot1] <= 0 { - cache.refCount[slot1] = 0 - cache.freeSlot(slot1) - } - delete(cache.slotForSeq, 1) - - // Slot should be freed - if len(cache.freeSlots) != initialFreeSlots { - t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots)) - } - - // Ref count should be 0 - if cache.refCount[slot1] != 0 { - t.Errorf("expected refCount 0, got %d", cache.refCount[slot1]) - } -} - -func TestHybridCache_SlotOverwrite(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Allocate slots for seq 1 and seq 2 - slot1, _ := cache.allocSlot() - cache.slotForSeq[1] = slot1 - cache.refCount[slot1] = 1 - - slot2, _ := cache.allocSlot() - cache.slotForSeq[2] = slot2 - cache.refCount[slot2] = 1 - - initialFreeSlots := len(cache.freeSlots) - - // Simulate overwriting seq 2's slot with slot1 (sharing) - // First free the old slot - cache.refCount[slot2]-- - if cache.refCount[slot2] <= 0 { - cache.refCount[slot2] = 0 - cache.freeSlot(slot2) - } - // Then share slot1 - cache.slotForSeq[2] = slot1 - cache.refCount[slot1]++ - - // Seq 2 should now share slot1 - if cache.slotForSeq[2] != slot1 { - t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2]) - } - - // Old slot2 should be freed - if len(cache.freeSlots) != initialFreeSlots+1 { - t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots)) - } -} - -func TestHybridCache_BoundsChecking(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Test freeing invalid slot (should not panic) - cache.freeSlot(-1) - cache.freeSlot(100) // out of bounds - - // freeSlot does bounds checking, so invalid slots should be ignored - if len(cache.freeSlots) != 4 { - t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots)) - } -} - -func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) { - cache := createSlotOnlyCache(8) - - // Allocate slot for seq 1 - slot1, _ := cache.allocSlot() - cache.slotForSeq[1] = slot1 - cache.refCount[slot1] = 1 - - // Fork to seq 2, 3, 4 (all share slot1) - for _, seq := range []int{2, 3, 4} { - cache.slotForSeq[seq] = slot1 - cache.refCount[slot1]++ - } - - // Ref count should be 4 - if cache.refCount[slot1] != 4 { - t.Errorf("expected refCount 4, got %d", cache.refCount[slot1]) - } - - // Remove seq 2, 3 - for _, seq := range []int{2, 3} { - delete(cache.slotForSeq, seq) - cache.refCount[slot1]-- - } - - if cache.refCount[slot1] != 2 { - t.Errorf("expected refCount 2, got %d", cache.refCount[slot1]) - } - - // Slot should still be allocated (not in free list) - found := false - for _, s := range cache.freeSlots { - if s == slot1 { - found = true - break - } - } - if found { - t.Error("slot1 should not be in free list yet") - } - - // Remove remaining sequences - for _, seq := range []int{1, 4} { - delete(cache.slotForSeq, seq) - cache.refCount[slot1]-- - } - - if cache.refCount[slot1] != 0 { - t.Errorf("expected refCount 0, got %d", cache.refCount[slot1]) - } -} - -func TestHybridCache_ChainedSharing(t *testing.T) { - cache := createSlotOnlyCache(8) - - // Create seq 1 - slot1, _ := cache.allocSlot() - cache.slotForSeq[1] = slot1 - cache.refCount[slot1] = 1 - - // Share 1 -> 2 - cache.slotForSeq[2] = slot1 - cache.refCount[slot1]++ - - // Share 2 -> 3 (should still share slot1) - cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1 - cache.refCount[slot1]++ - - // All should share slot1 - if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 { - t.Error("all sequences should share slot1") - } - - if cache.refCount[slot1] != 3 { - t.Errorf("expected refCount 3, got %d", cache.refCount[slot1]) - } -} - -func TestHybridCache_CacheParameters(t *testing.T) { - cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5 - - if cache.hiddenSize != 512 { - t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize) - } - if cache.dConv != 5 { - t.Errorf("expected dConv 5, got %d", cache.dConv) - } -} - -func TestHybridCache_NumSeqs(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Initially no sequences - if cache.numSeqs() != 0 { - t.Errorf("expected 0 seqs, got %d", cache.numSeqs()) - } - - // Manually set up current batch state - cache.curSeqs = []int{1, 2, 3} - - if cache.numSeqs() != 3 { - t.Errorf("expected 3 seqs, got %d", cache.numSeqs()) - } -} - -func TestHybridCache_SeqTokens(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Initially 0 - if cache.seqTokens() != 0 { - t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens()) - } - - // Manually set up current batch state - cache.curSeqTokens = 16 - - if cache.seqTokens() != 16 { - t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens()) - } -} - -// Test that Seqs returns a clone of curSeqs -func TestHybridCache_Seqs_ReturnsClone(t *testing.T) { - cache := createSlotOnlyCache(4) - - cache.curSeqs = []int{1, 2, 3} - - seqs := cache.Seqs() - - // Modify returned slice - seqs[0] = 999 - - // Original should be unchanged - if cache.curSeqs[0] != 1 { - t.Error("Seqs should return a clone, not the original slice") - } -} - -func TestHybridCache_IsSupportedForBatch(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Initially not supported (no batch set up) if cache.IsSupportedForBatch() { - t.Error("expected IsSupportedForBatch to be false initially") - } - - // Set up a valid batch - cache.curSeqTokens = 1 - cache.curSeqs = []int{1} - - if !cache.IsSupportedForBatch() { - t.Error("expected IsSupportedForBatch to be true with valid batch") - } -} - -func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) { - cache := createSlotOnlyCache(4) - - // zeroConvSlots should handle empty slots without panicking - cache.zeroConvSlots(nil, nil) - cache.zeroConvSlots(nil, []int{}) - - // zeroConvSlots should handle empty convStates without panicking - cache.zeroConvSlots(nil, []int{0, 1, 2}) -} - -func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Allocate slot for seq 1 - slot1, _ := cache.allocSlot() - cache.slotForSeq[1] = slot1 - cache.refCount[slot1] = 1 - - // Free the slot (simulating sequence removal) - cache.refCount[slot1]-- - cache.freeSlot(slot1) - delete(cache.slotForSeq, 1) - - // Verify slot is in free list - if len(cache.freeSlots) != 4 { - t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots)) - } - - // Allocate for new seq 2 - should get recycled slot - slot2, _ := cache.allocSlot() - if slot2 != slot1 { - t.Errorf("expected recycled slot %d, got %d", slot1, slot2) - } - - // This recycled slot would need zeroing in the real implementation - // The actual zeroing is tested via integration tests since it requires ML context -} - -func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) { - cache := createSlotOnlyCache(4) - - // Simulate the slot allocation flow from StartForward - // When a sequence doesn't have a slot, it gets allocated and tracked as "new" - - newSlots := []int{} - - // Seq 1 doesn't have a slot - allocate and track - seq := 1 - if _, ok := cache.slotForSeq[seq]; !ok { - slot, err := cache.allocSlot() - if err != nil { - t.Fatalf("allocSlot failed: %v", err) - } - cache.slotForSeq[seq] = slot - cache.refCount[slot] = 1 - newSlots = append(newSlots, slot) - } - - // Verify newSlots contains the allocated slot - if len(newSlots) != 1 { - t.Errorf("expected 1 new slot, got %d", len(newSlots)) - } - - // Seq 1 already has a slot - should NOT be tracked as new - newSlots2 := []int{} - if _, ok := cache.slotForSeq[seq]; !ok { - slot, _ := cache.allocSlot() - cache.slotForSeq[seq] = slot - cache.refCount[slot] = 1 - newSlots2 = append(newSlots2, slot) - } - - // Verify no new slots for existing sequence - if len(newSlots2) != 0 { - t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2)) + t.Fatal("expected unsupported batch layout before StartForward") } } diff --git a/model/models/lfm2/model.go b/model/models/lfm2/model.go index 51e40d3c3..fc33b1e36 100644 --- a/model/models/lfm2/model.go +++ b/model/models/lfm2/model.go @@ -1,7 +1,11 @@ package lfm2 import ( + "bytes" "cmp" + "errors" + "fmt" + "image" "math" "github.com/ollama/ollama/fs" @@ -25,8 +29,20 @@ type Options struct { // per-layer head counts (LFM2 alternates attention and recurrent layers) numHeadsByLayer []int numKVHeadsByLayer []int + + // MoE config + numExperts int + numExpertsUsed int + normTopKProb bool + expertWeightsScale float32 + expertGatingFunc uint32 } +const ( + expertGatingFuncSoftmax = uint32(0) + expertGatingFuncSigmoid = uint32(2) +) + func (o Options) headDimValue() int { // Head dim is shared across layers; fall back to first attention layer head count. for _, h := range o.numHeadsByLayer { @@ -67,18 +83,138 @@ type Model struct { OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"` Output *nn.Linear `gguf:"output,alt:token_embd"` + VisionModel *VisionModel `gguf:"v"` + VisionProjector *VisionProjector `gguf:"mm"` + ImageProcessor ImageProcessor + imageTokenID int32 + imageStartToken int32 + imageEndToken int32 + imageThumbnailID int32 + imageRowColIDs map[imageGridPos]int32 + useSpecialTokens bool + projectorOptions VisionProjectorOptions + Options } -func New(c fs.Config) (model.Model, error) { - if c.Uint("expert_count") > 0 { - return nil, model.ErrUnsupportedModel +var _ model.MultimodalProcessor = (*Model)(nil) + +type imageGridPos struct { + row int + col int +} + +type visionEmbeddingLayout struct { + rows int + cols int + hasThumbnail bool +} + +type visionChunkData struct { + tokens int + row int + col int + thumbnail bool + layout *visionEmbeddingLayout +} + +func (m *Model) Validate() error { + if m.TokenEmbedding == nil { + return errors.New("lfm2: missing token_embd tensor") + } + if m.OutputNorm == nil { + return errors.New("lfm2: missing output_norm tensor") + } + if m.Output == nil { + return errors.New("lfm2: missing output tensor") } + for i, layer := range m.Layers { + if layer.AttentionNorm == nil { + return fmt.Errorf("lfm2: missing blk.%d.attn_norm tensor", i) + } + if layer.MLPNorm == nil { + return fmt.Errorf("lfm2: missing blk.%d.ffn_norm tensor", i) + } + switch ff := layer.MLP.(type) { + case nil: + return fmt.Errorf("lfm2: missing blk.%d feed-forward tensors", i) + case *denseMLP: + if ff.Up == nil || ff.Down == nil || ff.Gate == nil { + return fmt.Errorf("lfm2: missing blk.%d dense feed-forward tensors", i) + } + case *sparseMLP: + if ff.Router == nil || ff.Gate == nil || ff.Up == nil || ff.Down == nil { + return fmt.Errorf("lfm2: missing blk.%d sparse feed-forward tensors", i) + } + default: + return fmt.Errorf("lfm2: unsupported feed-forward type at blk.%d", i) + } + + switch op := layer.Operator.(type) { + case *Attention: + if op == nil || op.Query == nil || op.Key == nil || op.Value == nil || op.Output == nil || op.QueryNorm == nil || op.KeyNorm == nil { + return fmt.Errorf("lfm2: missing blk.%d attention tensors", i) + } + case *ShortConv: + if op == nil || op.Conv == nil || op.Conv.Weight == nil || op.InProj == nil || op.OutProj == nil { + return fmt.Errorf("lfm2: missing blk.%d shortconv tensors", i) + } + default: + return fmt.Errorf("lfm2: unsupported operator at blk.%d", i) + } + } + + if m.VisionModel != nil { + if m.VisionModel.PatchEmbedding == nil { + return errors.New("lfm2: missing vision patch embedding tensors") + } + if m.VisionModel.PositionEmbedding == nil { + return errors.New("lfm2: missing vision position embedding tensors") + } + if m.VisionModel.PostLayerNorm == nil { + return errors.New("lfm2: missing vision post layer norm tensors") + } + if len(m.VisionModel.Layers) == 0 { + return errors.New("lfm2: missing vision encoder layers") + } + for i, layer := range m.VisionModel.Layers { + if layer.LayerNorm1 == nil || layer.LayerNorm2 == nil || layer.SelfAttention == nil || layer.MLP == nil { + return fmt.Errorf("lfm2: missing vision layer tensors at v.blk.%d", i) + } + if layer.SelfAttention.Query == nil || layer.SelfAttention.Key == nil || layer.SelfAttention.Value == nil || layer.SelfAttention.Output == nil { + return fmt.Errorf("lfm2: missing vision attention tensors at v.blk.%d", i) + } + if layer.MLP.Up == nil || layer.MLP.Down == nil { + return fmt.Errorf("lfm2: missing vision feed-forward tensors at v.blk.%d", i) + } + } + + if m.VisionProjector == nil || m.VisionProjector.Linear1 == nil || m.VisionProjector.Linear2 == nil { + return errors.New("lfm2: missing multimodal projector tensors") + } + } + + return nil +} + +func New(c fs.Config) (model.Model, error) { if c.String("tokenizer.ggml.model") != "gpt2" { return nil, model.ErrUnsupportedTokenizer } + numExperts := int(c.Uint("expert_count")) + isMoE := numExperts > 0 + numExpertsUsed := int(c.Uint("expert_used_count")) + if isMoE { + if numExperts <= 0 { + return nil, fmt.Errorf("lfm2: invalid expert_count=%d", numExperts) + } + if numExpertsUsed <= 0 || numExpertsUsed > numExperts { + return nil, fmt.Errorf("lfm2: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts) + } + } + vocabulary := tokenizer.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -105,8 +241,16 @@ func New(c fs.Config) (model.Model, error) { } m := Model{ - Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...), - Layers: make([]Layer, c.Uint("block_count")), + Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...), + Layers: make([]Layer, c.Uint("block_count")), + ImageProcessor: newImageProcessor(c), + VisionModel: newVisionModel(c), + VisionProjector: &VisionProjector{}, + imageRowColIDs: make(map[imageGridPos]int32), + projectorOptions: VisionProjectorOptions{ + scaleFactor: int(c.Uint("vision.projector.scale_factor", 2)), + useLayerNorm: c.Bool("vision.projector.use_layernorm", false), + }, Options: Options{ hiddenSize: int(c.Uint("embedding_length")), headDim: int(c.Uint("attention.key_length")), @@ -116,9 +260,66 @@ func New(c fs.Config) (model.Model, error) { ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.scaling.factor", 1), originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + numExperts: numExperts, + numExpertsUsed: numExpertsUsed, + normTopKProb: c.Bool("norm_top_k_prob", true), + expertWeightsScale: c.Float("expert_weights_scale", 1.0), + expertGatingFunc: c.Uint("expert_gating_func", expertGatingFuncSoftmax), }, } + lookupTokenID := func(token string) int32 { + for i, t := range vocabulary.Values { + if t == token { + return int32(i) + } + } + return 0 + } + + resolveTokenID := func(explicitKey, token string, fallback uint32) int32 { + if explicitKey != "" { + if id := c.Uint(explicitKey); id != 0 { + return int32(id) + } + } + if tokenID := lookupTokenID(token); tokenID != 0 { + return tokenID + } + return int32(fallback) + } + + m.imageTokenID = resolveTokenID("vision.image_token_id", "", 396) + m.imageStartToken = resolveTokenID("vision.image_start_token_id", "<|image_start|>", 0) + m.imageEndToken = resolveTokenID("vision.image_end_token_id", "<|image_end|>", 0) + m.imageThumbnailID = resolveTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>", 0) + m.useSpecialTokens = c.Bool("vision.use_image_special_tokens", true) + + maxGridTokens := int(c.Uint("vision.max_tiles", 10)) + if maxGridTokens <= 0 { + maxGridTokens = 10 + } + for row := 1; row <= maxGridTokens; row++ { + for col := 1; col <= maxGridTokens; col++ { + token := fmt.Sprintf("<|img_row_%d_col_%d|>", row, col) + if tokenID := lookupTokenID(token); tokenID > 0 { + m.imageRowColIDs[imageGridPos{row: row, col: col}] = tokenID + } + } + } + + if !m.useSpecialTokens { + m.imageStartToken = 0 + m.imageEndToken = 0 + m.imageThumbnailID = 0 + m.imageRowColIDs = map[imageGridPos]int32{} + } + + if c.Uint("vision.block_count") == 0 { + m.VisionModel = nil + m.VisionProjector = nil + } + type headCounts interface { HeadCount() []uint64 HeadCountKV() []uint64 @@ -133,6 +334,14 @@ func New(c fs.Config) (model.Model, error) { m.numHeadsByLayer = make([]int, len(m.Layers)) m.numKVHeadsByLayer = make([]int, len(m.Layers)) + leadingDenseBlockCount := int(c.Uint("leading_dense_block_count")) + if leadingDenseBlockCount < 0 { + leadingDenseBlockCount = 0 + } + if leadingDenseBlockCount > len(m.Layers) { + leadingDenseBlockCount = len(m.Layers) + } + for i := range m.Layers { m.numHeadsByLayer[i] = int(headCount[i]) m.numKVHeadsByLayer[i] = int(headCountKV[i]) @@ -142,6 +351,12 @@ func New(c fs.Config) (model.Model, error) { } else { m.Layers[i].Operator = &Attention{} } + + if isMoE && i >= leadingDenseBlockCount { + m.Layers[i].MLP = &sparseMLP{} + } else { + m.Layers[i].MLP = &denseMLP{} + } } lCache := int(c.Uint("shortconv.l_cache")) @@ -188,22 +403,77 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, return sa.Output.Forward(ctx, attention) } -type MLP struct { +type FeedForward interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type denseMLP struct { Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` Gate *nn.Linear `gguf:"ffn_gate"` } -func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { +func (mlp *denseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } +type sparseMLP struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` + Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"` +} + +func (mlp *sparseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { + // hiddenState: [hidden, tokens] + routerLogits := mlp.Router.Forward(ctx, hiddenState) + + probs := routerLogits.Softmax(ctx) + if opts.expertGatingFunc == expertGatingFuncSigmoid { + probs = routerLogits.Sigmoid(ctx) + } + + selectionProbs := probs + if mlp.Bias != nil { + selectionProbs = selectionProbs.Add(ctx, mlp.Bias) + } + + selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed) + routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts) + if opts.normTopKProb { + routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1)) + weightsSum := routingWeights.SumRows(ctx) + weightsSum = weightsSum.Clamp(ctx, 1e-6, float32(math.Inf(1))) + routingWeights = routingWeights.Div(ctx, weightsSum) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1)) + } + if opts.expertWeightsScale != 1 { + routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale)) + } + + // Build routing-weights branch early to enable topk-MoE fusion. + ctx.Forward(routingWeights) + + hiddenState3D := hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1)) + experts := mlp.Gate.Forward(ctx, hiddenState3D, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenState3D, selectedExperts)) + experts = mlp.Down.Forward(ctx, experts, selectedExperts) + experts = experts.Mul(ctx, routingWeights) + + nextState := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextState = nextState.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + + return nextState +} + type Layer struct { AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` Operator Operator MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` - MLP *MLP + MLP FeedForward } func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor { @@ -229,10 +499,233 @@ func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tenso return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } +func multimodalTokenCount(mm input.Multimodal) int { + if mm.Tensor != nil { + return mm.Tensor.Dim(1) + } + + switch data := mm.Data.(type) { + case int: + return data + case int32: + return int(data) + case visionChunkData: + return data.tokens + case *visionChunkData: + if data != nil { + return data.tokens + } + } + + return 0 +} + +func multimodalChunkInfo(mm input.Multimodal) visionChunkData { + switch data := mm.Data.(type) { + case visionChunkData: + return data + case *visionChunkData: + if data != nil { + return *data + } + } + + return visionChunkData{ + tokens: multimodalTokenCount(mm), + } +} + +func multimodalLayout(mm []input.Multimodal) visionEmbeddingLayout { + layout := visionEmbeddingLayout{rows: 1, cols: 1} + if len(mm) == 0 { + return layout + } + + first := multimodalChunkInfo(mm[0]) + if first.layout != nil { + return *first.layout + } + + return layout +} + +func (m *Model) imageRowColToken(row, col int) int32 { + if row <= 0 || col <= 0 { + return 0 + } + return m.imageRowColIDs[imageGridPos{row: row, col: col}] +} + +func (m *Model) appendImageChunk(result []*input.Input, chunk input.Multimodal, imageToken int32, hash uint64) ([]*input.Input, error) { + tokenCount := multimodalTokenCount(chunk) + if tokenCount <= 0 { + return nil, errors.New("lfm2: multimodal input has no tokens") + } + + result = append(result, &input.Input{ + Token: imageToken, + Multimodal: []input.Multimodal{chunk}, + MultimodalHash: hash, + SameBatch: tokenCount - 1, + }) + + for range tokenCount - 1 { + result = append(result, &input.Input{Token: imageToken}) + } + + return result, nil +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { + if m.VisionModel == nil || m.VisionProjector == nil || len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + img, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + processedImages, layout, err := m.ImageProcessor.ProcessImage(img) + if err != nil { + return nil, err + } + + if m.ImageProcessor.patchSize <= 0 { + return nil, errors.New("lfm2: invalid vision patch size") + } + + layoutInfo := &visionEmbeddingLayout{ + rows: layout.rows, + cols: layout.cols, + hasThumbnail: layout.hasThumbnail, + } + + mm := make([]input.Multimodal, 0, len(processedImages)) + for i, processed := range processedImages { + patches := visionPatchGrid{ + Width: processed.size.X / m.ImageProcessor.patchSize, + Height: processed.size.Y / m.ImageProcessor.patchSize, + } + if patches.Width == 0 || patches.Height == 0 { + return nil, errors.New("lfm2: invalid resized image dimensions") + } + + pixelValues := ctx.Input().FromFloats(processed.data, processed.size.X, processed.size.Y, m.ImageProcessor.numChannels) + visionOutputs := m.VisionModel.Forward(ctx, pixelValues, patches) + projected := m.VisionProjector.Forward(ctx, visionOutputs, patches, m.projectorOptions) + + chunk := visionChunkData{ + tokens: projected.Dim(1), + row: processed.row, + col: processed.col, + thumbnail: processed.thumbnail, + } + if i == 0 { + chunk.layout = layoutInfo + } + + mm = append(mm, input.Multimodal{ + Tensor: projected, + Data: chunk, + }) + } + + return mm, nil +} + +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input + + imageToken := m.imageTokenID + if imageToken == 0 { + imageToken = 396 + } + useSpecialTokens := m.useSpecialTokens || m.imageStartToken > 0 || m.imageEndToken > 0 || m.imageThumbnailID > 0 || len(m.imageRowColIDs) > 0 + + for _, inp := range inputs { + if len(inp.Multimodal) == 0 { + result = append(result, inp) + continue + } + + layout := multimodalLayout(inp.Multimodal) + if layout.rows <= 0 { + layout.rows = 1 + } + if layout.cols <= 0 { + layout.cols = 1 + } + tiles := layout.rows * layout.cols + multitile := tiles > 1 + + if useSpecialTokens && m.imageStartToken > 0 { + result = append(result, &input.Input{Token: m.imageStartToken}) + } + + for i, mm := range inp.Multimodal { + chunk := multimodalChunkInfo(mm) + if chunk.tokens <= 0 { + chunk.tokens = multimodalTokenCount(mm) + } + + if multitile && !chunk.thumbnail && chunk.row == 0 && chunk.col == 0 && i < tiles { + chunk.row = i/layout.cols + 1 + chunk.col = i%layout.cols + 1 + } + if multitile && layout.hasThumbnail && i == tiles { + chunk.thumbnail = true + } + + if useSpecialTokens && multitile { + if chunk.thumbnail { + if m.imageThumbnailID > 0 { + result = append(result, &input.Input{Token: m.imageThumbnailID}) + } + } else if marker := m.imageRowColToken(chunk.row, chunk.col); marker > 0 { + result = append(result, &input.Input{Token: marker}) + } + } + + var err error + result, err = m.appendImageChunk(result, input.Multimodal{ + Tensor: mm.Tensor, + Data: chunk, + }, imageToken, inp.MultimodalHash) + if err != nil { + return nil, err + } + } + + if useSpecialTokens && m.imageEndToken > 0 { + result = append(result, &input.Input{Token: m.imageEndToken}) + } + } + + return result, nil +} + func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) + if len(batch.Multimodal) > 0 { + // We splice vision embeddings into token embeddings in-place; duplicate to + // avoid aliasing the raw embedding output graph. + hiddenState = hiddenState.Duplicate(ctx) + } + for _, mm := range batch.Multimodal { + offset := mm.Index + for _, multimodal := range mm.Multimodal { + if multimodal.Tensor == nil { + continue + } + + visionOutputs := multimodal.Tensor + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, offset*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + offset += visionOutputs.Dim(1) + } + } for i, layer := range m.Layers { m.Cache.SetLayer(i) @@ -251,4 +744,5 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func init() { model.Register("lfm2", New) + model.Register("lfm2moe", New) } diff --git a/model/models/lfm2/model_multimodal_test.go b/model/models/lfm2/model_multimodal_test.go new file mode 100644 index 000000000..a289e1999 --- /dev/null +++ b/model/models/lfm2/model_multimodal_test.go @@ -0,0 +1,160 @@ +package lfm2 + +import ( + "testing" + + "github.com/ollama/ollama/model/input" +) + +func TestPostTokenizeWithSpecialImageTokens(t *testing.T) { + m := &Model{ + imageTokenID: 396, + imageStartToken: 2, + imageEndToken: 3, + useSpecialTokens: true, + } + + in := []*input.Input{ + {Token: 11}, + {Multimodal: []input.Multimodal{{Data: 64}}, MultimodalHash: 123}, + {Token: 12}, + } + + out, err := m.PostTokenize(in) + if err != nil { + t.Fatalf("PostTokenize returned error: %v", err) + } + + if len(out) != 68 { + t.Fatalf("expected 68 tokens, got %d", len(out)) + } + + if out[0].Token != 11 { + t.Fatalf("out[0].Token = %d, want 11", out[0].Token) + } + if out[1].Token != 2 { + t.Fatalf("out[1].Token = %d, want 2", out[1].Token) + } + + firstImage := out[2] + if firstImage.Token != 396 { + t.Fatalf("out[2].Token = %d, want 396", firstImage.Token) + } + if len(firstImage.Multimodal) != 1 { + t.Fatalf("expected multimodal payload on first image token") + } + if firstImage.MultimodalHash != 123 { + t.Fatalf("out[2].MultimodalHash = %d, want 123", firstImage.MultimodalHash) + } + if firstImage.SameBatch != 63 { + t.Fatalf("out[2].SameBatch = %d, want 63", firstImage.SameBatch) + } + + for i := 3; i < 66; i++ { + if out[i].Token != 396 { + t.Fatalf("out[%d].Token = %d, want 396", i, out[i].Token) + } + if len(out[i].Multimodal) != 0 { + t.Fatalf("out[%d] should not carry multimodal payload", i) + } + } + + if out[66].Token != 3 { + t.Fatalf("out[66].Token = %d, want 3", out[66].Token) + } + if out[67].Token != 12 { + t.Fatalf("out[67].Token = %d, want 12", out[67].Token) + } +} + +func TestPostTokenizeWithoutSpecialImageTokens(t *testing.T) { + m := &Model{ + imageTokenID: 777, + useSpecialTokens: false, + } + + in := []*input.Input{ + {Multimodal: []input.Multimodal{{Data: 5}}, MultimodalHash: 9}, + } + + out, err := m.PostTokenize(in) + if err != nil { + t.Fatalf("PostTokenize returned error: %v", err) + } + + if len(out) != 5 { + t.Fatalf("expected 5 tokens, got %d", len(out)) + } + if out[0].Token != 777 || out[0].SameBatch != 4 || len(out[0].Multimodal) != 1 { + t.Fatalf("unexpected first token: %+v", *out[0]) + } + for i := 1; i < 5; i++ { + if out[i].Token != 777 { + t.Fatalf("out[%d].Token = %d, want 777", i, out[i].Token) + } + if len(out[i].Multimodal) != 0 { + t.Fatalf("out[%d] should not carry multimodal payload", i) + } + } +} + +func TestPostTokenizeMultiTileLayoutTokens(t *testing.T) { + m := &Model{ + imageTokenID: 396, + imageStartToken: 498, + imageEndToken: 499, + imageThumbnailID: 497, + imageRowColIDs: map[imageGridPos]int32{ + {row: 1, col: 1}: 397, + {row: 1, col: 2}: 398, + }, + useSpecialTokens: true, + } + + layout := &visionEmbeddingLayout{rows: 1, cols: 2, hasThumbnail: true} + in := []*input.Input{{ + Multimodal: []input.Multimodal{ + {Data: visionChunkData{tokens: 3, row: 1, col: 1, layout: layout}}, + {Data: visionChunkData{tokens: 3, row: 1, col: 2}}, + {Data: visionChunkData{tokens: 2, thumbnail: true}}, + }, + MultimodalHash: 1, + }} + + out, err := m.PostTokenize(in) + if err != nil { + t.Fatalf("PostTokenize returned error: %v", err) + } + + got := make([]int32, len(out)) + for i := range out { + got[i] = out[i].Token + } + + want := []int32{ + 498, // <|image_start|> + 397, // <|img_row_1_col_1|> + 396, 396, 396, + 398, // <|img_row_1_col_2|> + 396, 396, 396, + 497, // <|img_thumbnail|> + 396, 396, + 499, // <|image_end|> + } + + if len(got) != len(want) { + t.Fatalf("len(out) = %d, want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("out[%d].Token = %d, want %d", i, got[i], want[i]) + } + } + + if len(out[2].Multimodal) != 1 || len(out[6].Multimodal) != 1 || len(out[10].Multimodal) != 1 { + t.Fatalf("expected multimodal payload on first token of each chunk") + } + if out[2].SameBatch != 2 || out[6].SameBatch != 2 || out[10].SameBatch != 1 { + t.Fatalf("unexpected SameBatch values: [%d %d %d]", out[2].SameBatch, out[6].SameBatch, out[10].SameBatch) + } +} diff --git a/model/models/lfm2/model_vision.go b/model/models/lfm2/model_vision.go new file mode 100644 index 000000000..233f6b0ee --- /dev/null +++ b/model/models/lfm2/model_vision.go @@ -0,0 +1,184 @@ +package lfm2 + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +const lfm2VisionBatchSize = 1 + +type visionPatchGrid struct { + Width int + Height int +} + +type VisionSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output,alt:attn_out"` +} + +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { + headDim := opts.hiddenSize / opts.numHeads + + query := sa.Query.Forward(ctx, hiddenState) + key := sa.Key.Forward(ctx, hiddenState) + value := sa.Value.Forward(ctx, hiddenState) + + query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), lfm2VisionBatchSize) + key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), lfm2VisionBatchSize) + value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), lfm2VisionBatchSize) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), lfm2VisionBatchSize) + return sa.Output.Forward(ctx, attention) +} + +type VisionMLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor { + return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx)) +} + +type VisionEncoderLayer struct { + LayerNorm1 *nn.LayerNorm `gguf:"ln1"` + SelfAttention *VisionSelfAttention + + LayerNorm2 *nn.LayerNorm `gguf:"ln2"` + MLP *VisionMLP +} + +func (l *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { + residual := hiddenState + + hiddenState = l.LayerNorm1.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, opts) + hiddenState = hiddenState.Add(ctx, residual) + + residual = hiddenState + hiddenState = l.LayerNorm2.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState) + return hiddenState.Add(ctx, residual) +} + +type VisionModelOptions struct { + hiddenSize, numHeads int + imageSize, patchSize int + eps float32 +} + +type VisionModel struct { + PatchEmbedding *nn.Conv2D `gguf:"patch_embd"` + PositionEmbedding *nn.Embedding `gguf:"position_embd"` + PostLayerNorm *nn.LayerNorm `gguf:"post_ln"` + + Layers []VisionEncoderLayer `gguf:"blk"` + + *VisionModelOptions +} + +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, patches visionPatchGrid) ml.Tensor { + numPatches := patches.Width * patches.Height + + hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) + hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) + hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + + if m.PositionEmbedding != nil { + posTokens := m.PositionEmbedding.Weight.Dim(1) + source := int(math.Sqrt(float64(posTokens))) + + var positionEmbeddings ml.Tensor + if source > 0 && source*source == posTokens && (source != patches.Width || source != patches.Height) { + // SigLIP2 NAFlex-style position interpolation for variable image sizes. + positionIDs := ctx.Arange(0, float32(posTokens), 1, ml.DTypeI32) + positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs) + positionEmbeddings = positionEmbeddings.Reshape(ctx, -1, source, source) + positionEmbeddings = positionEmbeddings.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) + positionEmbeddings = positionEmbeddings.Interpolate(ctx, [4]int{ + patches.Width, + patches.Height, + hiddenState.Dim(0), + 1, + }, ml.SamplingModeBilinear) + positionEmbeddings = positionEmbeddings.Permute(ctx, 1, 2, 0, 3) + positionEmbeddings = positionEmbeddings.Contiguous(ctx, -1, patches.Width*patches.Height) + } else { + positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32) + positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs) + } + + hiddenState = hiddenState.Add(ctx, positionEmbeddings) + } + + for _, layer := range m.Layers { + hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions) + } + + return m.PostLayerNorm.Forward(ctx, hiddenState, m.eps) +} + +func newVisionModel(c fs.Config) *VisionModel { + return &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")), + VisionModelOptions: &VisionModelOptions{ + hiddenSize: int(c.Uint("vision.embedding_length", 1152)), + numHeads: int(c.Uint("vision.attention.head_count", 16)), + imageSize: int(c.Uint("vision.image_size", 256)), + patchSize: int(c.Uint("vision.patch_size", 16)), + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6), + }, + } +} + +type VisionProjector struct { + LayerNorm *nn.LayerNorm `gguf:"layer_norm"` + Linear1 *nn.Linear `gguf:"1"` + Linear2 *nn.Linear `gguf:"2"` +} + +type VisionProjectorOptions struct { + scaleFactor int + useLayerNorm bool +} + +func (p *VisionProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid, opts VisionProjectorOptions) ml.Tensor { + hiddenSize := visionOutputs.Dim(0) + featureMap := visionOutputs + + merge := max(opts.scaleFactor, 1) + if merge > 1 { + width := patches.Width + height := patches.Height + + featureMap = featureMap.Reshape(ctx, hiddenSize, width, height) + + // Match llama.cpp patch merger: pad spatial dims to merge factor. + padWidth := (merge - width%merge) % merge + padHeight := (merge - height%merge) % merge + if padWidth != 0 || padHeight != 0 { + featureMap = featureMap.Pad(ctx, 0, padWidth, padHeight, 0) + width += padWidth + height += padHeight + } + + featureMap = featureMap.Reshape(ctx, hiddenSize*merge, width/merge, height) + featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx, hiddenSize*merge*merge, height/merge, width/merge) + featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx) + featureMap = featureMap.Contiguous(ctx, featureMap.Dim(0), featureMap.Dim(1)*featureMap.Dim(2)) + } + + if opts.useLayerNorm && p.LayerNorm != nil { + featureMap = p.LayerNorm.Forward(ctx, featureMap, 1e-5) + } + + featureMap = p.Linear1.Forward(ctx, featureMap).GELU(ctx) + return p.Linear2.Forward(ctx, featureMap) +} diff --git a/model/models/lfm2/process_image.go b/model/models/lfm2/process_image.go new file mode 100644 index 000000000..7ca65190a --- /dev/null +++ b/model/models/lfm2/process_image.go @@ -0,0 +1,260 @@ +package lfm2 + +import ( + "image" + stdimage "image/draw" + "math" + "slices" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/model/imageproc" +) + +type ImageProcessor struct { + imageSize, patchSize, numChannels int + downsampleFactor int + imageMean, imageStd [3]float32 + + doImageSplitting bool + minTiles int + maxTiles int + useThumbnail bool + tileSize int + + minImageTokens int + maxImageTokens int + maxPixelsTolerance float64 +} + +type processedVisionImage struct { + data []float32 + size image.Point + row int + col int + thumbnail bool +} + +type processedVisionLayout struct { + rows int + cols int + hasThumbnail bool +} + +func newImageProcessor(c fs.Config) ImageProcessor { + mean := c.Floats("vision.image_mean") + std := c.Floats("vision.image_std") + + processor := ImageProcessor{ + imageSize: int(c.Uint("vision.image_size", 256)), + patchSize: int(c.Uint("vision.patch_size", 16)), + numChannels: int(c.Uint("vision.num_channels", 3)), + downsampleFactor: int(c.Uint("vision.projector.scale_factor", 2)), + imageMean: [3]float32{0.5, 0.5, 0.5}, + imageStd: [3]float32{0.5, 0.5, 0.5}, + doImageSplitting: c.Bool("vision.do_image_splitting", true), + minTiles: int(c.Uint("vision.min_tiles", 2)), + maxTiles: int(c.Uint("vision.max_tiles", 10)), + useThumbnail: c.Bool("vision.use_thumbnail", true), + tileSize: int(c.Uint("vision.tile_size", 512)), + minImageTokens: int(c.Uint("vision.min_image_tokens", 64)), + maxImageTokens: int(c.Uint("vision.max_image_tokens", 256)), + maxPixelsTolerance: float64(c.Float("vision.max_pixels_tolerance", 2.0)), + } + + if len(mean) >= 3 { + processor.imageMean = [3]float32{mean[0], mean[1], mean[2]} + } + if len(std) >= 3 { + processor.imageStd = [3]float32{std[0], std[1], std[2]} + } + + // Keep defaults aligned with HF unless explicitly configured. + if processor.downsampleFactor <= 0 { + processor.downsampleFactor = 2 + } + if processor.patchSize <= 0 { + processor.patchSize = 16 + } + if processor.tileSize <= 0 { + processor.tileSize = 512 + } + if processor.minTiles <= 0 { + processor.minTiles = 2 + } + if processor.maxTiles < processor.minTiles { + processor.maxTiles = processor.minTiles + } + if processor.minImageTokens <= 0 { + processor.minImageTokens = 64 + } + if processor.maxImageTokens < processor.minImageTokens { + processor.maxImageTokens = processor.minImageTokens + } + if processor.maxPixelsTolerance <= 0 { + processor.maxPixelsTolerance = 2.0 + } + + return processor +} + +func (p ImageProcessor) ProcessImage(img image.Image) ([]processedVisionImage, processedVisionLayout, error) { + img = imageproc.Composite(img) + + orig := img.Bounds().Size() + resizedWidth, resizedHeight := p.smartResize(orig.Y, orig.X) + + layout := processedVisionLayout{rows: 1, cols: 1} + if p.shouldSplit(orig.Y, orig.X) { + gridWidth, gridHeight, targetWidth, targetHeight := p.gridLayout(orig.Y, orig.X) + layout.rows = gridHeight + layout.cols = gridWidth + layout.hasThumbnail = p.useThumbnail && gridWidth*gridHeight != 1 + + resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear) + images := make([]processedVisionImage, 0, gridWidth*gridHeight+1) + for row := range gridHeight { + for col := range gridWidth { + rect := image.Rect( + col*p.tileSize, + row*p.tileSize, + (col+1)*p.tileSize, + (row+1)*p.tileSize, + ) + tile := cropImage(resized, rect) + images = append(images, processedVisionImage{ + data: imageproc.Normalize(tile, p.imageMean, p.imageStd, true, true), + size: tile.Bounds().Size(), + row: row + 1, + col: col + 1, + }) + } + } + + if layout.hasThumbnail { + thumbnail := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear) + images = append(images, processedVisionImage{ + data: imageproc.Normalize(thumbnail, p.imageMean, p.imageStd, true, true), + size: thumbnail.Bounds().Size(), + thumbnail: true, + }) + } + + return images, layout, nil + } + + single := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear) + return []processedVisionImage{{ + data: imageproc.Normalize(single, p.imageMean, p.imageStd, true, true), + size: single.Bounds().Size(), + }}, layout, nil +} + +func (p ImageProcessor) shouldSplit(height, width int) bool { + if !p.doImageSplitting || p.minTiles == 1 && p.maxTiles == 1 { + return false + } + + totalFactor := p.patchSize * p.downsampleFactor + hBar := max(p.patchSize, roundByFactor(height, totalFactor)) + wBar := max(p.patchSize, roundByFactor(width, totalFactor)) + + limit := float64(p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor) + limit *= p.maxPixelsTolerance + + return float64(hBar*wBar) > limit +} + +func (p ImageProcessor) smartResize(height, width int) (int, int) { + totalFactor := p.patchSize * p.downsampleFactor + minPixels := p.minImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor + maxPixels := p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor + + hBar := max(totalFactor, roundByFactor(height, totalFactor)) + wBar := max(totalFactor, roundByFactor(width, totalFactor)) + + if hBar*wBar > maxPixels { + beta := math.Sqrt(float64(height*width) / float64(maxPixels)) + hBar = max(totalFactor, int(math.Floor(float64(height)/beta/float64(totalFactor)))*totalFactor) + wBar = max(totalFactor, int(math.Floor(float64(width)/beta/float64(totalFactor)))*totalFactor) + } else if hBar*wBar < minPixels { + beta := math.Sqrt(float64(minPixels) / float64(height*width)) + hBar = int(math.Ceil(float64(height)*beta/float64(totalFactor))) * totalFactor + wBar = int(math.Ceil(float64(width)*beta/float64(totalFactor))) * totalFactor + } + + return wBar, hBar +} + +func (p ImageProcessor) gridLayout(height, width int) (gridWidth, gridHeight, targetWidth, targetHeight int) { + aspectRatio := float64(width) / float64(height) + targetRatios := p.targetRatios() + bestRatio := clipImageSize{width: 1, height: 1} + bestRatioDiff := math.MaxFloat64 + area := float64(width * height) + + for _, ratio := range targetRatios { + targetAspect := float64(ratio.width) / float64(ratio.height) + ratioDiff := math.Abs(aspectRatio - targetAspect) + + if ratioDiff < bestRatioDiff { + bestRatioDiff = ratioDiff + bestRatio = ratio + continue + } + + if ratioDiff == bestRatioDiff { + targetArea := float64(p.tileSize * p.tileSize * ratio.width * ratio.height) + if area > 0.5*targetArea { + bestRatio = ratio + } + } + } + + return bestRatio.width, bestRatio.height, p.tileSize * bestRatio.width, p.tileSize * bestRatio.height +} + +type clipImageSize struct { + width int + height int +} + +func (p ImageProcessor) targetRatios() []clipImageSize { + targetRatios := make([]clipImageSize, 0, p.maxTiles*p.maxTiles) + for n := p.minTiles; n <= p.maxTiles; n++ { + for w := 1; w <= n; w++ { + for h := 1; h <= n; h++ { + if w*h < p.minTiles || w*h > p.maxTiles { + continue + } + targetRatios = append(targetRatios, clipImageSize{width: w, height: h}) + } + } + } + + unique := targetRatios[:0] + for _, ratio := range targetRatios { + if slices.Contains(unique, ratio) { + continue + } + unique = append(unique, ratio) + } + + slices.SortFunc(unique, func(a, b clipImageSize) int { + return a.width*a.height - b.width*b.height + }) + + return unique +} + +func roundByFactor(number, factor int) int { + if factor <= 0 { + return number + } + return int(math.RoundToEven(float64(number)/float64(factor))) * factor +} + +func cropImage(img image.Image, rect image.Rectangle) image.Image { + dst := image.NewRGBA(image.Rect(0, 0, rect.Dx(), rect.Dy())) + stdimage.Draw(dst, dst.Bounds(), img, rect.Min, stdimage.Src) + return dst +} diff --git a/model/models/lfm2/process_image_test.go b/model/models/lfm2/process_image_test.go new file mode 100644 index 000000000..51f65920f --- /dev/null +++ b/model/models/lfm2/process_image_test.go @@ -0,0 +1,105 @@ +package lfm2 + +import ( + "image" + "image/color" + "testing" +) + +func TestProcessImageSingleTile(t *testing.T) { + p := ImageProcessor{ + patchSize: 16, + downsampleFactor: 2, + numChannels: 3, + imageMean: [3]float32{0.5, 0.5, 0.5}, + imageStd: [3]float32{0.5, 0.5, 0.5}, + doImageSplitting: true, + minTiles: 2, + maxTiles: 10, + useThumbnail: true, + tileSize: 512, + minImageTokens: 64, + maxImageTokens: 256, + maxPixelsTolerance: 2.0, + } + + img := image.NewRGBA(image.Rect(0, 0, 320, 320)) + out, layout, err := p.ProcessImage(img) + if err != nil { + t.Fatalf("ProcessImage returned error: %v", err) + } + + if layout.rows != 1 || layout.cols != 1 || layout.hasThumbnail { + t.Fatalf("layout = %+v, want rows=1 cols=1 hasThumbnail=false", layout) + } + if len(out) != 1 { + t.Fatalf("len(out) = %d, want 1", len(out)) + } + if out[0].size != (image.Point{X: 320, Y: 320}) { + t.Fatalf("single image size = %+v, want 320x320", out[0].size) + } + if out[0].thumbnail { + t.Fatalf("single image should not be marked as thumbnail") + } +} + +func TestProcessImageDynamicTiling(t *testing.T) { + p := ImageProcessor{ + patchSize: 16, + downsampleFactor: 2, + numChannels: 3, + imageMean: [3]float32{0.5, 0.5, 0.5}, + imageStd: [3]float32{0.5, 0.5, 0.5}, + doImageSplitting: true, + minTiles: 2, + maxTiles: 10, + useThumbnail: true, + tileSize: 512, + minImageTokens: 64, + maxImageTokens: 256, + maxPixelsTolerance: 2.0, + } + + // Wide image that should trigger multi-tile splitting. + img := image.NewRGBA(image.Rect(0, 0, 3000, 1000)) + fill := color.RGBA{R: 120, G: 90, B: 60, A: 255} + for y := range 1000 { + for x := range 3000 { + img.Set(x, y, fill) + } + } + + out, layout, err := p.ProcessImage(img) + if err != nil { + t.Fatalf("ProcessImage returned error: %v", err) + } + + if layout.rows*layout.cols <= 1 { + t.Fatalf("expected multi-tile layout, got %+v", layout) + } + if !layout.hasThumbnail { + t.Fatalf("expected thumbnail for multi-tile layout") + } + + wantLen := layout.rows*layout.cols + 1 + if len(out) != wantLen { + t.Fatalf("len(out) = %d, want %d", len(out), wantLen) + } + + for i := range layout.rows * layout.cols { + if out[i].size != (image.Point{X: 512, Y: 512}) { + t.Fatalf("tile[%d] size = %+v, want 512x512", i, out[i].size) + } + if out[i].thumbnail { + t.Fatalf("tile[%d] should not be marked as thumbnail", i) + } + } + + thumb := out[len(out)-1] + if !thumb.thumbnail { + t.Fatalf("last chunk should be thumbnail") + } + if thumb.size.X%32 != 0 || thumb.size.Y%32 != 0 { + t.Fatalf("thumbnail size = %+v, want dimensions aligned to 32", thumb.size) + } +} diff --git a/model/parsers/lfm2.go b/model/parsers/lfm2.go index 4aade6926..43f926d8a 100644 --- a/model/parsers/lfm2.go +++ b/model/parsers/lfm2.go @@ -32,6 +32,8 @@ type LFM2Parser struct { hasThinkingSupport bool needsThinkingLeadingTrim bool // trim leading whitespace after tag needsContentLeadingTrim bool // trim leading whitespace after tag + toolNames map[string]struct{} + hasTools bool } func (p *LFM2Parser) HasToolSupport() bool { @@ -63,6 +65,13 @@ func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.T } func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.toolNames = make(map[string]struct{}, len(tools)) + p.hasTools = len(tools) > 0 + for _, tool := range tools { + if tool.Function.Name != "" { + p.toolNames[tool.Function.Name] = struct{}{} + } + } p.setInitialState(lastMessage, thinkValue) return tools } @@ -105,9 +114,33 @@ func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, } } + // Fallback for models that emit bare tool calls without <|tool_call_*|> wrappers. + if done && len(toolCalls) == 0 && p.hasTools { + candidate := strings.TrimSpace(contentSb.String()) + if fallbackCalls, parseErr := p.parseToolCallsContent(candidate); parseErr == nil && p.toolCallsAllowed(fallbackCalls) { + contentSb.Reset() + toolCalls = append(toolCalls, fallbackCalls...) + } + } + return contentSb.String(), thinkingSb.String(), toolCalls, nil } +func (p *LFM2Parser) toolCallsAllowed(calls []api.ToolCall) bool { + if len(calls) == 0 { + return false + } + if len(p.toolNames) == 0 { + return true + } + for _, call := range calls { + if _, ok := p.toolNames[call.Function.Name]; !ok { + return false + } + } + return true +} + func (p *LFM2Parser) parseEvents() []lfm2Event { var all []lfm2Event @@ -269,36 +302,16 @@ func (p *LFM2Parser) eat() ([]lfm2Event, bool) { return events, false } -// parseToolCallsContent parses one or more tool calls from content -// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)] +// parseToolCallsContent parses one or more Python-style tool calls. +// Example: [func1(arg='v'), func2(x=1)] func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) { content = strings.TrimSpace(content) - // Try JSON format first: {"name": "func", "arguments": {...}} - var parsed struct { - Name string `json:"name"` - Arguments json.RawMessage `json:"arguments"` - } + // Be tolerant of malformed outputs that include wrapper tags without proper pairing. + content = strings.TrimSpace(strings.TrimPrefix(content, lfm2ToolCallStartTag)) + content = strings.TrimSpace(strings.TrimSuffix(content, lfm2ToolCallEndTag)) - if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" { - var args api.ToolCallFunctionArguments - if len(parsed.Arguments) > 0 { - if err := json.Unmarshal(parsed.Arguments, &args); err != nil { - return nil, err - } - } else { - args = api.NewToolCallFunctionArguments() - } - - return []api.ToolCall{{ - Function: api.ToolCallFunction{ - Name: parsed.Name, - Arguments: args, - }, - }}, nil - } - - // Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1') + // Parse Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1') return p.parsePythonStyleToolCalls(content) } @@ -417,21 +430,16 @@ func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) // parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2" func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error { - // Simple state machine to parse key='value' pairs - // Handles: command='ls', flag="-la", count=42, enabled=true - var key string i := 0 - for i < len(argsStr) { - // Skip whitespace - for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') { + // Skip separators and whitespace. + for i < len(argsStr) && (argsStr[i] == ',' || unicode.IsSpace(rune(argsStr[i]))) { i++ } if i >= len(argsStr) { break } - // Parse key keyStart := i for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' { i++ @@ -439,60 +447,238 @@ func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error if i >= len(argsStr) || argsStr[i] != '=' { return errors.New("invalid argument: expected '='") } - key = strings.TrimSpace(argsStr[keyStart:i]) + + key := strings.TrimSpace(argsStr[keyStart:i]) + if key == "" { + return errors.New("invalid argument: empty key") + } i++ // skip '=' - // Skip whitespace after = - for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') { + for i < len(argsStr) && unicode.IsSpace(rune(argsStr[i])) { i++ } - - // Parse value - var value string - if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') { - // Quoted string - quote := argsStr[i] - i++ - valueStart := i - for i < len(argsStr) && argsStr[i] != quote { - if argsStr[i] == '\\' && i+1 < len(argsStr) { - i += 2 // skip escaped char - } else { - i++ - } - } - value = argsStr[valueStart:i] - if i < len(argsStr) { - i++ // skip closing quote - } - args.Set(key, value) - } else { - // Unquoted value (number, bool, etc) - valueStart := i - for i < len(argsStr) && argsStr[i] != ',' { - i++ - } - value = strings.TrimSpace(argsStr[valueStart:i]) - - // Try to parse as number or bool - if v, err := strconv.ParseInt(value, 10, 64); err == nil { - args.Set(key, v) - } else if v, err := strconv.ParseFloat(value, 64); err == nil { - args.Set(key, v) - } else if value == "true" { - args.Set(key, true) - } else if value == "false" { - args.Set(key, false) - } else { - args.Set(key, value) - } + if i >= len(argsStr) { + return errors.New("invalid argument: missing value") } - // Skip comma and whitespace - for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') { + value, next, err := parsePythonArgValue(argsStr, i) + if err != nil { + return err + } + args.Set(key, value) + i = next + + // Optional trailing comma before next key/value. + if i < len(argsStr) && argsStr[i] == ',' { i++ } } return nil } + +func parsePythonArgValue(s string, i int) (any, int, error) { + if i >= len(s) { + return nil, i, errors.New("invalid argument: missing value") + } + + // Quoted string literal. + if s[i] == '\'' || s[i] == '"' { + quote := s[i] + i++ + start := i + for i < len(s) { + if s[i] == '\\' && i+1 < len(s) { + i += 2 + continue + } + if s[i] == quote { + value := s[start:i] + i++ + return value, i, nil + } + i++ + } + return nil, i, errors.New("invalid argument: unterminated string") + } + + // Unquoted literal. Consume until top-level comma. + start := i + depthParen, depthSquare, depthCurly := 0, 0, 0 + inString := false + var quote byte + escaped := false + + for i < len(s) { + ch := s[i] + if inString { + if escaped { + escaped = false + } else if ch == '\\' { + escaped = true + } else if ch == quote { + inString = false + } + i++ + continue + } + + switch ch { + case '\'', '"': + inString = true + quote = ch + case '(': + depthParen++ + case ')': + if depthParen > 0 { + depthParen-- + } + case '[': + depthSquare++ + case ']': + if depthSquare > 0 { + depthSquare-- + } + case '{': + depthCurly++ + case '}': + if depthCurly > 0 { + depthCurly-- + } + case ',': + if depthParen == 0 && depthSquare == 0 && depthCurly == 0 { + token := strings.TrimSpace(s[start:i]) + value, err := parsePythonLiteral(token) + return value, i, err + } + } + i++ + } + + token := strings.TrimSpace(s[start:i]) + value, err := parsePythonLiteral(token) + return value, i, err +} + +func parsePythonLiteral(token string) (any, error) { + switch token { + case "": + return "", nil + case "true", "True": + return true, nil + case "false", "False": + return false, nil + case "null", "None": + return nil, nil + } + + if v, err := strconv.ParseInt(token, 10, 64); err == nil { + return v, nil + } + if v, err := strconv.ParseFloat(token, 64); err == nil { + return v, nil + } + + if strings.HasPrefix(token, "[") || strings.HasPrefix(token, "{") { + var parsed any + if err := json.Unmarshal([]byte(token), &parsed); err == nil { + return parsed, nil + } + + if converted, err := pythonLiteralToJSON(token); err == nil { + if err := json.Unmarshal([]byte(converted), &parsed); err == nil { + return parsed, nil + } + } + } + + return token, nil +} + +func pythonLiteralToJSON(s string) (string, error) { + var out strings.Builder + out.Grow(len(s) + len(s)/8) + + inString := false + var quote byte + escaped := false + + for i := 0; i < len(s); i++ { + ch := s[i] + + if inString { + if escaped { + out.WriteByte(ch) + escaped = false + continue + } + + if ch == '\\' { + out.WriteByte(ch) + escaped = true + continue + } + + if ch == quote { + out.WriteByte('"') + inString = false + continue + } + + if quote == '\'' && ch == '"' { + out.WriteString(`\"`) + continue + } + + out.WriteByte(ch) + continue + } + + if ch == '\'' || ch == '"' { + inString = true + quote = ch + escaped = false + out.WriteByte('"') + continue + } + + // Replace Python identifiers with JSON equivalents when outside strings. + if isIdentStart(ch) { + j := i + 1 + for j < len(s) && isIdentPart(s[j]) { + j++ + } + + ident := s[i:j] + switch ident { + case "True": + out.WriteString("true") + case "False": + out.WriteString("false") + case "None": + out.WriteString("null") + default: + out.WriteString(ident) + } + + i = j - 1 + continue + } + + out.WriteByte(ch) + } + + if inString { + return "", errors.New("unterminated string") + } + + return out.String(), nil +} + +func isIdentStart(b byte) bool { + return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || b == '_' +} + +func isIdentPart(b byte) bool { + return isIdentStart(b) || (b >= '0' && b <= '9') +} diff --git a/model/parsers/lfm2_test.go b/model/parsers/lfm2_test.go index 3e139b811..c353424b4 100644 --- a/model/parsers/lfm2_test.go +++ b/model/parsers/lfm2_test.go @@ -39,7 +39,7 @@ func TestLFM2Parser(t *testing.T) { }, { name: "tool_call_simple", - input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>", + input: "I'll check the weather.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>", expectedContent: "I'll check the weather.", expectedCalls: []api.ToolCall{ { @@ -55,7 +55,7 @@ func TestLFM2Parser(t *testing.T) { }, { name: "multiple_tool_calls", - input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>", + input: "Getting weather for both cities.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|tool_call_start|>[get_weather(location=\"London\")]<|tool_call_end|>", expectedContent: "Getting weather for both cities.", expectedCalls: []api.ToolCall{ { @@ -79,7 +79,7 @@ func TestLFM2Parser(t *testing.T) { }, { name: "complex_tool_arguments", - input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>", + input: "Processing data.<|tool_call_start|>[process_data(items=['item1','item2'], config={'enabled': True, 'threshold': 0.95})]<|tool_call_end|>", expectedContent: "Processing data.", expectedCalls: []api.ToolCall{ { @@ -96,7 +96,7 @@ func TestLFM2Parser(t *testing.T) { }, { name: "thinking_with_tool_call", - input: "Let me check the weather...I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>", + input: "Let me check the weather...I'll get that for you.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>", expectedThinking: "Let me check the weather...", expectedContent: "I'll get that for you.", expectedCalls: []api.ToolCall{ @@ -144,16 +144,16 @@ func TestLFM2Parser(t *testing.T) { hasThinking: true, }, { - name: "tool_call_with_unicode_args", - input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>", + name: "tool_call_with_text_args", + input: "Searching for information.<|tool_call_start|>[search(query='beijing weather', language='zh')]<|tool_call_end|>", expectedContent: "Searching for information.", expectedCalls: []api.ToolCall{ { Function: api.ToolCallFunction{ Name: "search", Arguments: testArgs(map[string]any{ - "query": "北京天气", - "language": "中文", + "query": "beijing weather", + "language": "zh", }), }, }, @@ -169,7 +169,7 @@ func TestLFM2Parser(t *testing.T) { }, { name: "empty_tool_call_args", - input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>", + input: "Pinging server.<|tool_call_start|>[ping()]<|tool_call_end|>", expectedContent: "Pinging server.", expectedCalls: []api.ToolCall{ { @@ -353,7 +353,7 @@ func TestLFM2Parser_Streaming(t *testing.T) { }, { name: "streaming_tool_call", - chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"}, + chunks: []string{"I'll check weather.", "<|tool_call_start|>", "[get_weather(", "location=\"Paris\")]", "<|tool_call_end|>"}, expectedContent: "I'll check weather.", expectedCalls: []api.ToolCall{ { @@ -381,16 +381,16 @@ func TestLFM2Parser_Streaming(t *testing.T) { hasThinking: false, }, { - name: "streaming_tool_call_with_split_json", - chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"}, + name: "streaming_tool_call_with_split_python", + chunks: []string{"Processing.", "<|tool_call_start|>", "[calc(", "x=42, ", "y=24)]", "<|tool_call_end|>"}, expectedContent: "Processing.", expectedCalls: []api.ToolCall{ { Function: api.ToolCallFunction{ Name: "calc", Arguments: testArgs(map[string]any{ - "x": float64(42), - "y": float64(24), + "x": int64(42), + "y": int64(24), }), }, }, @@ -516,8 +516,8 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) { expectError bool }{ { - name: "valid_tool_call", - content: `{"name":"get_weather","arguments":{"location":"Paris"}}`, + name: "python_style_single_call", + content: `get_weather(location="Paris")`, expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_weather", @@ -528,21 +528,33 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) { }, }, { - name: "complex_arguments", - content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`, + name: "python_style_with_brackets", + content: `[get_weather(location="Paris")]`, expected: api.ToolCall{ Function: api.ToolCallFunction{ - Name: "process_data", + Name: "get_weather", Arguments: testArgs(map[string]any{ - "items": []interface{}{"a", "b"}, - "config": map[string]interface{}{"enabled": true}, + "location": "Paris", }), }, }, }, { - name: "empty_arguments", - content: `{"name":"ping","arguments":{}}`, + name: "python_style_complex_arguments", + content: `process(items=['a', 'b'], config={'enabled': True})`, + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "process", + Arguments: testArgs(map[string]any{ + "items": []any{"a", "b"}, + "config": map[string]any{"enabled": true}, + }), + }, + }, + }, + { + name: "python_style_empty_arguments", + content: `ping()`, expected: api.ToolCall{ Function: api.ToolCallFunction{ Name: "ping", @@ -551,44 +563,13 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) { }, }, { - name: "unicode_in_tool_name", - content: `{"name":"获取天气","arguments":{"城市":"北京"}}`, - expected: api.ToolCall{ - Function: api.ToolCallFunction{ - Name: "获取天气", - Arguments: testArgs(map[string]any{ - "城市": "北京", - }), - }, - }, - }, - { - name: "numeric_arguments", - content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`, - expected: api.ToolCall{ - Function: api.ToolCallFunction{ - Name: "calculate", - Arguments: testArgs(map[string]any{ - "x": 3.14, - "y": float64(42), - "enabled": true, - }), - }, - }, - }, - { - name: "invalid_json", - content: `{invalid json}`, + name: "missing_parenthesis", + content: `get_weather location="Paris")`, expectError: true, }, { - name: "missing_name", - content: `{"arguments":{"arg":"value"}}`, - expectError: true, - }, - { - name: "empty_name", - content: `{"name":"","arguments":{"arg":"value"}}`, + name: "invalid_argument_format", + content: `bash(command)`, expectError: true, }, } @@ -645,6 +626,24 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) { }, }, }, + { + name: "python_style_complex_literals", + content: `[AskUserQuestion(question="What's up?", headers=['Hello!', 'How can I help you?'], options=['Debugging help', 'Code writing assistance'], multiSelect=False, metadata={'priority': 1, 'active': True})]`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "AskUserQuestion", + Arguments: testArgs(map[string]any{ + "question": "What's up?", + "headers": []any{"Hello!", "How can I help you?"}, + "options": []any{"Debugging help", "Code writing assistance"}, + "multiSelect": false, + "metadata": map[string]any{"priority": float64(1), "active": true}, + }), + }, + }, + }, + }, { name: "single_python_style_call", content: `bash(command='ls -la')`, @@ -673,6 +672,34 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) { }, }, }, + { + name: "single_call_with_orphan_end_tag", + content: `[bash(command='ls')]<|tool_call_end|>`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "ls", + }), + }, + }, + }, + }, + { + name: "single_call_with_wrapper_tags", + content: `<|tool_call_start|>[bash(command='pwd')]<|tool_call_end|>`, + expected: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "bash", + Arguments: testArgs(map[string]any{ + "command": "pwd", + }), + }, + }, + }, + }, { name: "multiple_different_functions", content: `[get_weather(location='Paris'),search(query='news')]`, @@ -1086,3 +1113,106 @@ func TestLFM2Parser_EdgeCases(t *testing.T) { }) } } + +func TestLFM2Parser_BareToolCallFallback(t *testing.T) { + parser := &LFM2Parser{} + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + }, + }, + } + parser.Init(tools, nil, &api.ThinkValue{Value: false}) + + content, thinking, calls, err := parser.Add(`[get_weather(location="Paris")]`, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if content != "" { + t.Fatalf("expected empty content, got %q", content) + } + if thinking != "" { + t.Fatalf("expected empty thinking, got %q", thinking) + } + if len(calls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("expected tool name get_weather, got %q", calls[0].Function.Name) + } +} + +func TestLFM2Parser_BareUnknownToolCallDoesNotParse(t *testing.T) { + parser := &LFM2Parser{} + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + }, + }, + } + parser.Init(tools, nil, &api.ThinkValue{Value: false}) + + input := `[unknown_tool(location="Paris")]` + content, _, calls, err := parser.Add(input, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if content != input { + t.Fatalf("expected content to be preserved, got %q", content) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } +} + +func TestLFM2Parser_ImagePlaceholdersPreserved(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "indexed_img_placeholder", + input: "[img-0]describe this image", + }, + { + name: "template_image_placeholder", + input: "describe this image", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &LFM2Parser{} + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "bash", + }, + }, + } + parser.Init(tools, nil, &api.ThinkValue{Value: false}) + + content, thinking, calls, err := parser.Add(tt.input, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if content != tt.input { + t.Fatalf("expected content %q, got %q", tt.input, content) + } + if thinking != "" { + t.Fatalf("expected empty thinking, got %q", thinking) + } + if len(calls) != 0 { + t.Fatalf("expected no tool calls, got %d", len(calls)) + } + }) + } +} diff --git a/model/parsers/parsers_test.go b/model/parsers/parsers_test.go index 15c2f664f..2cda2a64c 100644 --- a/model/parsers/parsers_test.go +++ b/model/parsers/parsers_test.go @@ -57,6 +57,8 @@ func TestBuiltInParsersStillWork(t *testing.T) { {"qwen3"}, {"qwen3-thinking"}, {"qwen3-coder"}, + {"lfm2"}, + {"lfm2-thinking"}, {"harmony"}, } diff --git a/model/renderers/lfm2.go b/model/renderers/lfm2.go index 5c046835f..1db023c49 100644 --- a/model/renderers/lfm2.go +++ b/model/renderers/lfm2.go @@ -1,7 +1,9 @@ package renderers import ( + "bytes" "encoding/json" + "sort" "strings" "github.com/ollama/ollama/api" @@ -9,18 +11,218 @@ import ( type LFM2Renderer struct { IsThinking bool + useImgTags bool +} + +const lfm2BOSToken = "<|startoftext|>" + +const ( + lfm2ToolListStartTag = "<|tool_list_start|>" + lfm2ToolListEndTag = "<|tool_list_end|>" + lfm2ToolCallStartTag = "<|tool_call_start|>" + lfm2ToolCallEndTag = "<|tool_call_end|>" + lfm2ToolResponseStartTag = "<|tool_response_start|>" + lfm2ToolResponseEndTag = "<|tool_response_end|>" +) + +func lfm2RenderSystemContent(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + var sb strings.Builder + for _, item := range v { + obj, ok := item.(map[string]any) + if !ok { + continue + } + + if itemType, _ := obj["type"].(string); itemType == "text" { + if text, ok := obj["text"].(string); ok { + sb.WriteString(text) + } + } + } + return sb.String() + default: + return "" + } +} + +func lfm2JSON(v any) string { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(v); err != nil { + fallback, _ := json.Marshal(v) + return string(fallback) + } + + encoded := bytes.TrimSuffix(buf.Bytes(), []byte{'\n'}) + + // HF `tojson` defaults to `json.dumps(..., separators=None)`, which inserts + // a space after commas and colons. + var out strings.Builder + out.Grow(len(encoded) + len(encoded)/8) + + inString := false + escaped := false + for i, b := range encoded { + out.WriteByte(b) + + if inString { + if escaped { + escaped = false + continue + } + if b == '\\' { + escaped = true + continue + } + if b == '"' { + inString = false + } + continue + } + + if b == '"' { + inString = true + continue + } + + if (b == ':' || b == ',') && i+1 < len(encoded) { + next := encoded[i+1] + if next != ' ' && next != '\n' && next != '\r' && next != '\t' { + out.WriteByte(' ') + } + } + } + + return out.String() +} + +func lfm2ImagePlaceholder(useImgTags bool) string { + if useImgTags { + return "[img]" + } + + return "" +} + +func lfm2RenderContent(content any, useImgTags bool) string { + switch v := content.(type) { + case string: + return v + case []any: + var sb strings.Builder + for _, item := range v { + obj, ok := item.(map[string]any) + if !ok { + sb.WriteString(lfm2JSON(item)) + continue + } + + itemType, _ := obj["type"].(string) + switch itemType { + case "image": + sb.WriteString(lfm2ImagePlaceholder(useImgTags)) + case "text": + if text, ok := obj["text"].(string); ok { + sb.WriteString(text) + } else { + sb.WriteString(lfm2JSON(item)) + } + default: + sb.WriteString(lfm2JSON(item)) + } + } + return sb.String() + default: + return lfm2JSON(content) + } +} + +func lfm2ToolSchema(tool api.Tool) any { + if tool.Function.Name == "" { + return tool + } + + // LFM2 templates are typically fed function-schema objects (name/description/parameters). + return tool.Function +} + +func lfm2ToolCallArgument(v any) string { + return lfm2JSON(v) +} + +func lfm2RenderToolCalls(calls []api.ToolCall) string { + var sb strings.Builder + + sb.WriteString(lfm2ToolCallStartTag) + sb.WriteString("[") + for i, tc := range calls { + if i > 0 { + sb.WriteString(",") + } + + sb.WriteString(tc.Function.Name) + sb.WriteString("(") + + keys := make([]string, 0, tc.Function.Arguments.Len()) + for key := range tc.Function.Arguments.All() { + keys = append(keys, key) + } + sort.Strings(keys) + + for j, key := range keys { + if j > 0 { + sb.WriteString(",") + } + value, _ := tc.Function.Arguments.Get(key) + sb.WriteString(key) + sb.WriteString("=") + sb.WriteString(lfm2ToolCallArgument(value)) + } + + sb.WriteString(")") + } + sb.WriteString("]") + sb.WriteString(lfm2ToolCallEndTag) + + return sb.String() +} + +func (r *LFM2Renderer) renderMessageContent(message api.Message) string { + content := lfm2RenderContent(message.Content, r.useImgTags) + if len(message.Images) == 0 { + return content + } + + // chatPrompt may already have inserted [img] / [img-n] placeholders. + if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "") { + return content + } + + var sb strings.Builder + placeholder := lfm2ImagePlaceholder(r.useImgTags) + for range message.Images { + sb.WriteString(placeholder) + } + sb.WriteString(content) + return sb.String() } func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { var sb strings.Builder - // Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer + // Follow Liquid tool-use formatting for LFM2 tool wrappers. + sb.WriteString(lfm2BOSToken) // Extract first system message if present (to combine with tools) var firstSystemContent string startIdx := 0 if len(messages) > 0 && messages[0].Role == "system" { - firstSystemContent = messages[0].Content + firstSystemContent = lfm2RenderSystemContent(messages[0].Content) startIdx = 1 } @@ -29,18 +231,17 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal if firstSystemContent != "" { firstSystemContent += "\n" } - firstSystemContent += "List of tools: [" + firstSystemContent += "List of tools: " + firstSystemContent += lfm2ToolListStartTag + firstSystemContent += "[" for i, tool := range tools { - toolJSON, err := json.Marshal(tool) - if err != nil { - return "", err - } - firstSystemContent += string(toolJSON) + firstSystemContent += lfm2JSON(lfm2ToolSchema(tool)) if i < len(tools)-1 { firstSystemContent += ", " } } firstSystemContent += "]" + firstSystemContent += lfm2ToolListEndTag } // Output first system block if it has content @@ -50,6 +251,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal sb.WriteString("<|im_end|>\n") } + keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool()) + // Find the index of the last assistant message for thinking stripping lastAssistantIndex := -1 for i := len(messages) - 1; i >= startIdx; i-- { @@ -59,85 +262,47 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal } } - // Track whether we need to add generation prompt - needsGenerationPrompt := len(messages) > 0 - for i := startIdx; i < len(messages); i++ { message := messages[i] - switch message.Role { - case "system": - // Additional system messages (after the first) are rendered normally - sb.WriteString("<|im_start|>system\n") - sb.WriteString(message.Content) - sb.WriteString("<|im_end|>\n") + lastMessage := i == len(messages)-1 + prefill := lastMessage && message.Role == "assistant" - case "user": - sb.WriteString("<|im_start|>user\n") - sb.WriteString(message.Content) - sb.WriteString("<|im_end|>\n") - needsGenerationPrompt = true + sb.WriteString("<|im_start|>") + sb.WriteString(message.Role) + sb.WriteString("\n") - case "assistant": - sb.WriteString("<|im_start|>assistant\n") - - // Check if this is the last assistant message - isLastAssistant := i == lastAssistantIndex - - // Process content (may need thinking stripped) - content := message.Content - - // Handle thinking tags in assistant content - keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool()) - if strings.Contains(content, "") { - parts := strings.SplitN(content, "", 2) - if len(parts) > 1 { - if !isLastAssistant && !keepPastThinking { - // Strip thinking entirely for past assistant messages - content = strings.TrimSpace(parts[1]) - } else { - // Preserve thinking but trim whitespace after - content = parts[0] + "" + strings.TrimLeft(parts[1], " \t\n\r") - } - } + content := r.renderMessageContent(message) + if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex { + if idx := strings.LastIndex(content, ""); idx >= 0 { + content = strings.TrimSpace(content[idx+len(""):]) } - - if len(message.ToolCalls) > 0 { - // Assistant with tool calls - write content first (if any after stripping) - if content != "" { - sb.WriteString(content) - } - - for _, toolCall := range message.ToolCalls { - sb.WriteString("<|tool_call_start|>") - toolCallJSON := map[string]any{ - "name": toolCall.Function.Name, - "arguments": toolCall.Function.Arguments, - } - callJSON, _ := json.Marshal(toolCallJSON) - sb.WriteString(string(callJSON)) - sb.WriteString("<|tool_call_end|>") - } + } + if message.Role == "assistant" && len(message.ToolCalls) > 0 && !strings.Contains(content, lfm2ToolCallStartTag) { + if strings.TrimSpace(content) == "" { + content = lfm2RenderToolCalls(message.ToolCalls) + content } else { - sb.WriteString(content) + content = lfm2RenderToolCalls(message.ToolCalls) + "\n" + content } + } + if message.Role == "tool" && !strings.Contains(content, lfm2ToolResponseStartTag) { + content = lfm2ToolResponseStartTag + content + lfm2ToolResponseEndTag + } + sb.WriteString(content) + if !prefill { sb.WriteString("<|im_end|>\n") - needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true - - case "tool": - // Tool responses are rendered as plain messages per the chat template - sb.WriteString("<|im_start|>tool\n") - sb.WriteString(message.Content) - sb.WriteString("<|im_end|>\n") - needsGenerationPrompt = true } } - // Add generation prompt + needsGenerationPrompt := true + if len(messages) > 0 && messages[len(messages)-1].Role == "assistant" { + needsGenerationPrompt = false + } + if needsGenerationPrompt { + // RenderWithRenderer uses add_generation_prompt=true for chat rendering, + // unless we're prefilling a trailing assistant message. sb.WriteString("<|im_start|>assistant\n") - // Note: Model is a "thinking-only" model - it will output itself - // We don't add tag to the prompt } return sb.String(), nil diff --git a/model/renderers/lfm2_test.go b/model/renderers/lfm2_test.go index 9eb07eea3..724858c14 100644 --- a/model/renderers/lfm2_test.go +++ b/model/renderers/lfm2_test.go @@ -8,73 +8,136 @@ import ( "github.com/ollama/ollama/api" ) -func TestLFM2Renderer(t *testing.T) { +func TestLFM2Renderer_ChatTemplateParity(t *testing.T) { tests := []struct { name string + renderer *LFM2Renderer messages []api.Message tools []api.Tool thinkValue *api.ThinkValue expected string }{ { - name: "basic user message", + name: "user_only", + renderer: &LFM2Renderer{IsThinking: false}, messages: []api.Message{ - {Role: "user", Content: "Hello!"}, + {Role: "user", Content: "Hello"}, }, thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n", + expected: "<|startoftext|><|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n", }, { - name: "basic with system message", - messages: []api.Message{ - {Role: "system", Content: "You are a helpful assistant."}, - {Role: "user", Content: "Hello!"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "multiple system messages rendered separately", - messages: []api.Message{ - {Role: "system", Content: "First instruction."}, - {Role: "system", Content: "Second instruction."}, - {Role: "user", Content: "Hello!"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "multi-turn conversation", - messages: []api.Message{ - {Role: "user", Content: "What is 2+2?"}, - {Role: "assistant", Content: "The answer is 4."}, - {Role: "user", Content: "Thanks!"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "only system message", + name: "system_and_user", + renderer: &LFM2Renderer{IsThinking: false}, messages: []api.Message{ {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hi"}, }, thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n", + expected: "<|startoftext|><|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n", }, { - // When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false) - name: "user-assistant-user: last assistant preserves thinking", + name: "tools_without_system", + renderer: &LFM2Renderer{IsThinking: false}, messages: []api.Message{ - {Role: "user", Content: "Q1"}, - {Role: "assistant", Content: "reasoningA1"}, - {Role: "user", Content: "Q2"}, + {Role: "user", Content: "Use tools"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + }, + }, + }, }, thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nreasoningA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n", + expected: "<|startoftext|><|im_start|>system\nList of tools: <|tool_list_start|>[{\"name\": \"get_weather\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" + + "<|im_start|>user\nUse tools<|im_end|>\n<|im_start|>assistant\n", }, { - // With two assistants, first is stripped (not last), second preserved (is last) - name: "multi-turn thinking: first stripped, second preserved", + name: "first_system_combined_with_tools", + renderer: &LFM2Renderer{IsThinking: false}, + messages: []api.Message{ + {Role: "system", Content: "Follow instructions."}, + {Role: "user", Content: "Do work"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "tool_a", + Parameters: api.ToolFunctionParameters{ + Type: "object", + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "tool_b", + Parameters: api.ToolFunctionParameters{ + Type: "object", + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|startoftext|><|im_start|>system\nFollow instructions.\nList of tools: <|tool_list_start|>[{\"name\": \"tool_a\", \"parameters\": {\"type\": \"object\", \"properties\": null}}, {\"name\": \"tool_b\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" + + "<|im_start|>user\nDo work<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "assistant_tool_calls_and_tool_responses_are_rendered", + renderer: &LFM2Renderer{IsThinking: false}, + messages: []api.Message{ + {Role: "user", Content: "Call a tool"}, + { + Role: "assistant", + Content: "", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + }, + {Role: "tool", Content: "22C"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|im_end|>\n<|im_start|>tool\n<|tool_response_start|>22C<|tool_response_end|><|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "assistant_tool_calls_with_content_preserves_both", + renderer: &LFM2Renderer{IsThinking: false}, + messages: []api.Message{ + {Role: "user", Content: "Call a tool"}, + { + Role: "assistant", + Content: "Checking now.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: testArgs(map[string]any{ + "location": "Paris", + }), + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>\nChecking now.", + }, + { + name: "thinking_strips_non_last_assistant_when_disabled", + renderer: &LFM2Renderer{IsThinking: true}, messages: []api.Message{ {Role: "user", Content: "Q1"}, {Role: "assistant", Content: "reason1A1"}, @@ -82,11 +145,11 @@ func TestLFM2Renderer(t *testing.T) { {Role: "assistant", Content: "reason2A2"}, }, thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\nreason2A2<|im_end|>\n<|im_start|>assistant\n", + expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\nreason2A2", }, { - // With thinking enabled (keep_past_thinking=true), both preserved - name: "multi-turn thinking: both preserved when thinking enabled", + name: "thinking_preserves_past_assistant_when_enabled", + renderer: &LFM2Renderer{IsThinking: true}, messages: []api.Message{ {Role: "user", Content: "Q1"}, {Role: "assistant", Content: "reason1A1"}, @@ -94,334 +157,137 @@ func TestLFM2Renderer(t *testing.T) { {Role: "assistant", Content: "reason2A2"}, }, thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nreason1A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\nreason2A2<|im_end|>\n<|im_start|>assistant\n", + expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nreason1A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\nreason2A2", }, { - name: "assistant with tool calls", + name: "arbitrary_roles_are_rendered_verbatim", + renderer: &LFM2Renderer{IsThinking: false}, messages: []api.Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{ - "location": "Paris", - }), - }, - }, - }, - }, + {Role: "developer", Content: "Do X"}, + {Role: "user", Content: "Hi"}, }, thinkValue: &api.ThinkValue{Value: false}, - expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", + expected: "<|startoftext|><|im_start|>developer\nDo X<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n", }, { - name: "assistant with content and tool calls", + name: "empty_messages_still_add_generation_prompt", + renderer: &LFM2Renderer{IsThinking: false}, + messages: nil, + thinkValue: &api.ThinkValue{Value: false}, + expected: "<|startoftext|><|im_start|>assistant\n", + }, + { + name: "assistant_prefill_no_generation_prompt", + renderer: &LFM2Renderer{IsThinking: false}, messages: []api.Message{ - {Role: "user", Content: "What's the weather in Paris?"}, - { - Role: "assistant", - Content: "Let me check.", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{ - "location": "Paris", - }), - }, - }, - }, - }, + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello"}, }, thinkValue: &api.ThinkValue{Value: false}, - expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", - }, - { - name: "tool response", - messages: []api.Message{ - {Role: "user", Content: "What's the weather?"}, - {Role: "assistant", Content: "Let me check."}, - {Role: "tool", Content: "22C, Sunny"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "multiple tool calls", - messages: []api.Message{ - {Role: "user", Content: "Get weather for Paris and London"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{ - "location": "Paris", - }), - }, - }, - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{ - "location": "London", - }), - }, - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", - }, - { - name: "tools definitions with system message", - messages: []api.Message{ - {Role: "system", Content: "You are helpful."}, - {Role: "user", Content: "What's the weather?"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get current weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{ - "location": { - Type: api.PropertyType{"string"}, - Description: "City name", - }, - }), - Required: []string{"location"}, - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", - }, - { - name: "tools definitions without system message", - messages: []api.Message{ - {Role: "user", Content: "What's the weather?"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get current weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: testPropsMap(map[string]api.ToolProperty{ - "location": { - Type: api.PropertyType{"string"}, - Description: "City name", - }, - }), - Required: []string{"location"}, - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n", - }, - { - name: "multiple tools without system message", - messages: []api.Message{ - {Role: "user", Content: "Hello"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get weather", - }, - }, - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_time", - Description: "Get time", - }, - }, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "user-tool sequence", - messages: []api.Message{ - {Role: "user", Content: "Check weather"}, - {Role: "tool", Content: "22C"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "full tool call cycle", - messages: []api.Message{ - {Role: "user", Content: "Check weather"}, - {Role: "assistant", Content: "Let me check"}, - {Role: "tool", Content: "22C"}, - {Role: "assistant", Content: "It's 22C"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "unicode content", - messages: []api.Message{ - {Role: "user", Content: "你好世界! مرحبا 🌍"}, - {Role: "assistant", Content: "Hello! 👋"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "newlines in content", - messages: []api.Message{ - {Role: "user", Content: "Line 1\nLine 2\n\nLine 4"}, - {Role: "assistant", Content: "Response with\nmultiple\nlines"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n", - }, - { - name: "empty assistant content", - messages: []api.Message{ - {Role: "user", Content: "Hello"}, - {Role: "assistant", Content: ""}, - {Role: "user", Content: "OK"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n", - }, - { - // Generation prompt does NOT include - model outputs it - name: "generation prompt has no think tag", - messages: []api.Message{ - {Role: "user", Content: "Think hard"}, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n", - }, - { - // Interleaved: thinking before tool call - last assistant preserves thinking - name: "thinking before tool call (last assistant)", - messages: []api.Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "I need to check the weather", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{ - "location": "Paris", - }), - }, - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nI need to check the weather<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n", - }, - { - // Two assistants with tool calls - first has thinking stripped - name: "two assistants with tools: first thinking stripped", - messages: []api.Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "checking", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{ - "location": "Paris", - }), - }, - }, - }, - }, - {Role: "tool", Content: "22C"}, - {Role: "assistant", Content: "got resultIt's 22C!"}, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\ngot resultIt's 22C!<|im_end|>\n<|im_start|>assistant\n", - }, - { - // Two assistants with tools - both preserved when thinking enabled - name: "two assistants with tools: both preserved when thinking enabled", - messages: []api.Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "checking", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{ - "location": "Paris", - }), - }, - }, - }, - }, - {Role: "tool", Content: "22C"}, - {Role: "assistant", Content: "got resultIt's 22C!"}, - }, - thinkValue: &api.ThinkValue{Value: true}, - expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nchecking<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\ngot resultIt's 22C!<|im_end|>\n<|im_start|>assistant\n", - }, - { - // Content before thinking before tool call - name: "content then thinking then tool call", - messages: []api.Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "Let me check.Using weather API", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: testArgs(map[string]any{ - "location": "Paris", - }), - }, - }, - }, - }, - }, - thinkValue: &api.ThinkValue{Value: false}, - expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.Using weather API<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n", + expected: "<|startoftext|><|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nHello", }, } - renderer := &LFM2Renderer{IsThinking: true} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue) + rendered, err := tt.renderer.Render(tt.messages, tt.tools, tt.thinkValue) if err != nil { t.Fatalf("Render() error = %v", err) } if diff := cmp.Diff(tt.expected, rendered); diff != "" { - t.Errorf("Render() mismatch (-want +got):\n%s", diff) + t.Fatalf("Render() mismatch (-want +got):\n%s", diff) } }) } } + +func TestLFM2Renderer_Images(t *testing.T) { + tests := []struct { + name string + renderer *LFM2Renderer + message api.Message + expected string + }{ + { + name: "single_image_default_placeholder", + renderer: &LFM2Renderer{}, + message: api.Message{ + Role: "user", + Content: "Describe this image.", + Images: []api.ImageData{api.ImageData("img1")}, + }, + expected: "<|startoftext|><|im_start|>user\nDescribe this image.<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "multiple_images_default_placeholder", + renderer: &LFM2Renderer{}, + message: api.Message{ + Role: "user", + Content: "Describe these images.", + Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")}, + }, + expected: "<|startoftext|><|im_start|>user\nDescribe these images.<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "single_image_img_tag_placeholder", + renderer: &LFM2Renderer{useImgTags: true}, + message: api.Message{ + Role: "user", + Content: "Describe this image.", + Images: []api.ImageData{api.ImageData("img1")}, + }, + expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "existing_indexed_img_placeholder_not_duplicated", + renderer: &LFM2Renderer{useImgTags: true}, + message: api.Message{ + Role: "user", + Content: "[img-0]Describe this image.", + Images: []api.ImageData{api.ImageData("img1")}, + }, + expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "existing_template_image_placeholder_not_duplicated", + renderer: &LFM2Renderer{}, + message: api.Message{ + Role: "user", + Content: "Describe this image.", + Images: []api.ImageData{api.ImageData("img1")}, + }, + expected: "<|startoftext|><|im_start|>user\nDescribe this image.<|im_end|>\n<|im_start|>assistant\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.renderer.Render([]api.Message{tt.message}, nil, nil) + if err != nil { + t.Fatalf("Render() error = %v", err) + } + if diff := cmp.Diff(tt.expected, got); diff != "" { + t.Fatalf("Render() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLFM2Renderer_JSONFormatting(t *testing.T) { + tool := api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "echo", + Description: "", + Parameters: api.ToolFunctionParameters{ + Type: "object", + }, + }, + } + + got := lfm2JSON(tool) + want := "{\"type\": \"function\", \"function\": {\"name\": \"echo\", \"description\": \"\", \"parameters\": {\"type\": \"object\", \"properties\": null}}}" + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("lfm2JSON mismatch (-want +got):\n%s", diff) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index baa0bc8c4..82d263e5e 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -85,9 +85,9 @@ func rendererForName(name string) Renderer { case "glm-ocr": return &GlmOcrRenderer{} case "lfm2": - return &LFM2Renderer{IsThinking: false} + return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags} case "lfm2-thinking": - return &LFM2Renderer{IsThinking: true} + return &LFM2Renderer{IsThinking: true, useImgTags: RenderImgTags} default: return nil }