feat(model): add qwen3vl (#12665)

This commit is contained in:
Michael Yang
2025-10-28 17:39:47 -07:00
committed by GitHub
parent 36d64fb531
commit 7d25b9e194
22 changed files with 1502 additions and 35 deletions

View File

@@ -198,6 +198,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":

157
convert/convert_qwen3.go Normal file
View File

@@ -0,0 +1,157 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
type qwen3Model struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
NumExperts uint32 `json:"num_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NormTopkProb bool `json:"norm_topk_prob"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
MropeSection []int32 `json:"mrope_section"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"
}
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = arch
kv["block_count"] = q.HiddenLayers
kv["context_length"] = q.MaxPositionEmbeddings
kv["embedding_length"] = q.HiddenSize
kv["feed_forward_length"] = q.IntermediateSize
kv["attention.head_count"] = q.NumAttentionHeads
kv["attention.head_count_kv"] = q.NumKeyValueHeads
kv["attention.key_length"] = q.HeadDim
kv["attention.value_length"] = q.HeadDim
if q.NumExperts > 0 {
kv["expert_count"] = q.NumExperts
kv["expert_used_count"] = q.NumExpertsPerToken
kv["norm_top_k_prob"] = q.NormTopkProb
}
kv["rope.freq_base"] = q.RopeTheta
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
switch q.RopeScaling.Type {
case "":
// no scaling
case "yarn":
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
case "mrope", "default":
kv["rope.mrope_section"] = q.RopeScaling.MropeSection
default:
panic("unknown rope scaling type")
}
return kv
}
// Tensors implements ModelConverter.
func (q *qwen3Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// TODO: handle split experts
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "ffn_gate_up_exps"):
afterFunc := func(t tensor.Tensor) (tensor.Tensor, error) { return tensor.Transpose(t, 0, 2, 1) }
for t := range splitDim(t, 2,
split{Replacer: strings.NewReplacer("gate_up", "gate"), afterFunc: afterFunc},
split{Replacer: strings.NewReplacer("gate_up", "up"), afterFunc: afterFunc},
) {
t.Shape[1], t.Shape[2] = t.Shape[2], t.Shape[1]
out = append(out, t)
}
case strings.Contains(t.Name(), "ffn_down_exps"):
shape := slices.Clone(t.Shape())
shape[1], shape[2] = shape[2], shape[1]
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tensor.Transpose(tt, 0, 2, 1)
if err != nil {
return nil, err
}
// flatten tensor so it can be written as a vector
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
default:
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
// Replacements implements ModelConverter.
func (q *qwen3Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.experts.down_proj", "ffn_down_exps.weight",
"mlp.experts.gate_up_proj", "ffn_gate_up_exps.weight",
"post_attention_layernorm", "ffn_norm",
"model.norm", "output_norm",
}
}
var _ ModelConverter = (*qwen3Model)(nil)

116
convert/convert_qwen3vl.go Normal file
View File

@@ -0,0 +1,116 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type qwen3VLModel struct {
qwen3Model `json:"text_config"`
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_channels"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
} `json:"vision_config"`
}
func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
if err != nil {
return err
}
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"
if m.NumExperts > 0 {
arch += "moe"
}
// override architecture
kv["general.architecture"] = arch
kv["vision.block_count"] = cmp.Or(m.VisionModel.Depth, 32)
kv["vision.embedding_length"] = m.VisionModel.HiddenSize
kv["vision.attention.head_count"] = cmp.Or(m.VisionModel.NumHeads, 16)
kv["vision.num_channels"] = m.VisionModel.InChannels
kv["vision.patch_size"] = cmp.Or(m.VisionModel.PatchSize, 14)
kv["vision.spatial_merge_size"] = cmp.Or(m.VisionModel.SpatialMergeSize, 2)
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(m.VisionModel.RMSNormEps, 1e-6)
kv["vision.rope.freq_base"] = cmp.Or(m.VisionModel.RopeTheta, 1e4)
kv["vision.temporal_patch_size"] = cmp.Or(m.VisionModel.TemporalPatchSize, 2)
kv["vision.deepstack_visual_indexes"] = m.VisionModel.DeepstackVisualIndexes
kv["vision.shortest_edge"] = m.VisionModel.Size.ShortestEdge
kv["vision.longest_edge"] = m.VisionModel.Size.LongestEdge
kv["vision.image_mean"] = m.VisionModel.ImageMean
kv["vision.image_std"] = m.VisionModel.ImageStd
return kv
}
func (m *qwen3VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var rest []Tensor
var out []*ggml.Tensor
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "attn_qkv"):
out = append(out, slices.Collect(splitDim(t, 0,
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
))...)
case strings.Contains(t.Name(), "patch_embed") && strings.HasSuffix(t.Name(), "weight"):
shape := t.Shape()
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
WriterTo: t,
})
default:
rest = append(rest, t)
}
}
return append(m.qwen3Model.Tensors(rest), out...)
}
func (m *qwen3VLModel) Replacements() []string {
return append(
m.qwen3Model.Replacements(),
"model.language_", "",
"model.visual", "v",
"patch_embed.proj", "patch_embed",
"blocks", "blk",
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"deepstack_merger_list", "deepstack_merger",
)
}

View File

@@ -19,8 +19,8 @@ type split struct {
dim int
slices []tensor.Slice
// fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error)
// afterFunc is an optional function to apply to the tensor after slicing
afterFunc func(tensor.Tensor) (tensor.Tensor, error)
}
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
@@ -54,8 +54,8 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
tt = tensor.Materialize(tt)
if split.fn != nil {
tt, err = split.fn(tt)
if split.afterFunc != nil {
tt, err = split.afterFunc(tt)
if err != nil {
return nil, err
}

View File

@@ -432,7 +432,7 @@ func TestSplitDim(t *testing.T) {
t.Run("split with transpose", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
split{Replacer: strings.NewReplacer("b", "y"), afterFunc: func(tt tensor.Tensor) (tensor.Tensor, error) {
return tensor.Transpose(tt, 1, 0)
}},
))

View File

@@ -242,13 +242,13 @@ func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",
"gemma3n",
"mistral3",
"qwen3",
"qwen3moe",
"gptoss", "gpt-oss",
"llama4",
"mistral3",
"mllama",
"qwen25vl",
"gptoss", "gpt-oss",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
}, kv.Architecture())
}

View File

@@ -26,6 +26,13 @@ func TestVisionModels(t *testing.T) {
{
model: "gemma3",
},
{
model: "qwen3-vl:8b",
},
{
// Qwen 3 VL mixture of experts
model: "qwen3-vl:30b",
},
}
for _, v := range testCases {

View File

@@ -161,6 +161,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1182,6 +1182,10 @@ func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
}
func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor {
if slices.Contains(shape, -1) {
inferShape(t, shape)
}
switch len(shape) {
case 0:
return &Tensor{
@@ -1324,7 +1328,43 @@ func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
// inferShape updates shape in place to automatically set a single -1 dimesion
// based on the input tensor and the other dimensions
func inferShape(t *Tensor, shape []int) {
total := 1
for _, dim := range t.Shape() {
total *= dim
}
dim := -1
for i := range shape {
switch shape[i] {
case -1:
if dim != -1 {
panic("only one dimension can be inferred")
}
dim = i
case 0:
panic("dimension cannot be zero")
default:
if total%shape[i] != 0 {
panic("cannot infer dimension")
}
total /= shape[i]
}
}
if dim != -1 {
shape[dim] = total
}
}
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
if slices.Contains(shape, -1) {
inferShape(t, shape)
}
switch len(shape) {
case 1:
return &Tensor{
@@ -1537,6 +1577,16 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
var tt ml.Tensor = &Tensor{
b: t.b,
t: C.ggml_conv_3d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int64_t(c), C.int(s0), C.int(s1), C.int(s2), C.int(p0), C.int(p1), C.int(p2), C.int(d0), C.int(d1), C.int(d2)),
}
tt = tt.Reshape(ctx, t.Dim(3)/c, t2.Dim(3)/c)
return tt
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -0,0 +1,126 @@
package ggml
import (
"errors"
"os"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
)
func setup(tb testing.TB) ml.Context {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
tb.Fatal(err)
}
b, err := ml.NewBackend(f.Name(), ml.BackendParams{})
if err != nil {
tb.Fatal(err)
}
ctx := b.NewContext().Input()
tb.Cleanup(func() {
ctx.Close()
b.Close()
})
return ctx
}
func TestInferShape(t *testing.T) {
cases := []struct {
name string
input []int
want []int
err error
}{
{
name: "no inferred shape",
input: []int{2, 3, 4},
want: []int{2, 3, 4},
},
{
name: "infer begin",
input: []int{-1, 3, 4},
want: []int{2, 3, 4},
},
{
name: "infer mid",
input: []int{2, -1, 4},
want: []int{2, 3, 4},
},
{
name: "infer end",
input: []int{2, 3, -1},
want: []int{2, 3, 4},
},
{
name: "too many inferred dims",
input: []int{-1, 3, -1},
err: errors.New("only one dimension can be inferred"),
},
{
name: "infer gather",
input: []int{2, -1},
want: []int{2, 12},
},
{
name: "infer gather all",
input: []int{-1},
want: []int{24},
},
{
name: "infer split",
input: []int{2, -1, 3, 2},
want: []int{2, 2, 3, 2},
},
{
name: "indivisible infer",
input: []int{2, -1, 2, 4},
err: errors.New("cannot infer dimension"),
},
{
name: "infer zero dim",
input: []int{2, 0, 4},
err: errors.New("dimension cannot be zero"),
},
}
ctx := setup(t)
tensor, ok := ctx.Empty(ml.DTypeF32, 2, 3, 4).(*Tensor)
if !ok {
t.Fatal("expected *Tensor")
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if r := recover(); r == nil && tt.err == nil {
// all good
} else if r != nil && tt.err == nil {
t.Errorf("unexpected panic: %v", r)
} else if r == nil && tt.err != nil {
t.Errorf("expected panic but did not get one: %v", tt.err)
} else if errStr, ok := r.(string); ok && errStr != tt.err.Error() {
t.Errorf("expected panic %q but got %q", tt.err.Error(), errStr)
}
}()
inferShape(tensor, tt.input)
if diff := cmp.Diff(tt.want, tt.input); diff != "" {
t.Errorf("%s: shape mismatch (-want +got):\n%s", tt.name, diff)
}
})
}
}

View File

@@ -4,8 +4,26 @@ import "github.com/ollama/ollama/ml"
type Conv2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1)
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}
type Conv3D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
t = m.Weight.Conv3D(ctx, t, c, s0, s1, s2, p0, p1, p2, d0, d1, d2)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}

View File

@@ -14,4 +14,5 @@ import (
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"
_ "github.com/ollama/ollama/model/models/qwen3vl"
)

View File

@@ -3,6 +3,7 @@ package qwen3
import (
"cmp"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -210,7 +211,7 @@ var _ model.Model = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
for i := range layers {
if c.String("general.architecture") == "qwen3moe" {
if strings.HasSuffix(c.String("general.architecture"), "moe") {
layers[i].MLP = &sparse{}
} else {
layers[i].MLP = &dense{}

View File

@@ -0,0 +1,194 @@
package qwen3vl
import (
"fmt"
"image"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
// ImageProcessor contains configuration for the Qwen 3 VL image processing
type ImageProcessor struct {
numChannels int
patchSize int
temporalPatchSize int
mergeSize int
shortestEdge int
longestEdge int
factor int
rescaleFactor float32
imageMean []float32
imageStd []float32
}
// newImageProcessor creates a new image processor with default values
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
return ImageProcessor{
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
patchSize: patchSize,
temporalPatchSize: 2,
mergeSize: mergeSize,
shortestEdge: int(c.Uint("vision.shortest_edge", 64<<10)),
// FIXME(mxyng): the model defined longest edge (16M) is too large for the default
// context length of 8K and will panic. Adjusting to 2M for now.
// longestEdge: int(c.Uint("vision.longest_edge", 16<<20)),
longestEdge: 2 << 20,
factor: patchSize * mergeSize,
rescaleFactor: 1.0 / 255.0,
imageMean: c.Floats("vision.image_mean", imageproc.ImageNetStandardMean[:]),
imageStd: c.Floats("vision.image_std", imageproc.ImageNetStandardSTD[:]),
}
}
// SmartResize implements the smart resize algorithm
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
if height < factor || width < factor {
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
} else if aspectRatio := max(height, width) / min(height, width); aspectRatio > 200 {
panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %v", aspectRatio))
}
round := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
if hBar*wBar > p.longestEdge {
beta := math.Sqrt(float64(height*width) / float64(p.longestEdge))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if hBar*wBar < p.shortestEdge {
beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
type Grid struct {
Height int
Width int
Temporal int
}
func (p *ImageProcessor) ProcessImage(ctx ml.Context, img image.Image) (ml.Tensor, *Grid, error) {
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image using existing functions
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
normalizedPixels := imageproc.Normalize(
resizedImg,
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
true, // rescale
true, // channelFirst
)
// Calculate grid dimensions
grid := &Grid{
Height: resizedHeight / p.patchSize,
Width: resizedWidth / p.patchSize,
Temporal: 1, // For single images, temporal dimension is 1
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
}
patchDim := p.numChannels * p.temporalPatchSize *
p.patchSize * p.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues := ctx.Input().FromFloats(patches, patchDim, numPatches)
// Return patches and grid dimensions
return pixelValues, grid, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := p.numChannels
patchSize := p.patchSize
mergeSize := p.mergeSize
temporalPatchSize := p.temporalPatchSize
// Calculate output dimensions
numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
// Single temporal frame handling (copies to all frames)
for range grid.Temporal {
for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize {
// Handle the 2x2 merged patches
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
// Extract patch data for first temporal frame
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for py := range patchSize {
for px := range patchSize {
// Calculate source pixel coordinates
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
// Source index in input tensor (CHW format)
srcIdx := c*height*width + y*width + x
// Destination index in first temporal frame
dstIdx := channelOffset + (py * patchSize) + px
if srcIdx < len(pixels) && dstIdx < len(result) {
result[dstIdx] = pixels[srcIdx]
}
}
}
}
// Copy first temporal frame to all other frames
if temporalPatchSize > 1 {
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
firstFrameOffset := channelOffset
frameSize := patchSize * patchSize
// Copy first frame to all other frames
for tp := 1; tp < temporalPatchSize; tp++ {
currentFrameOffset := channelOffset + (tp * frameSize)
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
result[firstFrameOffset:firstFrameOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@@ -0,0 +1,204 @@
package qwen3vl
import (
"bytes"
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
*TextModel
*VisionModel `gguf:"v"`
ImageProcessor
positionCache []int32
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
pixelValues, grid, err := m.ProcessImage(ctx, img)
if err != nil {
return nil, err
}
// Calculate tensor dimensions
visionOutputs, deepstackVisualEmbeds := m.VisionModel.Forward(ctx, pixelValues, grid)
mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}}
for i := range deepstackVisualEmbeds {
mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]})
}
return mm, nil
}
var (
tokenVision int32 = 151655
tokenVisionStart int32 = 151652
tokenVisionEnd int32 = 151653
)
type modelInput struct {
*input.Input
position int32
}
// PostTokenize arranges Qwen 3 VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
m.positionCache = m.positionCache[:0]
return slices.Collect(func(yield func(*input.Input) bool) {
for i := range inputs {
s := []modelInput{{Input: inputs[i]}}
if mm := inputs[i].Multimodal; mm != nil {
t := mm[0].Tensor
s = slices.Repeat([]modelInput{
{
position: int32(i + 1),
Input: &input.Input{Token: tokenVision},
},
}, t.Dim(1)+1+1)
s[0] = modelInput{
Input: &input.Input{Token: tokenVisionStart},
position: int32(i),
}
s[len(s)-1] = modelInput{
Input: &input.Input{Token: tokenVisionEnd},
position: int32(i + mm[0].Data.(*Grid).Width/m.spatialMergeSize + 1),
}
s[1] = modelInput{
Input: &input.Input{
Token: tokenVision,
Multimodal: inputs[i].Multimodal,
MultimodalHash: inputs[i].MultimodalHash,
SameBatch: t.Dim(1),
},
position: int32(i + 1),
}
}
for _, e := range s {
position := e.position
if position == 0 && len(m.positionCache) > 0 {
position = m.positionCache[len(m.positionCache)-1] + 1
}
m.positionCache = append(m.positionCache, position)
if !yield(e.Input) {
return
}
}
}
}), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positionSlice := slices.Collect(makeSlice2D[int32](3, len(batch.Positions)))
for i, id := range batch.Positions {
if id < int32(len(m.positionCache)) {
id = m.positionCache[id]
} else if len(m.positionCache) > 0 {
id = id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
}
positionSlice[0][i] = id
positionSlice[1][i] = id
positionSlice[2][i] = id
}
hiddenStates := m.TextModel.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
var deepstackVisualEmbeds []ml.Tensor
for _, mi := range batch.Multimodal {
visionOutputs := mi.Multimodal[0].Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
for i := range visionOutputs.Dim(1) {
w := grid.Width / m.spatialMergeSize
positionSlice[1][mi.Index+i] += int32(i / w)
positionSlice[2][mi.Index+i] += int32(i % w)
}
}
deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:]))
for i, mm := range mi.Multimodal[1:] {
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
}
}
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0]), len(positionSlice))
cos, sin := m.rotaryEmbedding(ctx, positions)
for i, layer := range m.TextModel.Layers {
if m.Cache != nil {
m.Cache.SetLayer(i)
}
var outputs ml.Tensor
if i == len(m.TextModel.Layers)-1 {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, outputs, m.Cache, m.Options)
if i < len(deepstackVisualEmbeds) {
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, 1e-06)
return m.Output.Forward(ctx, hiddenStates), nil
}
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
}
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, position ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
return nil, kvcache.ErrNotSupported
})
return &m, nil
}
func init() {
model.Register("qwen3vl", New)
model.Register("qwen3vlmoe", New)
}

View File

@@ -0,0 +1,229 @@
package qwen3vl
import (
"cmp"
"math"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type TextOptions struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength int
eps,
ropeBase,
ropeScale float32
mropeSections []int
numExperts, numExpertsUsed int
normTopKProb bool
inverseFrequenciesCache []float32
}
func (o TextOptions) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}
type TextMLP interface {
Forward(ml.Context, ml.Tensor, *TextOptions) ml.Tensor
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
routerLogits := mlp.Router.Forward(ctx, hiddenStates)
routingWeights := routerLogits.Softmax(ctx)
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts)
if opts.normTopKProb {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
*TextAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
TextMLP
}
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.TextAttention.Forward(ctx, hiddenStates, cos, sin, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.TextMLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []TextLayer `gguf:"blk"`
Options *TextOptions
}
func (m *TextModel) rotaryEmbedding(ctx ml.Context, positions ml.Tensor) (_, _ ml.Tensor) {
positions = positions.Reshape(ctx, 1, positions.Dim(0), positions.Dim(1))
if len(m.Options.inverseFrequenciesCache) == 0 {
m.Options.inverseFrequenciesCache = make([]float32, m.Options.headDim()/2)
for i := range m.Options.inverseFrequenciesCache {
frequency := float32(math.Pow(float64(m.Options.ropeBase), float64(i*2)/float64(m.Options.headDim())))
m.Options.inverseFrequenciesCache[i] = 1 / frequency
}
}
inverseFrequencies := ctx.Input().FromFloats(m.Options.inverseFrequenciesCache, 1, len(m.Options.inverseFrequenciesCache))
positions = positions.Cast(ctx, ml.DTypeF32)
frequencies := inverseFrequencies.Mulmat(ctx, positions)
interleaved := frequencies.View(ctx,
0, frequencies.Dim(0),
frequencies.Stride(1), frequencies.Dim(1),
)
for _, i := range []int{1, 2} {
args := []int{
i * frequencies.Stride(0), 1,
3 * frequencies.Stride(0), m.Options.mropeSections[i],
frequencies.Stride(1), frequencies.Dim(1),
}
ctx.Forward(frequencies.View(ctx, i*frequencies.Stride(2)+args[0], args[1:]...).
Copy(ctx, interleaved.View(ctx, args[0], args[1:]...)))
}
interleaved = interleaved.Concat(ctx, interleaved, 0)
interleaved = interleaved.Reshape(ctx, interleaved.Dim(0), 1, interleaved.Dim(1), interleaved.Dim(2))
return interleaved.Cos(ctx), interleaved.Sin(ctx)
}
var _ model.Model = (*Model)(nil)
func newTextModel(c fs.Config) *TextModel {
layers := make([]TextLayer, c.Uint("block_count"))
for i := range layers {
if strings.HasSuffix(c.String("general.architecture"), "moe") {
layers[i].TextMLP = &sparse{}
} else {
layers[i].TextMLP = &dense{}
}
}
m := TextModel{
Layers: layers,
Options: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range c.Ints("mrope_sections", []int32{24, 20, 20}) {
if !yield(int(section)) {
return
}
}
}),
},
}
return &m
}

View File

@@ -0,0 +1,268 @@
package qwen3vl
import (
"iter"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type VisionAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1))
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key := sa.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1))
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
value := sa.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"linear_fc1"`
FC2 *nn.Linear `gguf:"linear_fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts VisionOptions) ml.Tensor {
return mlp.FC2.Forward(ctx, mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx))
}
type VisionEncoderLayer struct {
Norm1 *nn.LayerNorm `gguf:"norm1"`
Attention *VisionAttention
Norm2 *nn.LayerNorm `gguf:"norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.Attention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type VisionOptions struct {
hiddenSize,
numHeads,
patchSize,
numChannels,
spatialMergeSize,
temporalPatchSize,
gridPerSide int
eps,
ropeTheta float32
deepstackVisualIndexes []int32
mropeSections []int
}
func (o VisionOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type VisionPatchMerger struct {
Norm *nn.LayerNorm `gguf:"norm"`
FC1 *nn.Linear `gguf:"linear_fc1"`
FC2 *nn.Linear `gguf:"linear_fc2"`
}
func (m *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, postshuffleNorm bool, opts VisionOptions) ml.Tensor {
hiddenSize := opts.hiddenSize * opts.spatialMergeSize * opts.spatialMergeSize
if postshuffleNorm {
visionOutputs = visionOutputs.Reshape(ctx, hiddenSize, -1)
}
visionOutputs = m.Norm.Forward(ctx, visionOutputs, opts.eps)
visionOutputs = visionOutputs.Reshape(ctx, hiddenSize, -1)
return m.FC2.Forward(ctx, m.FC1.Forward(ctx, visionOutputs).GELU(ctx))
}
type VisionPositionEmbedding struct {
PositionEmbedding *nn.Embedding `gguf:"pos_embed"`
}
func makeSlice2D[T int32 | float32](n0, n1 int) iter.Seq[[]T] {
return func(yield func([]T) bool) {
for range n0 {
if !yield(make([]T, n1)) {
return
}
}
}
}
func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts VisionOptions) ml.Tensor {
indexSlice := slices.Collect(makeSlice2D[int32](4, grid.Height*grid.Width))
weightSlice := slices.Collect(makeSlice2D[float32](4, grid.Height*grid.Width))
stepHeight := float32(opts.gridPerSide-1) / float32(grid.Height-1)
stepWidth := float32(opts.gridPerSide-1) / float32(grid.Width-1)
var i int
for h := range grid.Height {
for w := range grid.Width {
y, x := float32(h)*stepHeight, float32(w)*stepWidth
floorY, floorX := int32(y), int32(x)
ceilY, ceilX := min(floorY+1, int32(opts.gridPerSide-1)), min(floorX+1, int32(opts.gridPerSide-1))
indexSlice[0][i] = floorY*int32(opts.gridPerSide) + floorX
indexSlice[1][i] = floorY*int32(opts.gridPerSide) + ceilX
indexSlice[2][i] = ceilY*int32(opts.gridPerSide) + floorX
indexSlice[3][i] = ceilY*int32(opts.gridPerSide) + ceilX
weightSlice[0][i] = (1 - (y - float32(floorY))) * (1 - (x - float32(floorX)))
weightSlice[1][i] = (1 - (y - float32(floorY))) * (x - float32(floorX))
weightSlice[2][i] = (y - float32(floorY)) * (1 - (x - float32(floorX)))
weightSlice[3][i] = (y - float32(floorY)) * (x - float32(floorX))
i++
}
}
indices := ctx.Input().FromInts(slices.Concat(indexSlice...), grid.Height*grid.Width*4)
weights := ctx.Input().FromFloats(slices.Concat(weightSlice...), 1, grid.Height*grid.Width*4)
n := hiddenStates.Dim(0)
positionEmbeds := m.PositionEmbedding.Forward(ctx, indices)
positionEmbeds = positionEmbeds.Mul(ctx, weights)
positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4)
positionEmbeds = positionEmbeds.View(ctx, 0, n, positionEmbeds.Stride(1), grid.Height*grid.Width).
Add(ctx, positionEmbeds.View(ctx, 1*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)).
Add(ctx, positionEmbeds.View(ctx, 2*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)).
Add(ctx, positionEmbeds.View(ctx, 3*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width))
positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize)
positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1)
return hiddenStates.Add(ctx, positionEmbeds)
}
type VisionModel struct {
PatchEmbedding *nn.Conv3D `gguf:"patch_embed"`
PositionEmbedding *VisionPositionEmbedding
Layers []VisionEncoderLayer `gguf:"blk"`
PatchMerger *VisionPatchMerger `gguf:"merger"`
DeepstackMerger []*VisionPatchMerger `gguf:"deepstack_merger"`
VisionOptions
}
func (m *VisionModel) positions(ctx ml.Context, grid *Grid) (_, _ ml.Tensor) {
indices := ctx.Input().FromInts(slices.Collect(func(yield func(int32) bool) {
for y := range grid.Height {
for x := range grid.Width {
if !yield(int32(y)) {
return
}
if !yield(int32(x)) {
return
}
}
}
}), grid.Width*grid.Height*2)
indices = indices.Reshape(ctx, -1, grid.Width/m.spatialMergeSize, m.spatialMergeSize, grid.Height/m.spatialMergeSize)
indices = indices.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
indices = indices.Reshape(ctx, -1)
halfDim := m.headDim() / 2
maxGrid := max(grid.Height, grid.Width)
frequencies := ctx.Input().FromFloats(slices.Collect(func(yield func(float32) bool) {
ropeTheta := float64(m.ropeTheta)
for i := range maxGrid {
for j := range halfDim / 2 {
if !yield(float32(i) / float32(math.Pow(ropeTheta, float64(j*2)/float64(halfDim)))) {
return
}
}
}
}), halfDim/2, maxGrid)
embeds := frequencies.Rows(ctx, indices)
embeds = embeds.Reshape(ctx, halfDim, 1, -1)
embeds = embeds.Concat(ctx, embeds, 0)
return embeds.Cos(ctx), embeds.Sin(ctx)
}
// Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) (ml.Tensor, []ml.Tensor) {
pixelValues = pixelValues.Reshape(ctx, m.patchSize, m.patchSize, m.temporalPatchSize, -1)
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.numChannels, m.patchSize, m.patchSize, m.temporalPatchSize, 0, 0, 0, 1, 1, 1)
hiddenStates = m.PositionEmbedding.Forward(ctx, hiddenStates, grid, m.VisionOptions)
cos, sin := m.positions(ctx, grid)
deepstackStates := make([]ml.Tensor, len(m.deepstackVisualIndexes))
for i, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions)
if i := slices.Index(m.deepstackVisualIndexes, int32(i)); i >= 0 {
deepstackStates[i] = m.DeepstackMerger[i].Forward(ctx, hiddenStates, true, m.VisionOptions)
}
}
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, false, m.VisionOptions)
return hiddenStates, deepstackStates
}
// newVisionModel creates a new instance of the Qwen vision model
func newVisionModel(c fs.Config) *VisionModel {
deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes")
model := &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
DeepstackMerger: make([]*VisionPatchMerger, len(deepstackVisualIndexes)),
VisionOptions: VisionOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1280)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
ropeTheta: c.Float("vision.rope.freq_base", 10000.0),
spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)),
temporalPatchSize: int(c.Uint("vision.temporal_patch_size", 2)),
gridPerSide: int(math.Sqrt(float64(c.Uint("vision.num_positional_embeddings", 2304)))),
mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range c.Ints("mrope_sections", []int32{24, 20, 20}) {
if !yield(int(section)) {
return
}
}
}),
deepstackVisualIndexes: deepstackVisualIndexes,
},
}
return model
}

