mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
model: improvements to LFM architectures (#14368)
This commit is contained in:
@@ -1,410 +1,44 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for attention layers
|
||||
// - a per-sequence recurrent conv state for shortconv layers
|
||||
// HybridCache adapts the shared recurrent cache for LFM2:
|
||||
// - KV attention cache is handled by the embedded causal cache
|
||||
// - shortconv recurrent state uses conv slots [dConv, hiddenSize]
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
|
||||
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
|
||||
// This reuses shared checkpoint/restore logic for prefix mismatch recovery.
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
hiddenSize int
|
||||
dConv int
|
||||
|
||||
// slot mapping for recurrent state
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
// track any error from EnsureWritable to propagate later
|
||||
writableError error
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
hiddenSize: hiddenSize,
|
||||
dConv: dConv,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: shift,
|
||||
ConvDim: dConv,
|
||||
ConvChannels: hiddenSize,
|
||||
RecurrentStateSize: 1, // LFM2 uses only conv state; keep a minimal recurrent buffer size.
|
||||
CheckpointLogPrefix: "lfm2",
|
||||
})
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for shortconv.
|
||||
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
// without modifying permanent state (slotForSeq, refCount)
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int // track newly allocated slots that need zeroing
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero conv state for newly allocated slots to clear stale data from previous sequences
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroConvSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
// Bounds check before freeing
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroConvSlots zeros the conv state for the given slots across all layers.
|
||||
// This must be called when recycling slots to prevent stale state from affecting new sequences.
|
||||
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 || len(c.convStates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Use input context for creating tensors
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
// Create slot indices tensor
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Create zero tensor for the slots (SetRows requires F32 source)
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
|
||||
|
||||
// Zero each layer's conv state for these slots
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
|
||||
// Returns an error if slot allocation fails.
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
// Copy existing conv state for all initialized layers
|
||||
for _, buf := range c.convStates {
|
||||
// buf: [dConv*hiddenSize, maxSlots]
|
||||
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
|
||||
// On the first write to dst, EnsureWritable will create a private slot.
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
// Bounds check before decrementing
|
||||
if dstSlot >= 0 && dstSlot < len(c.refCount) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
// src may not have a slot yet; dst will allocate on demand
|
||||
return
|
||||
}
|
||||
|
||||
// Bounds check before incrementing
|
||||
if srcSlot >= 0 && srcSlot < len(c.refCount) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
return c.kv.CanResume(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For recurrent state, any removal invalidates the state because
|
||||
// the state at position N depends on all previous positions.
|
||||
// Drop the slot mapping so it resets on next use.
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
return c.SlotsTensor()
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
|
||||
// Returns an error if copy-on-write allocation fails.
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
// newState must have shape [dConv, hiddenSize, nSeqs].
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
return c.NumSeqs()
|
||||
}
|
||||
|
||||
@@ -4,441 +4,39 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// TestHybridCache tests verify the slot management logic of HybridCache.
|
||||
// These tests focus on the recurrent state slot allocation, reference counting,
|
||||
// and copy-on-write semantics without requiring a full ML backend.
|
||||
func TestHybridCache_New(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
if cache == nil {
|
||||
t.Fatal("expected cache to be created")
|
||||
}
|
||||
|
||||
// createSlotOnlyCache creates a HybridCache with only the slot management
|
||||
// fields initialized. Used to test slot logic in isolation.
|
||||
func createSlotOnlyCache(maxSequences int) *HybridCache {
|
||||
return &HybridCache{
|
||||
hiddenSize: 256,
|
||||
dConv: 3,
|
||||
maxSequences: maxSequences,
|
||||
refCount: make([]int, maxSequences),
|
||||
freeSlots: initFreeSlots(maxSequences),
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
if cache.Recurrent == nil {
|
||||
t.Fatal("expected embedded recurrent cache to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func initFreeSlots(n int) []int {
|
||||
slots := make([]int, 0, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
slots = append(slots, i)
|
||||
}
|
||||
return slots
|
||||
}
|
||||
func TestHybridCache_ImplementsCheckpointCache(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
|
||||
func TestHybridCache_SlotAllocation(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Verify initial state
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate all slots
|
||||
for range 4 {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.refCount[slot] = 1
|
||||
}
|
||||
|
||||
// Should be full now
|
||||
if len(cache.freeSlots) != 0 {
|
||||
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Trying to allocate another should fail
|
||||
_, err := cache.allocSlot()
|
||||
if err != kvcache.ErrKvCacheFull {
|
||||
t.Errorf("expected ErrKvCacheFull, got %v", err)
|
||||
if _, ok := any(cache).(kvcache.CheckpointCache); !ok {
|
||||
t.Fatal("expected HybridCache to implement CheckpointCache")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotReuse(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
func TestHybridCache_DefaultBatchState(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 2)
|
||||
|
||||
// Allocate a slot
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free it
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
|
||||
// Allocate again - should get the same slot back (LIFO)
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Simulate sharing slot with seq 2 (copy-on-write style)
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Should share the same slot
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
if got := cache.numSeqs(); got != 0 {
|
||||
t.Fatalf("expected 0 sequences before StartForward, got %d", got)
|
||||
}
|
||||
|
||||
// Ref count should be 2
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share with seq 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Unshare seq 2
|
||||
cache.refCount[slot1]--
|
||||
delete(cache.slotForSeq, 2)
|
||||
|
||||
// Ref count should be back to 1
|
||||
if cache.refCount[slot1] != 1 {
|
||||
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
|
||||
if got := cache.seqTokens(); got != 0 {
|
||||
t.Fatalf("expected 0 sequence tokens before StartForward, got %d", got)
|
||||
}
|
||||
|
||||
// Seq 2 should no longer have a slot
|
||||
if _, ok := cache.slotForSeq[2]; ok {
|
||||
t.Error("seq 2 should not have a slot after unshare")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot when refCount drops to 0
|
||||
cache.refCount[slot1]--
|
||||
if cache.refCount[slot1] <= 0 {
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
}
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Slot should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Ref count should be 0
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotOverwrite(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slots for seq 1 and seq 2
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
slot2, _ := cache.allocSlot()
|
||||
cache.slotForSeq[2] = slot2
|
||||
cache.refCount[slot2] = 1
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Simulate overwriting seq 2's slot with slot1 (sharing)
|
||||
// First free the old slot
|
||||
cache.refCount[slot2]--
|
||||
if cache.refCount[slot2] <= 0 {
|
||||
cache.refCount[slot2] = 0
|
||||
cache.freeSlot(slot2)
|
||||
}
|
||||
// Then share slot1
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Seq 2 should now share slot1
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Old slot2 should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots+1 {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_BoundsChecking(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Test freeing invalid slot (should not panic)
|
||||
cache.freeSlot(-1)
|
||||
cache.freeSlot(100) // out of bounds
|
||||
|
||||
// freeSlot does bounds checking, so invalid slots should be ignored
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Fork to seq 2, 3, 4 (all share slot1)
|
||||
for _, seq := range []int{2, 3, 4} {
|
||||
cache.slotForSeq[seq] = slot1
|
||||
cache.refCount[slot1]++
|
||||
}
|
||||
|
||||
// Ref count should be 4
|
||||
if cache.refCount[slot1] != 4 {
|
||||
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Remove seq 2, 3
|
||||
for _, seq := range []int{2, 3} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Slot should still be allocated (not in free list)
|
||||
found := false
|
||||
for _, s := range cache.freeSlots {
|
||||
if s == slot1 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
t.Error("slot1 should not be in free list yet")
|
||||
}
|
||||
|
||||
// Remove remaining sequences
|
||||
for _, seq := range []int{1, 4} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ChainedSharing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Create seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share 1 -> 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Share 2 -> 3 (should still share slot1)
|
||||
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// All should share slot1
|
||||
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
|
||||
t.Error("all sequences should share slot1")
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 3 {
|
||||
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_CacheParameters(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
|
||||
|
||||
if cache.hiddenSize != 512 {
|
||||
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
|
||||
}
|
||||
if cache.dConv != 5 {
|
||||
t.Errorf("expected dConv 5, got %d", cache.dConv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_NumSeqs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially no sequences
|
||||
if cache.numSeqs() != 0 {
|
||||
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
if cache.numSeqs() != 3 {
|
||||
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SeqTokens(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially 0
|
||||
if cache.seqTokens() != 0 {
|
||||
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqTokens = 16
|
||||
|
||||
if cache.seqTokens() != 16 {
|
||||
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Seqs returns a clone of curSeqs
|
||||
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
seqs := cache.Seqs()
|
||||
|
||||
// Modify returned slice
|
||||
seqs[0] = 999
|
||||
|
||||
// Original should be unchanged
|
||||
if cache.curSeqs[0] != 1 {
|
||||
t.Error("Seqs should return a clone, not the original slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially not supported (no batch set up)
|
||||
if cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be false initially")
|
||||
}
|
||||
|
||||
// Set up a valid batch
|
||||
cache.curSeqTokens = 1
|
||||
cache.curSeqs = []int{1}
|
||||
|
||||
if !cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be true with valid batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// zeroConvSlots should handle empty slots without panicking
|
||||
cache.zeroConvSlots(nil, nil)
|
||||
cache.zeroConvSlots(nil, []int{})
|
||||
|
||||
// zeroConvSlots should handle empty convStates without panicking
|
||||
cache.zeroConvSlots(nil, []int{0, 1, 2})
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot (simulating sequence removal)
|
||||
cache.refCount[slot1]--
|
||||
cache.freeSlot(slot1)
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Verify slot is in free list
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate for new seq 2 - should get recycled slot
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
|
||||
}
|
||||
|
||||
// This recycled slot would need zeroing in the real implementation
|
||||
// The actual zeroing is tested via integration tests since it requires ML context
|
||||
}
|
||||
|
||||
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Simulate the slot allocation flow from StartForward
|
||||
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
|
||||
|
||||
newSlots := []int{}
|
||||
|
||||
// Seq 1 doesn't have a slot - allocate and track
|
||||
seq := 1
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
|
||||
// Verify newSlots contains the allocated slot
|
||||
if len(newSlots) != 1 {
|
||||
t.Errorf("expected 1 new slot, got %d", len(newSlots))
|
||||
}
|
||||
|
||||
// Seq 1 already has a slot - should NOT be tracked as new
|
||||
newSlots2 := []int{}
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, _ := cache.allocSlot()
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots2 = append(newSlots2, slot)
|
||||
}
|
||||
|
||||
// Verify no new slots for existing sequence
|
||||
if len(newSlots2) != 0 {
|
||||
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
|
||||
t.Fatal("expected unsupported batch layout before StartForward")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -25,8 +29,20 @@ type Options struct {
|
||||
// per-layer head counts (LFM2 alternates attention and recurrent layers)
|
||||
numHeadsByLayer []int
|
||||
numKVHeadsByLayer []int
|
||||
|
||||
// MoE config
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
normTopKProb bool
|
||||
expertWeightsScale float32
|
||||
expertGatingFunc uint32
|
||||
}
|
||||
|
||||
const (
|
||||
expertGatingFuncSoftmax = uint32(0)
|
||||
expertGatingFuncSigmoid = uint32(2)
|
||||
)
|
||||
|
||||
func (o Options) headDimValue() int {
|
||||
// Head dim is shared across layers; fall back to first attention layer head count.
|
||||
for _, h := range o.numHeadsByLayer {
|
||||
@@ -67,18 +83,138 @@ type Model struct {
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
VisionModel *VisionModel `gguf:"v"`
|
||||
VisionProjector *VisionProjector `gguf:"mm"`
|
||||
ImageProcessor ImageProcessor
|
||||
imageTokenID int32
|
||||
imageStartToken int32
|
||||
imageEndToken int32
|
||||
imageThumbnailID int32
|
||||
imageRowColIDs map[imageGridPos]int32
|
||||
useSpecialTokens bool
|
||||
projectorOptions VisionProjectorOptions
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.Uint("expert_count") > 0 {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type imageGridPos struct {
|
||||
row int
|
||||
col int
|
||||
}
|
||||
|
||||
type visionEmbeddingLayout struct {
|
||||
rows int
|
||||
cols int
|
||||
hasThumbnail bool
|
||||
}
|
||||
|
||||
type visionChunkData struct {
|
||||
tokens int
|
||||
row int
|
||||
col int
|
||||
thumbnail bool
|
||||
layout *visionEmbeddingLayout
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.TokenEmbedding == nil {
|
||||
return errors.New("lfm2: missing token_embd tensor")
|
||||
}
|
||||
if m.OutputNorm == nil {
|
||||
return errors.New("lfm2: missing output_norm tensor")
|
||||
}
|
||||
if m.Output == nil {
|
||||
return errors.New("lfm2: missing output tensor")
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
if layer.AttentionNorm == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d.attn_norm tensor", i)
|
||||
}
|
||||
if layer.MLPNorm == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d.ffn_norm tensor", i)
|
||||
}
|
||||
switch ff := layer.MLP.(type) {
|
||||
case nil:
|
||||
return fmt.Errorf("lfm2: missing blk.%d feed-forward tensors", i)
|
||||
case *denseMLP:
|
||||
if ff.Up == nil || ff.Down == nil || ff.Gate == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d dense feed-forward tensors", i)
|
||||
}
|
||||
case *sparseMLP:
|
||||
if ff.Router == nil || ff.Gate == nil || ff.Up == nil || ff.Down == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d sparse feed-forward tensors", i)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("lfm2: unsupported feed-forward type at blk.%d", i)
|
||||
}
|
||||
|
||||
switch op := layer.Operator.(type) {
|
||||
case *Attention:
|
||||
if op == nil || op.Query == nil || op.Key == nil || op.Value == nil || op.Output == nil || op.QueryNorm == nil || op.KeyNorm == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d attention tensors", i)
|
||||
}
|
||||
case *ShortConv:
|
||||
if op == nil || op.Conv == nil || op.Conv.Weight == nil || op.InProj == nil || op.OutProj == nil {
|
||||
return fmt.Errorf("lfm2: missing blk.%d shortconv tensors", i)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("lfm2: unsupported operator at blk.%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
if m.VisionModel != nil {
|
||||
if m.VisionModel.PatchEmbedding == nil {
|
||||
return errors.New("lfm2: missing vision patch embedding tensors")
|
||||
}
|
||||
if m.VisionModel.PositionEmbedding == nil {
|
||||
return errors.New("lfm2: missing vision position embedding tensors")
|
||||
}
|
||||
if m.VisionModel.PostLayerNorm == nil {
|
||||
return errors.New("lfm2: missing vision post layer norm tensors")
|
||||
}
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return errors.New("lfm2: missing vision encoder layers")
|
||||
}
|
||||
for i, layer := range m.VisionModel.Layers {
|
||||
if layer.LayerNorm1 == nil || layer.LayerNorm2 == nil || layer.SelfAttention == nil || layer.MLP == nil {
|
||||
return fmt.Errorf("lfm2: missing vision layer tensors at v.blk.%d", i)
|
||||
}
|
||||
if layer.SelfAttention.Query == nil || layer.SelfAttention.Key == nil || layer.SelfAttention.Value == nil || layer.SelfAttention.Output == nil {
|
||||
return fmt.Errorf("lfm2: missing vision attention tensors at v.blk.%d", i)
|
||||
}
|
||||
if layer.MLP.Up == nil || layer.MLP.Down == nil {
|
||||
return fmt.Errorf("lfm2: missing vision feed-forward tensors at v.blk.%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
if m.VisionProjector == nil || m.VisionProjector.Linear1 == nil || m.VisionProjector.Linear2 == nil {
|
||||
return errors.New("lfm2: missing multimodal projector tensors")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.String("tokenizer.ggml.model") != "gpt2" {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
numExperts := int(c.Uint("expert_count"))
|
||||
isMoE := numExperts > 0
|
||||
numExpertsUsed := int(c.Uint("expert_used_count"))
|
||||
if isMoE {
|
||||
if numExperts <= 0 {
|
||||
return nil, fmt.Errorf("lfm2: invalid expert_count=%d", numExperts)
|
||||
}
|
||||
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
|
||||
return nil, fmt.Errorf("lfm2: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
|
||||
}
|
||||
}
|
||||
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -105,8 +241,16 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
VisionProjector: &VisionProjector{},
|
||||
imageRowColIDs: make(map[imageGridPos]int32),
|
||||
projectorOptions: VisionProjectorOptions{
|
||||
scaleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
|
||||
useLayerNorm: c.Bool("vision.projector.use_layernorm", false),
|
||||
},
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
@@ -116,9 +260,66 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
numExperts: numExperts,
|
||||
numExpertsUsed: numExpertsUsed,
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
|
||||
expertGatingFunc: c.Uint("expert_gating_func", expertGatingFuncSoftmax),
|
||||
},
|
||||
}
|
||||
|
||||
lookupTokenID := func(token string) int32 {
|
||||
for i, t := range vocabulary.Values {
|
||||
if t == token {
|
||||
return int32(i)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
resolveTokenID := func(explicitKey, token string, fallback uint32) int32 {
|
||||
if explicitKey != "" {
|
||||
if id := c.Uint(explicitKey); id != 0 {
|
||||
return int32(id)
|
||||
}
|
||||
}
|
||||
if tokenID := lookupTokenID(token); tokenID != 0 {
|
||||
return tokenID
|
||||
}
|
||||
return int32(fallback)
|
||||
}
|
||||
|
||||
m.imageTokenID = resolveTokenID("vision.image_token_id", "<image>", 396)
|
||||
m.imageStartToken = resolveTokenID("vision.image_start_token_id", "<|image_start|>", 0)
|
||||
m.imageEndToken = resolveTokenID("vision.image_end_token_id", "<|image_end|>", 0)
|
||||
m.imageThumbnailID = resolveTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>", 0)
|
||||
m.useSpecialTokens = c.Bool("vision.use_image_special_tokens", true)
|
||||
|
||||
maxGridTokens := int(c.Uint("vision.max_tiles", 10))
|
||||
if maxGridTokens <= 0 {
|
||||
maxGridTokens = 10
|
||||
}
|
||||
for row := 1; row <= maxGridTokens; row++ {
|
||||
for col := 1; col <= maxGridTokens; col++ {
|
||||
token := fmt.Sprintf("<|img_row_%d_col_%d|>", row, col)
|
||||
if tokenID := lookupTokenID(token); tokenID > 0 {
|
||||
m.imageRowColIDs[imageGridPos{row: row, col: col}] = tokenID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !m.useSpecialTokens {
|
||||
m.imageStartToken = 0
|
||||
m.imageEndToken = 0
|
||||
m.imageThumbnailID = 0
|
||||
m.imageRowColIDs = map[imageGridPos]int32{}
|
||||
}
|
||||
|
||||
if c.Uint("vision.block_count") == 0 {
|
||||
m.VisionModel = nil
|
||||
m.VisionProjector = nil
|
||||
}
|
||||
|
||||
type headCounts interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
@@ -133,6 +334,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
m.numHeadsByLayer = make([]int, len(m.Layers))
|
||||
m.numKVHeadsByLayer = make([]int, len(m.Layers))
|
||||
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count"))
|
||||
if leadingDenseBlockCount < 0 {
|
||||
leadingDenseBlockCount = 0
|
||||
}
|
||||
if leadingDenseBlockCount > len(m.Layers) {
|
||||
leadingDenseBlockCount = len(m.Layers)
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
m.numHeadsByLayer[i] = int(headCount[i])
|
||||
m.numKVHeadsByLayer[i] = int(headCountKV[i])
|
||||
@@ -142,6 +351,12 @@ func New(c fs.Config) (model.Model, error) {
|
||||
} else {
|
||||
m.Layers[i].Operator = &Attention{}
|
||||
}
|
||||
|
||||
if isMoE && i >= leadingDenseBlockCount {
|
||||
m.Layers[i].MLP = &sparseMLP{}
|
||||
} else {
|
||||
m.Layers[i].MLP = &denseMLP{}
|
||||
}
|
||||
}
|
||||
|
||||
lCache := int(c.Uint("shortconv.l_cache"))
|
||||
@@ -188,22 +403,77 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
type FeedForward interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type denseMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
func (mlp *denseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type sparseMLP struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
}
|
||||
|
||||
func (mlp *sparseMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
// hiddenState: [hidden, tokens]
|
||||
routerLogits := mlp.Router.Forward(ctx, hiddenState)
|
||||
|
||||
probs := routerLogits.Softmax(ctx)
|
||||
if opts.expertGatingFunc == expertGatingFuncSigmoid {
|
||||
probs = routerLogits.Sigmoid(ctx)
|
||||
}
|
||||
|
||||
selectionProbs := probs
|
||||
if mlp.Bias != nil {
|
||||
selectionProbs = selectionProbs.Add(ctx, mlp.Bias)
|
||||
}
|
||||
|
||||
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
||||
if opts.normTopKProb {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
weightsSum := routingWeights.SumRows(ctx)
|
||||
weightsSum = weightsSum.Clamp(ctx, 1e-6, float32(math.Inf(1)))
|
||||
routingWeights = routingWeights.Div(ctx, weightsSum)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
}
|
||||
if opts.expertWeightsScale != 1 {
|
||||
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
|
||||
}
|
||||
|
||||
// Build routing-weights branch early to enable topk-MoE fusion.
|
||||
ctx.Forward(routingWeights)
|
||||
|
||||
hiddenState3D := hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
||||
experts := mlp.Gate.Forward(ctx, hiddenState3D, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenState3D, selectedExperts))
|
||||
experts = mlp.Down.Forward(ctx, experts, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextState := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextState = nextState.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextState
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Operator Operator
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
MLP FeedForward
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
|
||||
@@ -229,10 +499,233 @@ func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tenso
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func multimodalTokenCount(mm input.Multimodal) int {
|
||||
if mm.Tensor != nil {
|
||||
return mm.Tensor.Dim(1)
|
||||
}
|
||||
|
||||
switch data := mm.Data.(type) {
|
||||
case int:
|
||||
return data
|
||||
case int32:
|
||||
return int(data)
|
||||
case visionChunkData:
|
||||
return data.tokens
|
||||
case *visionChunkData:
|
||||
if data != nil {
|
||||
return data.tokens
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func multimodalChunkInfo(mm input.Multimodal) visionChunkData {
|
||||
switch data := mm.Data.(type) {
|
||||
case visionChunkData:
|
||||
return data
|
||||
case *visionChunkData:
|
||||
if data != nil {
|
||||
return *data
|
||||
}
|
||||
}
|
||||
|
||||
return visionChunkData{
|
||||
tokens: multimodalTokenCount(mm),
|
||||
}
|
||||
}
|
||||
|
||||
func multimodalLayout(mm []input.Multimodal) visionEmbeddingLayout {
|
||||
layout := visionEmbeddingLayout{rows: 1, cols: 1}
|
||||
if len(mm) == 0 {
|
||||
return layout
|
||||
}
|
||||
|
||||
first := multimodalChunkInfo(mm[0])
|
||||
if first.layout != nil {
|
||||
return *first.layout
|
||||
}
|
||||
|
||||
return layout
|
||||
}
|
||||
|
||||
func (m *Model) imageRowColToken(row, col int) int32 {
|
||||
if row <= 0 || col <= 0 {
|
||||
return 0
|
||||
}
|
||||
return m.imageRowColIDs[imageGridPos{row: row, col: col}]
|
||||
}
|
||||
|
||||
func (m *Model) appendImageChunk(result []*input.Input, chunk input.Multimodal, imageToken int32, hash uint64) ([]*input.Input, error) {
|
||||
tokenCount := multimodalTokenCount(chunk)
|
||||
if tokenCount <= 0 {
|
||||
return nil, errors.New("lfm2: multimodal input has no tokens")
|
||||
}
|
||||
|
||||
result = append(result, &input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: []input.Multimodal{chunk},
|
||||
MultimodalHash: hash,
|
||||
SameBatch: tokenCount - 1,
|
||||
})
|
||||
|
||||
for range tokenCount - 1 {
|
||||
result = append(result, &input.Input{Token: imageToken})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if m.VisionModel == nil || m.VisionProjector == nil || len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processedImages, layout, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.ImageProcessor.patchSize <= 0 {
|
||||
return nil, errors.New("lfm2: invalid vision patch size")
|
||||
}
|
||||
|
||||
layoutInfo := &visionEmbeddingLayout{
|
||||
rows: layout.rows,
|
||||
cols: layout.cols,
|
||||
hasThumbnail: layout.hasThumbnail,
|
||||
}
|
||||
|
||||
mm := make([]input.Multimodal, 0, len(processedImages))
|
||||
for i, processed := range processedImages {
|
||||
patches := visionPatchGrid{
|
||||
Width: processed.size.X / m.ImageProcessor.patchSize,
|
||||
Height: processed.size.Y / m.ImageProcessor.patchSize,
|
||||
}
|
||||
if patches.Width == 0 || patches.Height == 0 {
|
||||
return nil, errors.New("lfm2: invalid resized image dimensions")
|
||||
}
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(processed.data, processed.size.X, processed.size.Y, m.ImageProcessor.numChannels)
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, patches)
|
||||
projected := m.VisionProjector.Forward(ctx, visionOutputs, patches, m.projectorOptions)
|
||||
|
||||
chunk := visionChunkData{
|
||||
tokens: projected.Dim(1),
|
||||
row: processed.row,
|
||||
col: processed.col,
|
||||
thumbnail: processed.thumbnail,
|
||||
}
|
||||
if i == 0 {
|
||||
chunk.layout = layoutInfo
|
||||
}
|
||||
|
||||
mm = append(mm, input.Multimodal{
|
||||
Tensor: projected,
|
||||
Data: chunk,
|
||||
})
|
||||
}
|
||||
|
||||
return mm, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
imageToken := m.imageTokenID
|
||||
if imageToken == 0 {
|
||||
imageToken = 396
|
||||
}
|
||||
useSpecialTokens := m.useSpecialTokens || m.imageStartToken > 0 || m.imageEndToken > 0 || m.imageThumbnailID > 0 || len(m.imageRowColIDs) > 0
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
layout := multimodalLayout(inp.Multimodal)
|
||||
if layout.rows <= 0 {
|
||||
layout.rows = 1
|
||||
}
|
||||
if layout.cols <= 0 {
|
||||
layout.cols = 1
|
||||
}
|
||||
tiles := layout.rows * layout.cols
|
||||
multitile := tiles > 1
|
||||
|
||||
if useSpecialTokens && m.imageStartToken > 0 {
|
||||
result = append(result, &input.Input{Token: m.imageStartToken})
|
||||
}
|
||||
|
||||
for i, mm := range inp.Multimodal {
|
||||
chunk := multimodalChunkInfo(mm)
|
||||
if chunk.tokens <= 0 {
|
||||
chunk.tokens = multimodalTokenCount(mm)
|
||||
}
|
||||
|
||||
if multitile && !chunk.thumbnail && chunk.row == 0 && chunk.col == 0 && i < tiles {
|
||||
chunk.row = i/layout.cols + 1
|
||||
chunk.col = i%layout.cols + 1
|
||||
}
|
||||
if multitile && layout.hasThumbnail && i == tiles {
|
||||
chunk.thumbnail = true
|
||||
}
|
||||
|
||||
if useSpecialTokens && multitile {
|
||||
if chunk.thumbnail {
|
||||
if m.imageThumbnailID > 0 {
|
||||
result = append(result, &input.Input{Token: m.imageThumbnailID})
|
||||
}
|
||||
} else if marker := m.imageRowColToken(chunk.row, chunk.col); marker > 0 {
|
||||
result = append(result, &input.Input{Token: marker})
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
result, err = m.appendImageChunk(result, input.Multimodal{
|
||||
Tensor: mm.Tensor,
|
||||
Data: chunk,
|
||||
}, imageToken, inp.MultimodalHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if useSpecialTokens && m.imageEndToken > 0 {
|
||||
result = append(result, &input.Input{Token: m.imageEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
if len(batch.Multimodal) > 0 {
|
||||
// We splice vision embeddings into token embeddings in-place; duplicate to
|
||||
// avoid aliasing the raw embedding output graph.
|
||||
hiddenState = hiddenState.Duplicate(ctx)
|
||||
}
|
||||
for _, mm := range batch.Multimodal {
|
||||
offset := mm.Index
|
||||
for _, multimodal := range mm.Multimodal {
|
||||
if multimodal.Tensor == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
visionOutputs := multimodal.Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, offset*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
offset += visionOutputs.Dim(1)
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
@@ -251,4 +744,5 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func init() {
|
||||
model.Register("lfm2", New)
|
||||
model.Register("lfm2moe", New)
|
||||
}
|
||||
|
||||
160
model/models/lfm2/model_multimodal_test.go
Normal file
160
model/models/lfm2/model_multimodal_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
func TestPostTokenizeWithSpecialImageTokens(t *testing.T) {
|
||||
m := &Model{
|
||||
imageTokenID: 396,
|
||||
imageStartToken: 2,
|
||||
imageEndToken: 3,
|
||||
useSpecialTokens: true,
|
||||
}
|
||||
|
||||
in := []*input.Input{
|
||||
{Token: 11},
|
||||
{Multimodal: []input.Multimodal{{Data: 64}}, MultimodalHash: 123},
|
||||
{Token: 12},
|
||||
}
|
||||
|
||||
out, err := m.PostTokenize(in)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(out) != 68 {
|
||||
t.Fatalf("expected 68 tokens, got %d", len(out))
|
||||
}
|
||||
|
||||
if out[0].Token != 11 {
|
||||
t.Fatalf("out[0].Token = %d, want 11", out[0].Token)
|
||||
}
|
||||
if out[1].Token != 2 {
|
||||
t.Fatalf("out[1].Token = %d, want 2", out[1].Token)
|
||||
}
|
||||
|
||||
firstImage := out[2]
|
||||
if firstImage.Token != 396 {
|
||||
t.Fatalf("out[2].Token = %d, want 396", firstImage.Token)
|
||||
}
|
||||
if len(firstImage.Multimodal) != 1 {
|
||||
t.Fatalf("expected multimodal payload on first image token")
|
||||
}
|
||||
if firstImage.MultimodalHash != 123 {
|
||||
t.Fatalf("out[2].MultimodalHash = %d, want 123", firstImage.MultimodalHash)
|
||||
}
|
||||
if firstImage.SameBatch != 63 {
|
||||
t.Fatalf("out[2].SameBatch = %d, want 63", firstImage.SameBatch)
|
||||
}
|
||||
|
||||
for i := 3; i < 66; i++ {
|
||||
if out[i].Token != 396 {
|
||||
t.Fatalf("out[%d].Token = %d, want 396", i, out[i].Token)
|
||||
}
|
||||
if len(out[i].Multimodal) != 0 {
|
||||
t.Fatalf("out[%d] should not carry multimodal payload", i)
|
||||
}
|
||||
}
|
||||
|
||||
if out[66].Token != 3 {
|
||||
t.Fatalf("out[66].Token = %d, want 3", out[66].Token)
|
||||
}
|
||||
if out[67].Token != 12 {
|
||||
t.Fatalf("out[67].Token = %d, want 12", out[67].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTokenizeWithoutSpecialImageTokens(t *testing.T) {
|
||||
m := &Model{
|
||||
imageTokenID: 777,
|
||||
useSpecialTokens: false,
|
||||
}
|
||||
|
||||
in := []*input.Input{
|
||||
{Multimodal: []input.Multimodal{{Data: 5}}, MultimodalHash: 9},
|
||||
}
|
||||
|
||||
out, err := m.PostTokenize(in)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(out) != 5 {
|
||||
t.Fatalf("expected 5 tokens, got %d", len(out))
|
||||
}
|
||||
if out[0].Token != 777 || out[0].SameBatch != 4 || len(out[0].Multimodal) != 1 {
|
||||
t.Fatalf("unexpected first token: %+v", *out[0])
|
||||
}
|
||||
for i := 1; i < 5; i++ {
|
||||
if out[i].Token != 777 {
|
||||
t.Fatalf("out[%d].Token = %d, want 777", i, out[i].Token)
|
||||
}
|
||||
if len(out[i].Multimodal) != 0 {
|
||||
t.Fatalf("out[%d] should not carry multimodal payload", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostTokenizeMultiTileLayoutTokens(t *testing.T) {
|
||||
m := &Model{
|
||||
imageTokenID: 396,
|
||||
imageStartToken: 498,
|
||||
imageEndToken: 499,
|
||||
imageThumbnailID: 497,
|
||||
imageRowColIDs: map[imageGridPos]int32{
|
||||
{row: 1, col: 1}: 397,
|
||||
{row: 1, col: 2}: 398,
|
||||
},
|
||||
useSpecialTokens: true,
|
||||
}
|
||||
|
||||
layout := &visionEmbeddingLayout{rows: 1, cols: 2, hasThumbnail: true}
|
||||
in := []*input.Input{{
|
||||
Multimodal: []input.Multimodal{
|
||||
{Data: visionChunkData{tokens: 3, row: 1, col: 1, layout: layout}},
|
||||
{Data: visionChunkData{tokens: 3, row: 1, col: 2}},
|
||||
{Data: visionChunkData{tokens: 2, thumbnail: true}},
|
||||
},
|
||||
MultimodalHash: 1,
|
||||
}}
|
||||
|
||||
out, err := m.PostTokenize(in)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize returned error: %v", err)
|
||||
}
|
||||
|
||||
got := make([]int32, len(out))
|
||||
for i := range out {
|
||||
got[i] = out[i].Token
|
||||
}
|
||||
|
||||
want := []int32{
|
||||
498, // <|image_start|>
|
||||
397, // <|img_row_1_col_1|>
|
||||
396, 396, 396,
|
||||
398, // <|img_row_1_col_2|>
|
||||
396, 396, 396,
|
||||
497, // <|img_thumbnail|>
|
||||
396, 396,
|
||||
499, // <|image_end|>
|
||||
}
|
||||
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len(out) = %d, want %d", len(got), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("out[%d].Token = %d, want %d", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(out[2].Multimodal) != 1 || len(out[6].Multimodal) != 1 || len(out[10].Multimodal) != 1 {
|
||||
t.Fatalf("expected multimodal payload on first token of each chunk")
|
||||
}
|
||||
if out[2].SameBatch != 2 || out[6].SameBatch != 2 || out[10].SameBatch != 1 {
|
||||
t.Fatalf("unexpected SameBatch values: [%d %d %d]", out[2].SameBatch, out[6].SameBatch, out[10].SameBatch)
|
||||
}
|
||||
}
|
||||
184
model/models/lfm2/model_vision.go
Normal file
184
model/models/lfm2/model_vision.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
const lfm2VisionBatchSize = 1
|
||||
|
||||
type visionPatchGrid struct {
|
||||
Width int
|
||||
Height int
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), lfm2VisionBatchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), lfm2VisionBatchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), lfm2VisionBatchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), lfm2VisionBatchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx))
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
LayerNorm2 *nn.LayerNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (l *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
|
||||
residual = hiddenState
|
||||
hiddenState = l.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize, numHeads int
|
||||
imageSize, patchSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, patches visionPatchGrid) ml.Tensor {
|
||||
numPatches := patches.Width * patches.Height
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
if m.PositionEmbedding != nil {
|
||||
posTokens := m.PositionEmbedding.Weight.Dim(1)
|
||||
source := int(math.Sqrt(float64(posTokens)))
|
||||
|
||||
var positionEmbeddings ml.Tensor
|
||||
if source > 0 && source*source == posTokens && (source != patches.Width || source != patches.Height) {
|
||||
// SigLIP2 NAFlex-style position interpolation for variable image sizes.
|
||||
positionIDs := ctx.Arange(0, float32(posTokens), 1, ml.DTypeI32)
|
||||
positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs)
|
||||
positionEmbeddings = positionEmbeddings.Reshape(ctx, -1, source, source)
|
||||
positionEmbeddings = positionEmbeddings.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
positionEmbeddings = positionEmbeddings.Interpolate(ctx, [4]int{
|
||||
patches.Width,
|
||||
patches.Height,
|
||||
hiddenState.Dim(0),
|
||||
1,
|
||||
}, ml.SamplingModeBilinear)
|
||||
positionEmbeddings = positionEmbeddings.Permute(ctx, 1, 2, 0, 3)
|
||||
positionEmbeddings = positionEmbeddings.Contiguous(ctx, -1, patches.Width*patches.Height)
|
||||
} else {
|
||||
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32)
|
||||
positionEmbeddings = m.PositionEmbedding.Forward(ctx, positionIDs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, positionEmbeddings)
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length", 1152)),
|
||||
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||
imageSize: int(c.Uint("vision.image_size", 256)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type VisionProjector struct {
|
||||
LayerNorm *nn.LayerNorm `gguf:"layer_norm"`
|
||||
Linear1 *nn.Linear `gguf:"1"`
|
||||
Linear2 *nn.Linear `gguf:"2"`
|
||||
}
|
||||
|
||||
type VisionProjectorOptions struct {
|
||||
scaleFactor int
|
||||
useLayerNorm bool
|
||||
}
|
||||
|
||||
func (p *VisionProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, patches visionPatchGrid, opts VisionProjectorOptions) ml.Tensor {
|
||||
hiddenSize := visionOutputs.Dim(0)
|
||||
featureMap := visionOutputs
|
||||
|
||||
merge := max(opts.scaleFactor, 1)
|
||||
if merge > 1 {
|
||||
width := patches.Width
|
||||
height := patches.Height
|
||||
|
||||
featureMap = featureMap.Reshape(ctx, hiddenSize, width, height)
|
||||
|
||||
// Match llama.cpp patch merger: pad spatial dims to merge factor.
|
||||
padWidth := (merge - width%merge) % merge
|
||||
padHeight := (merge - height%merge) % merge
|
||||
if padWidth != 0 || padHeight != 0 {
|
||||
featureMap = featureMap.Pad(ctx, 0, padWidth, padHeight, 0)
|
||||
width += padWidth
|
||||
height += padHeight
|
||||
}
|
||||
|
||||
featureMap = featureMap.Reshape(ctx, hiddenSize*merge, width/merge, height)
|
||||
featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx, hiddenSize*merge*merge, height/merge, width/merge)
|
||||
featureMap = featureMap.Permute(ctx, 0, 2, 1).Contiguous(ctx)
|
||||
featureMap = featureMap.Contiguous(ctx, featureMap.Dim(0), featureMap.Dim(1)*featureMap.Dim(2))
|
||||
}
|
||||
|
||||
if opts.useLayerNorm && p.LayerNorm != nil {
|
||||
featureMap = p.LayerNorm.Forward(ctx, featureMap, 1e-5)
|
||||
}
|
||||
|
||||
featureMap = p.Linear1.Forward(ctx, featureMap).GELU(ctx)
|
||||
return p.Linear2.Forward(ctx, featureMap)
|
||||
}
|
||||
260
model/models/lfm2/process_image.go
Normal file
260
model/models/lfm2/process_image.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"image"
|
||||
stdimage "image/draw"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, patchSize, numChannels int
|
||||
downsampleFactor int
|
||||
imageMean, imageStd [3]float32
|
||||
|
||||
doImageSplitting bool
|
||||
minTiles int
|
||||
maxTiles int
|
||||
useThumbnail bool
|
||||
tileSize int
|
||||
|
||||
minImageTokens int
|
||||
maxImageTokens int
|
||||
maxPixelsTolerance float64
|
||||
}
|
||||
|
||||
type processedVisionImage struct {
|
||||
data []float32
|
||||
size image.Point
|
||||
row int
|
||||
col int
|
||||
thumbnail bool
|
||||
}
|
||||
|
||||
type processedVisionLayout struct {
|
||||
rows int
|
||||
cols int
|
||||
hasThumbnail bool
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
mean := c.Floats("vision.image_mean")
|
||||
std := c.Floats("vision.image_std")
|
||||
|
||||
processor := ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 256)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
downsampleFactor: int(c.Uint("vision.projector.scale_factor", 2)),
|
||||
imageMean: [3]float32{0.5, 0.5, 0.5},
|
||||
imageStd: [3]float32{0.5, 0.5, 0.5},
|
||||
doImageSplitting: c.Bool("vision.do_image_splitting", true),
|
||||
minTiles: int(c.Uint("vision.min_tiles", 2)),
|
||||
maxTiles: int(c.Uint("vision.max_tiles", 10)),
|
||||
useThumbnail: c.Bool("vision.use_thumbnail", true),
|
||||
tileSize: int(c.Uint("vision.tile_size", 512)),
|
||||
minImageTokens: int(c.Uint("vision.min_image_tokens", 64)),
|
||||
maxImageTokens: int(c.Uint("vision.max_image_tokens", 256)),
|
||||
maxPixelsTolerance: float64(c.Float("vision.max_pixels_tolerance", 2.0)),
|
||||
}
|
||||
|
||||
if len(mean) >= 3 {
|
||||
processor.imageMean = [3]float32{mean[0], mean[1], mean[2]}
|
||||
}
|
||||
if len(std) >= 3 {
|
||||
processor.imageStd = [3]float32{std[0], std[1], std[2]}
|
||||
}
|
||||
|
||||
// Keep defaults aligned with HF unless explicitly configured.
|
||||
if processor.downsampleFactor <= 0 {
|
||||
processor.downsampleFactor = 2
|
||||
}
|
||||
if processor.patchSize <= 0 {
|
||||
processor.patchSize = 16
|
||||
}
|
||||
if processor.tileSize <= 0 {
|
||||
processor.tileSize = 512
|
||||
}
|
||||
if processor.minTiles <= 0 {
|
||||
processor.minTiles = 2
|
||||
}
|
||||
if processor.maxTiles < processor.minTiles {
|
||||
processor.maxTiles = processor.minTiles
|
||||
}
|
||||
if processor.minImageTokens <= 0 {
|
||||
processor.minImageTokens = 64
|
||||
}
|
||||
if processor.maxImageTokens < processor.minImageTokens {
|
||||
processor.maxImageTokens = processor.minImageTokens
|
||||
}
|
||||
if processor.maxPixelsTolerance <= 0 {
|
||||
processor.maxPixelsTolerance = 2.0
|
||||
}
|
||||
|
||||
return processor
|
||||
}
|
||||
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]processedVisionImage, processedVisionLayout, error) {
|
||||
img = imageproc.Composite(img)
|
||||
|
||||
orig := img.Bounds().Size()
|
||||
resizedWidth, resizedHeight := p.smartResize(orig.Y, orig.X)
|
||||
|
||||
layout := processedVisionLayout{rows: 1, cols: 1}
|
||||
if p.shouldSplit(orig.Y, orig.X) {
|
||||
gridWidth, gridHeight, targetWidth, targetHeight := p.gridLayout(orig.Y, orig.X)
|
||||
layout.rows = gridHeight
|
||||
layout.cols = gridWidth
|
||||
layout.hasThumbnail = p.useThumbnail && gridWidth*gridHeight != 1
|
||||
|
||||
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
|
||||
images := make([]processedVisionImage, 0, gridWidth*gridHeight+1)
|
||||
for row := range gridHeight {
|
||||
for col := range gridWidth {
|
||||
rect := image.Rect(
|
||||
col*p.tileSize,
|
||||
row*p.tileSize,
|
||||
(col+1)*p.tileSize,
|
||||
(row+1)*p.tileSize,
|
||||
)
|
||||
tile := cropImage(resized, rect)
|
||||
images = append(images, processedVisionImage{
|
||||
data: imageproc.Normalize(tile, p.imageMean, p.imageStd, true, true),
|
||||
size: tile.Bounds().Size(),
|
||||
row: row + 1,
|
||||
col: col + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if layout.hasThumbnail {
|
||||
thumbnail := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||
images = append(images, processedVisionImage{
|
||||
data: imageproc.Normalize(thumbnail, p.imageMean, p.imageStd, true, true),
|
||||
size: thumbnail.Bounds().Size(),
|
||||
thumbnail: true,
|
||||
})
|
||||
}
|
||||
|
||||
return images, layout, nil
|
||||
}
|
||||
|
||||
single := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||
return []processedVisionImage{{
|
||||
data: imageproc.Normalize(single, p.imageMean, p.imageStd, true, true),
|
||||
size: single.Bounds().Size(),
|
||||
}}, layout, nil
|
||||
}
|
||||
|
||||
func (p ImageProcessor) shouldSplit(height, width int) bool {
|
||||
if !p.doImageSplitting || p.minTiles == 1 && p.maxTiles == 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
totalFactor := p.patchSize * p.downsampleFactor
|
||||
hBar := max(p.patchSize, roundByFactor(height, totalFactor))
|
||||
wBar := max(p.patchSize, roundByFactor(width, totalFactor))
|
||||
|
||||
limit := float64(p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor)
|
||||
limit *= p.maxPixelsTolerance
|
||||
|
||||
return float64(hBar*wBar) > limit
|
||||
}
|
||||
|
||||
func (p ImageProcessor) smartResize(height, width int) (int, int) {
|
||||
totalFactor := p.patchSize * p.downsampleFactor
|
||||
minPixels := p.minImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor
|
||||
maxPixels := p.maxImageTokens * p.patchSize * p.patchSize * p.downsampleFactor * p.downsampleFactor
|
||||
|
||||
hBar := max(totalFactor, roundByFactor(height, totalFactor))
|
||||
wBar := max(totalFactor, roundByFactor(width, totalFactor))
|
||||
|
||||
if hBar*wBar > maxPixels {
|
||||
beta := math.Sqrt(float64(height*width) / float64(maxPixels))
|
||||
hBar = max(totalFactor, int(math.Floor(float64(height)/beta/float64(totalFactor)))*totalFactor)
|
||||
wBar = max(totalFactor, int(math.Floor(float64(width)/beta/float64(totalFactor)))*totalFactor)
|
||||
} else if hBar*wBar < minPixels {
|
||||
beta := math.Sqrt(float64(minPixels) / float64(height*width))
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(totalFactor))) * totalFactor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(totalFactor))) * totalFactor
|
||||
}
|
||||
|
||||
return wBar, hBar
|
||||
}
|
||||
|
||||
func (p ImageProcessor) gridLayout(height, width int) (gridWidth, gridHeight, targetWidth, targetHeight int) {
|
||||
aspectRatio := float64(width) / float64(height)
|
||||
targetRatios := p.targetRatios()
|
||||
bestRatio := clipImageSize{width: 1, height: 1}
|
||||
bestRatioDiff := math.MaxFloat64
|
||||
area := float64(width * height)
|
||||
|
||||
for _, ratio := range targetRatios {
|
||||
targetAspect := float64(ratio.width) / float64(ratio.height)
|
||||
ratioDiff := math.Abs(aspectRatio - targetAspect)
|
||||
|
||||
if ratioDiff < bestRatioDiff {
|
||||
bestRatioDiff = ratioDiff
|
||||
bestRatio = ratio
|
||||
continue
|
||||
}
|
||||
|
||||
if ratioDiff == bestRatioDiff {
|
||||
targetArea := float64(p.tileSize * p.tileSize * ratio.width * ratio.height)
|
||||
if area > 0.5*targetArea {
|
||||
bestRatio = ratio
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bestRatio.width, bestRatio.height, p.tileSize * bestRatio.width, p.tileSize * bestRatio.height
|
||||
}
|
||||
|
||||
type clipImageSize struct {
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func (p ImageProcessor) targetRatios() []clipImageSize {
|
||||
targetRatios := make([]clipImageSize, 0, p.maxTiles*p.maxTiles)
|
||||
for n := p.minTiles; n <= p.maxTiles; n++ {
|
||||
for w := 1; w <= n; w++ {
|
||||
for h := 1; h <= n; h++ {
|
||||
if w*h < p.minTiles || w*h > p.maxTiles {
|
||||
continue
|
||||
}
|
||||
targetRatios = append(targetRatios, clipImageSize{width: w, height: h})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unique := targetRatios[:0]
|
||||
for _, ratio := range targetRatios {
|
||||
if slices.Contains(unique, ratio) {
|
||||
continue
|
||||
}
|
||||
unique = append(unique, ratio)
|
||||
}
|
||||
|
||||
slices.SortFunc(unique, func(a, b clipImageSize) int {
|
||||
return a.width*a.height - b.width*b.height
|
||||
})
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
func roundByFactor(number, factor int) int {
|
||||
if factor <= 0 {
|
||||
return number
|
||||
}
|
||||
return int(math.RoundToEven(float64(number)/float64(factor))) * factor
|
||||
}
|
||||
|
||||
func cropImage(img image.Image, rect image.Rectangle) image.Image {
|
||||
dst := image.NewRGBA(image.Rect(0, 0, rect.Dx(), rect.Dy()))
|
||||
stdimage.Draw(dst, dst.Bounds(), img, rect.Min, stdimage.Src)
|
||||
return dst
|
||||
}
|
||||
105
model/models/lfm2/process_image_test.go
Normal file
105
model/models/lfm2/process_image_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessImageSingleTile(t *testing.T) {
|
||||
p := ImageProcessor{
|
||||
patchSize: 16,
|
||||
downsampleFactor: 2,
|
||||
numChannels: 3,
|
||||
imageMean: [3]float32{0.5, 0.5, 0.5},
|
||||
imageStd: [3]float32{0.5, 0.5, 0.5},
|
||||
doImageSplitting: true,
|
||||
minTiles: 2,
|
||||
maxTiles: 10,
|
||||
useThumbnail: true,
|
||||
tileSize: 512,
|
||||
minImageTokens: 64,
|
||||
maxImageTokens: 256,
|
||||
maxPixelsTolerance: 2.0,
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, 320, 320))
|
||||
out, layout, err := p.ProcessImage(img)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessImage returned error: %v", err)
|
||||
}
|
||||
|
||||
if layout.rows != 1 || layout.cols != 1 || layout.hasThumbnail {
|
||||
t.Fatalf("layout = %+v, want rows=1 cols=1 hasThumbnail=false", layout)
|
||||
}
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("len(out) = %d, want 1", len(out))
|
||||
}
|
||||
if out[0].size != (image.Point{X: 320, Y: 320}) {
|
||||
t.Fatalf("single image size = %+v, want 320x320", out[0].size)
|
||||
}
|
||||
if out[0].thumbnail {
|
||||
t.Fatalf("single image should not be marked as thumbnail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessImageDynamicTiling(t *testing.T) {
|
||||
p := ImageProcessor{
|
||||
patchSize: 16,
|
||||
downsampleFactor: 2,
|
||||
numChannels: 3,
|
||||
imageMean: [3]float32{0.5, 0.5, 0.5},
|
||||
imageStd: [3]float32{0.5, 0.5, 0.5},
|
||||
doImageSplitting: true,
|
||||
minTiles: 2,
|
||||
maxTiles: 10,
|
||||
useThumbnail: true,
|
||||
tileSize: 512,
|
||||
minImageTokens: 64,
|
||||
maxImageTokens: 256,
|
||||
maxPixelsTolerance: 2.0,
|
||||
}
|
||||
|
||||
// Wide image that should trigger multi-tile splitting.
|
||||
img := image.NewRGBA(image.Rect(0, 0, 3000, 1000))
|
||||
fill := color.RGBA{R: 120, G: 90, B: 60, A: 255}
|
||||
for y := range 1000 {
|
||||
for x := range 3000 {
|
||||
img.Set(x, y, fill)
|
||||
}
|
||||
}
|
||||
|
||||
out, layout, err := p.ProcessImage(img)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessImage returned error: %v", err)
|
||||
}
|
||||
|
||||
if layout.rows*layout.cols <= 1 {
|
||||
t.Fatalf("expected multi-tile layout, got %+v", layout)
|
||||
}
|
||||
if !layout.hasThumbnail {
|
||||
t.Fatalf("expected thumbnail for multi-tile layout")
|
||||
}
|
||||
|
||||
wantLen := layout.rows*layout.cols + 1
|
||||
if len(out) != wantLen {
|
||||
t.Fatalf("len(out) = %d, want %d", len(out), wantLen)
|
||||
}
|
||||
|
||||
for i := range layout.rows * layout.cols {
|
||||
if out[i].size != (image.Point{X: 512, Y: 512}) {
|
||||
t.Fatalf("tile[%d] size = %+v, want 512x512", i, out[i].size)
|
||||
}
|
||||
if out[i].thumbnail {
|
||||
t.Fatalf("tile[%d] should not be marked as thumbnail", i)
|
||||
}
|
||||
}
|
||||
|
||||
thumb := out[len(out)-1]
|
||||
if !thumb.thumbnail {
|
||||
t.Fatalf("last chunk should be thumbnail")
|
||||
}
|
||||
if thumb.size.X%32 != 0 || thumb.size.Y%32 != 0 {
|
||||
t.Fatalf("thumbnail size = %+v, want dimensions aligned to 32", thumb.size)
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,8 @@ type LFM2Parser struct {
|
||||
hasThinkingSupport bool
|
||||
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
|
||||
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
|
||||
toolNames map[string]struct{}
|
||||
hasTools bool
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) HasToolSupport() bool {
|
||||
@@ -63,6 +65,13 @@ func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.T
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.toolNames = make(map[string]struct{}, len(tools))
|
||||
p.hasTools = len(tools) > 0
|
||||
for _, tool := range tools {
|
||||
if tool.Function.Name != "" {
|
||||
p.toolNames[tool.Function.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
p.setInitialState(lastMessage, thinkValue)
|
||||
return tools
|
||||
}
|
||||
@@ -105,9 +114,33 @@ func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string,
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback for models that emit bare tool calls without <|tool_call_*|> wrappers.
|
||||
if done && len(toolCalls) == 0 && p.hasTools {
|
||||
candidate := strings.TrimSpace(contentSb.String())
|
||||
if fallbackCalls, parseErr := p.parseToolCallsContent(candidate); parseErr == nil && p.toolCallsAllowed(fallbackCalls) {
|
||||
contentSb.Reset()
|
||||
toolCalls = append(toolCalls, fallbackCalls...)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) toolCallsAllowed(calls []api.ToolCall) bool {
|
||||
if len(calls) == 0 {
|
||||
return false
|
||||
}
|
||||
if len(p.toolNames) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, call := range calls {
|
||||
if _, ok := p.toolNames[call.Function.Name]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) parseEvents() []lfm2Event {
|
||||
var all []lfm2Event
|
||||
|
||||
@@ -269,36 +302,16 @@ func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseToolCallsContent parses one or more tool calls from content
|
||||
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
|
||||
// parseToolCallsContent parses one or more Python-style tool calls.
|
||||
// Example: [func1(arg='v'), func2(x=1)]
|
||||
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Try JSON format first: {"name": "func", "arguments": {...}}
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
// Be tolerant of malformed outputs that include wrapper tags without proper pairing.
|
||||
content = strings.TrimSpace(strings.TrimPrefix(content, lfm2ToolCallStartTag))
|
||||
content = strings.TrimSpace(strings.TrimSuffix(content, lfm2ToolCallEndTag))
|
||||
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
|
||||
var args api.ToolCallFunctionArguments
|
||||
if len(parsed.Arguments) > 0 {
|
||||
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
args = api.NewToolCallFunctionArguments()
|
||||
}
|
||||
|
||||
return []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
|
||||
// Parse Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
|
||||
return p.parsePythonStyleToolCalls(content)
|
||||
}
|
||||
|
||||
@@ -417,21 +430,16 @@ func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error)
|
||||
|
||||
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
|
||||
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
|
||||
// Simple state machine to parse key='value' pairs
|
||||
// Handles: command='ls', flag="-la", count=42, enabled=true
|
||||
var key string
|
||||
i := 0
|
||||
|
||||
for i < len(argsStr) {
|
||||
// Skip whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
// Skip separators and whitespace.
|
||||
for i < len(argsStr) && (argsStr[i] == ',' || unicode.IsSpace(rune(argsStr[i]))) {
|
||||
i++
|
||||
}
|
||||
if i >= len(argsStr) {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse key
|
||||
keyStart := i
|
||||
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
|
||||
i++
|
||||
@@ -439,60 +447,238 @@ func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error
|
||||
if i >= len(argsStr) || argsStr[i] != '=' {
|
||||
return errors.New("invalid argument: expected '='")
|
||||
}
|
||||
key = strings.TrimSpace(argsStr[keyStart:i])
|
||||
|
||||
key := strings.TrimSpace(argsStr[keyStart:i])
|
||||
if key == "" {
|
||||
return errors.New("invalid argument: empty key")
|
||||
}
|
||||
i++ // skip '='
|
||||
|
||||
// Skip whitespace after =
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
|
||||
for i < len(argsStr) && unicode.IsSpace(rune(argsStr[i])) {
|
||||
i++
|
||||
}
|
||||
|
||||
// Parse value
|
||||
var value string
|
||||
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
|
||||
// Quoted string
|
||||
quote := argsStr[i]
|
||||
i++
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != quote {
|
||||
if argsStr[i] == '\\' && i+1 < len(argsStr) {
|
||||
i += 2 // skip escaped char
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
value = argsStr[valueStart:i]
|
||||
if i < len(argsStr) {
|
||||
i++ // skip closing quote
|
||||
}
|
||||
args.Set(key, value)
|
||||
} else {
|
||||
// Unquoted value (number, bool, etc)
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != ',' {
|
||||
i++
|
||||
}
|
||||
value = strings.TrimSpace(argsStr[valueStart:i])
|
||||
|
||||
// Try to parse as number or bool
|
||||
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if value == "true" {
|
||||
args.Set(key, true)
|
||||
} else if value == "false" {
|
||||
args.Set(key, false)
|
||||
} else {
|
||||
args.Set(key, value)
|
||||
}
|
||||
if i >= len(argsStr) {
|
||||
return errors.New("invalid argument: missing value")
|
||||
}
|
||||
|
||||
// Skip comma and whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
value, next, err := parsePythonArgValue(argsStr, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args.Set(key, value)
|
||||
i = next
|
||||
|
||||
// Optional trailing comma before next key/value.
|
||||
if i < len(argsStr) && argsStr[i] == ',' {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePythonArgValue(s string, i int) (any, int, error) {
|
||||
if i >= len(s) {
|
||||
return nil, i, errors.New("invalid argument: missing value")
|
||||
}
|
||||
|
||||
// Quoted string literal.
|
||||
if s[i] == '\'' || s[i] == '"' {
|
||||
quote := s[i]
|
||||
i++
|
||||
start := i
|
||||
for i < len(s) {
|
||||
if s[i] == '\\' && i+1 < len(s) {
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if s[i] == quote {
|
||||
value := s[start:i]
|
||||
i++
|
||||
return value, i, nil
|
||||
}
|
||||
i++
|
||||
}
|
||||
return nil, i, errors.New("invalid argument: unterminated string")
|
||||
}
|
||||
|
||||
// Unquoted literal. Consume until top-level comma.
|
||||
start := i
|
||||
depthParen, depthSquare, depthCurly := 0, 0, 0
|
||||
inString := false
|
||||
var quote byte
|
||||
escaped := false
|
||||
|
||||
for i < len(s) {
|
||||
ch := s[i]
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
} else if ch == '\\' {
|
||||
escaped = true
|
||||
} else if ch == quote {
|
||||
inString = false
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '\'', '"':
|
||||
inString = true
|
||||
quote = ch
|
||||
case '(':
|
||||
depthParen++
|
||||
case ')':
|
||||
if depthParen > 0 {
|
||||
depthParen--
|
||||
}
|
||||
case '[':
|
||||
depthSquare++
|
||||
case ']':
|
||||
if depthSquare > 0 {
|
||||
depthSquare--
|
||||
}
|
||||
case '{':
|
||||
depthCurly++
|
||||
case '}':
|
||||
if depthCurly > 0 {
|
||||
depthCurly--
|
||||
}
|
||||
case ',':
|
||||
if depthParen == 0 && depthSquare == 0 && depthCurly == 0 {
|
||||
token := strings.TrimSpace(s[start:i])
|
||||
value, err := parsePythonLiteral(token)
|
||||
return value, i, err
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(s[start:i])
|
||||
value, err := parsePythonLiteral(token)
|
||||
return value, i, err
|
||||
}
|
||||
|
||||
func parsePythonLiteral(token string) (any, error) {
|
||||
switch token {
|
||||
case "":
|
||||
return "", nil
|
||||
case "true", "True":
|
||||
return true, nil
|
||||
case "false", "False":
|
||||
return false, nil
|
||||
case "null", "None":
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if v, err := strconv.ParseInt(token, 10, 64); err == nil {
|
||||
return v, nil
|
||||
}
|
||||
if v, err := strconv.ParseFloat(token, 64); err == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(token, "[") || strings.HasPrefix(token, "{") {
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(token), &parsed); err == nil {
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
if converted, err := pythonLiteralToJSON(token); err == nil {
|
||||
if err := json.Unmarshal([]byte(converted), &parsed); err == nil {
|
||||
return parsed, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func pythonLiteralToJSON(s string) (string, error) {
|
||||
var out strings.Builder
|
||||
out.Grow(len(s) + len(s)/8)
|
||||
|
||||
inString := false
|
||||
var quote byte
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
ch := s[i]
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
out.WriteByte(ch)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\\' {
|
||||
out.WriteByte(ch)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == quote {
|
||||
out.WriteByte('"')
|
||||
inString = false
|
||||
continue
|
||||
}
|
||||
|
||||
if quote == '\'' && ch == '"' {
|
||||
out.WriteString(`\"`)
|
||||
continue
|
||||
}
|
||||
|
||||
out.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\'' || ch == '"' {
|
||||
inString = true
|
||||
quote = ch
|
||||
escaped = false
|
||||
out.WriteByte('"')
|
||||
continue
|
||||
}
|
||||
|
||||
// Replace Python identifiers with JSON equivalents when outside strings.
|
||||
if isIdentStart(ch) {
|
||||
j := i + 1
|
||||
for j < len(s) && isIdentPart(s[j]) {
|
||||
j++
|
||||
}
|
||||
|
||||
ident := s[i:j]
|
||||
switch ident {
|
||||
case "True":
|
||||
out.WriteString("true")
|
||||
case "False":
|
||||
out.WriteString("false")
|
||||
case "None":
|
||||
out.WriteString("null")
|
||||
default:
|
||||
out.WriteString(ident)
|
||||
}
|
||||
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
|
||||
out.WriteByte(ch)
|
||||
}
|
||||
|
||||
if inString {
|
||||
return "", errors.New("unterminated string")
|
||||
}
|
||||
|
||||
return out.String(), nil
|
||||
}
|
||||
|
||||
func isIdentStart(b byte) bool {
|
||||
return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || b == '_'
|
||||
}
|
||||
|
||||
func isIdentPart(b byte) bool {
|
||||
return isIdentStart(b) || (b >= '0' && b <= '9')
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "tool_call_simple",
|
||||
input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
|
||||
input: "I'll check the weather.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>",
|
||||
expectedContent: "I'll check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -55,7 +55,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>",
|
||||
input: "Getting weather for both cities.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|tool_call_start|>[get_weather(location=\"London\")]<|tool_call_end|>",
|
||||
expectedContent: "Getting weather for both cities.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -79,7 +79,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "complex_tool_arguments",
|
||||
input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>",
|
||||
input: "Processing data.<|tool_call_start|>[process_data(items=['item1','item2'], config={'enabled': True, 'threshold': 0.95})]<|tool_call_end|>",
|
||||
expectedContent: "Processing data.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -96,7 +96,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "thinking_with_tool_call",
|
||||
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
|
||||
input: "Let me check the weather...</think>I'll get that for you.<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>",
|
||||
expectedThinking: "Let me check the weather...",
|
||||
expectedContent: "I'll get that for you.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
@@ -144,16 +144,16 @@ func TestLFM2Parser(t *testing.T) {
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_unicode_args",
|
||||
input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>",
|
||||
name: "tool_call_with_text_args",
|
||||
input: "Searching for information.<|tool_call_start|>[search(query='beijing weather', language='zh')]<|tool_call_end|>",
|
||||
expectedContent: "Searching for information.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"query": "北京天气",
|
||||
"language": "中文",
|
||||
"query": "beijing weather",
|
||||
"language": "zh",
|
||||
}),
|
||||
},
|
||||
},
|
||||
@@ -169,7 +169,7 @@ func TestLFM2Parser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "empty_tool_call_args",
|
||||
input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>",
|
||||
input: "Pinging server.<|tool_call_start|>[ping()]<|tool_call_end|>",
|
||||
expectedContent: "Pinging server.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -353,7 +353,7 @@ func TestLFM2Parser_Streaming(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "streaming_tool_call",
|
||||
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"},
|
||||
chunks: []string{"I'll check weather.", "<|tool_call_start|>", "[get_weather(", "location=\"Paris\")]", "<|tool_call_end|>"},
|
||||
expectedContent: "I'll check weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -381,16 +381,16 @@ func TestLFM2Parser_Streaming(t *testing.T) {
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "streaming_tool_call_with_split_json",
|
||||
chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"},
|
||||
name: "streaming_tool_call_with_split_python",
|
||||
chunks: []string{"Processing.", "<|tool_call_start|>", "[calc(", "x=42, ", "y=24)]", "<|tool_call_end|>"},
|
||||
expectedContent: "Processing.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calc",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": float64(42),
|
||||
"y": float64(24),
|
||||
"x": int64(42),
|
||||
"y": int64(24),
|
||||
}),
|
||||
},
|
||||
},
|
||||
@@ -516,8 +516,8 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid_tool_call",
|
||||
content: `{"name":"get_weather","arguments":{"location":"Paris"}}`,
|
||||
name: "python_style_single_call",
|
||||
content: `get_weather(location="Paris")`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
@@ -528,21 +528,33 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex_arguments",
|
||||
content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`,
|
||||
name: "python_style_with_brackets",
|
||||
content: `[get_weather(location="Paris")]`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []interface{}{"a", "b"},
|
||||
"config": map[string]interface{}{"enabled": true},
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_arguments",
|
||||
content: `{"name":"ping","arguments":{}}`,
|
||||
name: "python_style_complex_arguments",
|
||||
content: `process(items=['a', 'b'], config={'enabled': True})`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []any{"a", "b"},
|
||||
"config": map[string]any{"enabled": true},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "python_style_empty_arguments",
|
||||
content: `ping()`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ping",
|
||||
@@ -551,44 +563,13 @@ func TestLFM2Parser_parseToolCallContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unicode_in_tool_name",
|
||||
content: `{"name":"获取天气","arguments":{"城市":"北京"}}`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"城市": "北京",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "numeric_arguments",
|
||||
content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`,
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid_json",
|
||||
content: `{invalid json}`,
|
||||
name: "missing_parenthesis",
|
||||
content: `get_weather location="Paris")`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing_name",
|
||||
content: `{"arguments":{"arg":"value"}}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty_name",
|
||||
content: `{"name":"","arguments":{"arg":"value"}}`,
|
||||
name: "invalid_argument_format",
|
||||
content: `bash(command)`,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -645,6 +626,24 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "python_style_complex_literals",
|
||||
content: `[AskUserQuestion(question="What's up?", headers=['Hello!', 'How can I help you?'], options=['Debugging help', 'Code writing assistance'], multiSelect=False, metadata={'priority': 1, 'active': True})]`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "AskUserQuestion",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"question": "What's up?",
|
||||
"headers": []any{"Hello!", "How can I help you?"},
|
||||
"options": []any{"Debugging help", "Code writing assistance"},
|
||||
"multiSelect": false,
|
||||
"metadata": map[string]any{"priority": float64(1), "active": true},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single_python_style_call",
|
||||
content: `bash(command='ls -la')`,
|
||||
@@ -673,6 +672,34 @@ func TestLFM2Parser_parseToolCallsContent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single_call_with_orphan_end_tag",
|
||||
content: `[bash(command='ls')]<|tool_call_end|>`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single_call_with_wrapper_tags",
|
||||
content: `<|tool_call_start|>[bash(command='pwd')]<|tool_call_end|>`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "pwd",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_different_functions",
|
||||
content: `[get_weather(location='Paris'),search(query='news')]`,
|
||||
@@ -1086,3 +1113,106 @@ func TestLFM2Parser_EdgeCases(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Parser_BareToolCallFallback(t *testing.T) {
|
||||
parser := &LFM2Parser{}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
},
|
||||
},
|
||||
}
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
content, thinking, calls, err := parser.Add(`[get_weather(location="Paris")]`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name get_weather, got %q", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Parser_BareUnknownToolCallDoesNotParse(t *testing.T) {
|
||||
parser := &LFM2Parser{}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
},
|
||||
},
|
||||
}
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := `[unknown_tool(location="Paris")]`
|
||||
content, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if content != input {
|
||||
t.Fatalf("expected content to be preserved, got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Parser_ImagePlaceholdersPreserved(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "indexed_img_placeholder",
|
||||
input: "[img-0]describe this image",
|
||||
},
|
||||
{
|
||||
name: "template_image_placeholder",
|
||||
input: "<image>describe this image",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &LFM2Parser{}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "bash",
|
||||
},
|
||||
},
|
||||
}
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
content, thinking, calls, err := parser.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if content != tt.input {
|
||||
t.Fatalf("expected content %q, got %q", tt.input, content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +57,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
||||
{"qwen3"},
|
||||
{"qwen3-thinking"},
|
||||
{"qwen3-coder"},
|
||||
{"lfm2"},
|
||||
{"lfm2-thinking"},
|
||||
{"harmony"},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -9,18 +11,218 @@ import (
|
||||
|
||||
type LFM2Renderer struct {
|
||||
IsThinking bool
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
const lfm2BOSToken = "<|startoftext|>"
|
||||
|
||||
const (
|
||||
lfm2ToolListStartTag = "<|tool_list_start|>"
|
||||
lfm2ToolListEndTag = "<|tool_list_end|>"
|
||||
lfm2ToolCallStartTag = "<|tool_call_start|>"
|
||||
lfm2ToolCallEndTag = "<|tool_call_end|>"
|
||||
lfm2ToolResponseStartTag = "<|tool_response_start|>"
|
||||
lfm2ToolResponseEndTag = "<|tool_response_end|>"
|
||||
)
|
||||
|
||||
func lfm2RenderSystemContent(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
var sb strings.Builder
|
||||
for _, item := range v {
|
||||
obj, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if itemType, _ := obj["type"].(string); itemType == "text" {
|
||||
if text, ok := obj["text"].(string); ok {
|
||||
sb.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func lfm2JSON(v any) string {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(v); err != nil {
|
||||
fallback, _ := json.Marshal(v)
|
||||
return string(fallback)
|
||||
}
|
||||
|
||||
encoded := bytes.TrimSuffix(buf.Bytes(), []byte{'\n'})
|
||||
|
||||
// HF `tojson` defaults to `json.dumps(..., separators=None)`, which inserts
|
||||
// a space after commas and colons.
|
||||
var out strings.Builder
|
||||
out.Grow(len(encoded) + len(encoded)/8)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i, b := range encoded {
|
||||
out.WriteByte(b)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if b == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if b == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if b == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if (b == ':' || b == ',') && i+1 < len(encoded) {
|
||||
next := encoded[i+1]
|
||||
if next != ' ' && next != '\n' && next != '\r' && next != '\t' {
|
||||
out.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func lfm2ImagePlaceholder(useImgTags bool) string {
|
||||
if useImgTags {
|
||||
return "[img]"
|
||||
}
|
||||
|
||||
return "<image>"
|
||||
}
|
||||
|
||||
func lfm2RenderContent(content any, useImgTags bool) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
var sb strings.Builder
|
||||
for _, item := range v {
|
||||
obj, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
sb.WriteString(lfm2JSON(item))
|
||||
continue
|
||||
}
|
||||
|
||||
itemType, _ := obj["type"].(string)
|
||||
switch itemType {
|
||||
case "image":
|
||||
sb.WriteString(lfm2ImagePlaceholder(useImgTags))
|
||||
case "text":
|
||||
if text, ok := obj["text"].(string); ok {
|
||||
sb.WriteString(text)
|
||||
} else {
|
||||
sb.WriteString(lfm2JSON(item))
|
||||
}
|
||||
default:
|
||||
sb.WriteString(lfm2JSON(item))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
default:
|
||||
return lfm2JSON(content)
|
||||
}
|
||||
}
|
||||
|
||||
func lfm2ToolSchema(tool api.Tool) any {
|
||||
if tool.Function.Name == "" {
|
||||
return tool
|
||||
}
|
||||
|
||||
// LFM2 templates are typically fed function-schema objects (name/description/parameters).
|
||||
return tool.Function
|
||||
}
|
||||
|
||||
func lfm2ToolCallArgument(v any) string {
|
||||
return lfm2JSON(v)
|
||||
}
|
||||
|
||||
func lfm2RenderToolCalls(calls []api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(lfm2ToolCallStartTag)
|
||||
sb.WriteString("[")
|
||||
for i, tc := range calls {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
|
||||
sb.WriteString(tc.Function.Name)
|
||||
sb.WriteString("(")
|
||||
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for key := range tc.Function.Arguments.All() {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for j, key := range keys {
|
||||
if j > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
value, _ := tc.Function.Arguments.Get(key)
|
||||
sb.WriteString(key)
|
||||
sb.WriteString("=")
|
||||
sb.WriteString(lfm2ToolCallArgument(value))
|
||||
}
|
||||
|
||||
sb.WriteString(")")
|
||||
}
|
||||
sb.WriteString("]")
|
||||
sb.WriteString(lfm2ToolCallEndTag)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *LFM2Renderer) renderMessageContent(message api.Message) string {
|
||||
content := lfm2RenderContent(message.Content, r.useImgTags)
|
||||
if len(message.Images) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
// chatPrompt may already have inserted [img] / [img-n] placeholders.
|
||||
if strings.Contains(content, "[img]") || strings.Contains(content, "[img-") || strings.Contains(content, "<image>") {
|
||||
return content
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
placeholder := lfm2ImagePlaceholder(r.useImgTags)
|
||||
for range message.Images {
|
||||
sb.WriteString(placeholder)
|
||||
}
|
||||
sb.WriteString(content)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
|
||||
// Follow Liquid tool-use formatting for LFM2 tool wrappers.
|
||||
sb.WriteString(lfm2BOSToken)
|
||||
|
||||
// Extract first system message if present (to combine with tools)
|
||||
var firstSystemContent string
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
firstSystemContent = messages[0].Content
|
||||
firstSystemContent = lfm2RenderSystemContent(messages[0].Content)
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
@@ -29,18 +231,17 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
||||
if firstSystemContent != "" {
|
||||
firstSystemContent += "\n"
|
||||
}
|
||||
firstSystemContent += "List of tools: ["
|
||||
firstSystemContent += "List of tools: "
|
||||
firstSystemContent += lfm2ToolListStartTag
|
||||
firstSystemContent += "["
|
||||
for i, tool := range tools {
|
||||
toolJSON, err := json.Marshal(tool)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
firstSystemContent += string(toolJSON)
|
||||
firstSystemContent += lfm2JSON(lfm2ToolSchema(tool))
|
||||
if i < len(tools)-1 {
|
||||
firstSystemContent += ", "
|
||||
}
|
||||
}
|
||||
firstSystemContent += "]"
|
||||
firstSystemContent += lfm2ToolListEndTag
|
||||
}
|
||||
|
||||
// Output first system block if it has content
|
||||
@@ -50,6 +251,8 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
|
||||
|
||||
// Find the index of the last assistant message for thinking stripping
|
||||
lastAssistantIndex := -1
|
||||
for i := len(messages) - 1; i >= startIdx; i-- {
|
||||
@@ -59,85 +262,47 @@ func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkVal
|
||||
}
|
||||
}
|
||||
|
||||
// Track whether we need to add generation prompt
|
||||
needsGenerationPrompt := len(messages) > 0
|
||||
|
||||
for i := startIdx; i < len(messages); i++ {
|
||||
message := messages[i]
|
||||
switch message.Role {
|
||||
case "system":
|
||||
// Additional system messages (after the first) are rendered normally
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
lastMessage := i == len(messages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
sb.WriteString("<|im_start|>")
|
||||
sb.WriteString(message.Role)
|
||||
sb.WriteString("\n")
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
// Check if this is the last assistant message
|
||||
isLastAssistant := i == lastAssistantIndex
|
||||
|
||||
// Process content (may need thinking stripped)
|
||||
content := message.Content
|
||||
|
||||
// Handle thinking tags in assistant content
|
||||
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
|
||||
if strings.Contains(content, "</think>") {
|
||||
parts := strings.SplitN(content, "</think>", 2)
|
||||
if len(parts) > 1 {
|
||||
if !isLastAssistant && !keepPastThinking {
|
||||
// Strip thinking entirely for past assistant messages
|
||||
content = strings.TrimSpace(parts[1])
|
||||
} else {
|
||||
// Preserve thinking but trim whitespace after </think>
|
||||
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
|
||||
}
|
||||
}
|
||||
content := r.renderMessageContent(message)
|
||||
if message.Role == "assistant" && !keepPastThinking && i != lastAssistantIndex {
|
||||
if idx := strings.LastIndex(content, "</think>"); idx >= 0 {
|
||||
content = strings.TrimSpace(content[idx+len("</think>"):])
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
// Assistant with tool calls - write content first (if any after stripping)
|
||||
if content != "" {
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<|tool_call_start|>")
|
||||
toolCallJSON := map[string]any{
|
||||
"name": toolCall.Function.Name,
|
||||
"arguments": toolCall.Function.Arguments,
|
||||
}
|
||||
callJSON, _ := json.Marshal(toolCallJSON)
|
||||
sb.WriteString(string(callJSON))
|
||||
sb.WriteString("<|tool_call_end|>")
|
||||
}
|
||||
}
|
||||
if message.Role == "assistant" && len(message.ToolCalls) > 0 && !strings.Contains(content, lfm2ToolCallStartTag) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = lfm2RenderToolCalls(message.ToolCalls) + content
|
||||
} else {
|
||||
sb.WriteString(content)
|
||||
content = lfm2RenderToolCalls(message.ToolCalls) + "\n" + content
|
||||
}
|
||||
}
|
||||
if message.Role == "tool" && !strings.Contains(content, lfm2ToolResponseStartTag) {
|
||||
content = lfm2ToolResponseStartTag + content + lfm2ToolResponseEndTag
|
||||
}
|
||||
|
||||
sb.WriteString(content)
|
||||
if !prefill {
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
|
||||
|
||||
case "tool":
|
||||
// Tool responses are rendered as plain messages per the chat template
|
||||
sb.WriteString("<|im_start|>tool\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
needsGenerationPrompt := true
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "assistant" {
|
||||
needsGenerationPrompt = false
|
||||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
// RenderWithRenderer uses add_generation_prompt=true for chat rendering,
|
||||
// unless we're prefilling a trailing assistant message.
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
// Note: Model is a "thinking-only" model - it will output <think> itself
|
||||
// We don't add <think> tag to the prompt
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
|
||||
@@ -8,73 +8,136 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestLFM2Renderer(t *testing.T) {
|
||||
func TestLFM2Renderer_ChatTemplateParity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer *LFM2Renderer
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
name: "user_only",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple system messages rendered separately",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "First instruction."},
|
||||
{Role: "system", Content: "Second instruction."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "The answer is 4."},
|
||||
{Role: "user", Content: "Thanks!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "only system message",
|
||||
name: "system_and_user",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
|
||||
name: "user-assistant-user: last assistant preserves thinking",
|
||||
name: "tools_without_system",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reasoning</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
{Role: "user", Content: "Use tools"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>system\nList of tools: <|tool_list_start|>[{\"name\": \"get_weather\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" +
|
||||
"<|im_start|>user\nUse tools<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// With two assistants, first is stripped (not last), second preserved (is last)
|
||||
name: "multi-turn thinking: first stripped, second preserved",
|
||||
name: "first_system_combined_with_tools",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "Follow instructions."},
|
||||
{Role: "user", Content: "Do work"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "tool_a",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "tool_b",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|startoftext|><|im_start|>system\nFollow instructions.\nList of tools: <|tool_list_start|>[{\"name\": \"tool_a\", \"parameters\": {\"type\": \"object\", \"properties\": null}}, {\"name\": \"tool_b\", \"parameters\": {\"type\": \"object\", \"properties\": null}}]<|tool_list_end|><|im_end|>\n" +
|
||||
"<|im_start|>user\nDo work<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant_tool_calls_and_tool_responses_are_rendered",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Call a tool"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|><|im_end|>\n<|im_start|>tool\n<|tool_response_start|>22C<|tool_response_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant_tool_calls_with_content_preserves_both",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Call a tool"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Checking now.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|startoftext|><|im_start|>user\nCall a tool<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>\nChecking now.",
|
||||
},
|
||||
{
|
||||
name: "thinking_strips_non_last_assistant_when_disabled",
|
||||
renderer: &LFM2Renderer{IsThinking: true},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
@@ -82,11 +145,11 @@ func TestLFM2Renderer(t *testing.T) {
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2",
|
||||
},
|
||||
{
|
||||
// With thinking enabled (keep_past_thinking=true), both preserved
|
||||
name: "multi-turn thinking: both preserved when thinking enabled",
|
||||
name: "thinking_preserves_past_assistant_when_enabled",
|
||||
renderer: &LFM2Renderer{IsThinking: true},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
@@ -94,334 +157,137 @@ func TestLFM2Renderer(t *testing.T) {
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
name: "arbitrary_roles_are_rendered_verbatim",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "developer", Content: "Do X"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
expected: "<|startoftext|><|im_start|>developer\nDo X<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with content and tool calls",
|
||||
name: "empty_messages_still_add_generation_prompt",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: nil,
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|startoftext|><|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant_prefill_no_generation_prompt",
|
||||
renderer: &LFM2Renderer{IsThinking: false},
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tool response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{Role: "assistant", Content: "Let me check."},
|
||||
{Role: "tool", Content: "22C, Sunny"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Get weather for Paris and London"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tools without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get time",
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "user-tool sequence",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "full tool call cycle",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "assistant", Content: "Let me check"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "It's 22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "你好世界! مرحبا 🌍"},
|
||||
{Role: "assistant", Content: "Hello! 👋"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "newlines in content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
|
||||
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "empty assistant content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: ""},
|
||||
{Role: "user", Content: "OK"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Generation prompt does NOT include <think> - model outputs it
|
||||
name: "generation prompt has no think tag",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think hard"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Interleaved: thinking before tool call - last assistant preserves thinking
|
||||
name: "thinking before tool call (last assistant)",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>I need to check the weather</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tool calls - first has thinking stripped
|
||||
name: "two assistants with tools: first thinking stripped",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tools - both preserved when thinking enabled
|
||||
name: "two assistants with tools: both preserved when thinking enabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Content before thinking before tool call
|
||||
name: "content then thinking then tool call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.<think>Using weather API</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nHello",
|
||||
},
|
||||
}
|
||||
|
||||
renderer := &LFM2Renderer{IsThinking: true}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
rendered, err := tt.renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
|
||||
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
|
||||
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Renderer_Images(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer *LFM2Renderer
|
||||
message api.Message
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single_image_default_placeholder",
|
||||
renderer: &LFM2Renderer{},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n<image>Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple_images_default_placeholder",
|
||||
renderer: &LFM2Renderer{},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "Describe these images.",
|
||||
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n<image><image>Describe these images.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "single_image_img_tag_placeholder",
|
||||
renderer: &LFM2Renderer{useImgTags: true},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n[img]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "existing_indexed_img_placeholder_not_duplicated",
|
||||
renderer: &LFM2Renderer{useImgTags: true},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "[img-0]Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "existing_template_image_placeholder_not_duplicated",
|
||||
renderer: &LFM2Renderer{},
|
||||
message: api.Message{
|
||||
Role: "user",
|
||||
Content: "<image>Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n<image>Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.renderer.Render([]api.Message{tt.message}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, got); diff != "" {
|
||||
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLFM2Renderer_JSONFormatting(t *testing.T) {
|
||||
tool := api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "echo",
|
||||
Description: "<html>",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := lfm2JSON(tool)
|
||||
want := "{\"type\": \"function\", \"function\": {\"name\": \"echo\", \"description\": \"<html>\", \"parameters\": {\"type\": \"object\", \"properties\": null}}}"
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Fatalf("lfm2JSON mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,9 +85,9 @@ func rendererForName(name string) Renderer {
|
||||
case "glm-ocr":
|
||||
return &GlmOcrRenderer{}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false}
|
||||
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Renderer{IsThinking: true}
|
||||
return &LFM2Renderer{IsThinking: true, useImgTags: RenderImgTags}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user