View File

@@ -235,15 +235,28 @@ func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
return count
}
// TODO(jessegross): If we need to reprocess the inputs we should ensure that
// we don't split up a SameBatch
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
// ShiftDiscard computes how many inputs can be discarded from the cache. Inputs in the same batch
// are discarded together.
func (c *InputCache) ShiftDiscard(inputs []*input.Input, numKeep int32) int32 {
targetFree := max((c.numCtx-numKeep)/2, 1)
currentFree := c.numCtx - int32(len(inputs))
currentFree := c.numCtx - inputLen
var discard, sameBatch int32
for _, input := range inputs[numKeep:] {
if sameBatch <= 0 && currentFree >= targetFree {
break
}
return max(targetFree-currentFree, 0)
sameBatch--
currentFree++
discard++
if input.SameBatch > 0 {
sameBatch = int32(input.SameBatch)
}
}
return discard
}
type ErrReprocessInputs struct {
@@ -264,7 +277,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
}
inputLen := int32(len(slot.Inputs))
discard := c.ShiftDiscard(inputLen, numKeep)
discard := c.ShiftDiscard(slot.Inputs, numKeep)
if discard <= 0 {
return nil

View File

@@ -3,6 +3,7 @@ package ollamarunner
import (
"errors"
"fmt"
"slices"
"testing"
"time"
@@ -238,59 +239,137 @@ func TestShiftDiscard(t *testing.T) {
name string
numCtx int32
numKeep int32
inputLen int32
inputs []*input.Input
expected int32
}{
{
name: "Shift",
numCtx: 2048,
numKeep: 5,
inputLen: 2048,
inputs: slices.Repeat([]*input.Input{{}}, 2048),
expected: 1021,
},
{
name: "Max Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 2048,
inputs: slices.Repeat([]*input.Input{{}}, 2048),
expected: 1,
},
{
name: "No Keep",
numCtx: 2048,
numKeep: 0,
inputLen: 2048,
inputs: slices.Repeat([]*input.Input{{}}, 2048),
expected: 1024,
},
{
name: "Truncate",
numCtx: 2048,
numKeep: 5,
inputLen: 5000,
inputs: slices.Repeat([]*input.Input{{}}, 5000),
expected: 3973,
},
{
name: "Truncate Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 5000,
inputs: slices.Repeat([]*input.Input{{}}, 5000),
expected: 2953,
},
{
name: "No Op",
numCtx: 2048,
numKeep: 5,
inputLen: 512,
inputs: slices.Repeat([]*input.Input{{}}, 512),
expected: 0,
},
{
name: "Same Batch",
numCtx: 2048,
numKeep: 5,
inputs: slices.Collect(func(yield func(*input.Input) bool) {
for range 1024 {
if !yield(&input.Input{}) {
return
}
}
if !yield(&input.Input{SameBatch: 512 - 1}) {
return
}
for range 2048 - 1024 - 1 {
if !yield(&input.Input{}) {
return
}
}
}),
expected: 1531,
},
{
name: "Same Batch Near Start",
numCtx: 2048,
numKeep: 5,
inputs: slices.Collect(func(yield func(*input.Input) bool) {
for range 10 {
if !yield(&input.Input{}) {
return
}
}
if !yield(&input.Input{SameBatch: 512 - 1}) {
return
}
for range 2048 - 10 - 1 {
if !yield(&input.Input{}) {
return
}
}
}),
expected: 1021,
},
{
name: "Consecutive Same Batch",
numCtx: 32,
inputs: slices.Collect(func(yield func(*input.Input) bool) {
for i := range 32 {
input := input.Input{}
if i%10 == 0 {
input.SameBatch = 10 - 1
}
if !yield(&input) {
return
}
}
}),
expected: 20,
},
{
name: "Overlapping Same Batch",
numCtx: 32,
inputs: slices.Collect(func(yield func(*input.Input) bool) {
for i := range 32 {
input := input.Input{}
if slices.Contains([]int{4, 8, 14}, i) {
input.SameBatch = 10 - 1
}
if !yield(&input) {
return
}
}
}),
expected: 24,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := InputCache{numCtx: tt.numCtx}
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
result := c.ShiftDiscard(tt.inputs, tt.numKeep)
if result != tt.expected {
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
t.Errorf("shiftDiscard(ctx: %v, keep: %v inputs: %v): have %v; want %v", tt.numCtx, tt.numKeep, len(tt.inputs), result, tt.expected)
}
})
}

View File

@@ -214,7 +214,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
parts = []string{prompt}
}
postTokenize := false
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
@@ -257,11 +256,10 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
if visionModel && postTokenize {
if visionModel {
var err error
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {

View File

@@ -142,7 +142,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
// This model is much more capable with a larger context, so set that
// unless it would penalize performance too much
if !s.lowVRAM && slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
if !s.lowVRAM && slices.Contains([]string{
"gptoss", "gpt-oss",
"qwen3vl", "qwen3vlmoe",
}, model.Config.ModelFamily) {
opts.NumCtx = max(opts.NumCtx, 8192)
}

View File

@@ -390,11 +390,11 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
numParallel = 1
}
// `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1
// `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains(req.model.Config.ModelFamilies, "mllama") && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("mllama does not currently support parallel requests")
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}
sessionDuration := envconfig.KeepAlive()