mirror of
https://github.com/ollama/ollama.git
synced 2025-12-05 19:16:53 -06:00
feat(model): add qwen3vl (#12665)
This commit is contained in:
@@ -198,6 +198,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &qwen2Model{}
|
||||
case "Qwen2_5_VLForConditionalGeneration":
|
||||
conv = &qwen25VLModel{}
|
||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||
conv = &qwen3VLModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
|
||||
157
convert/convert_qwen3.go
Normal file
157
convert/convert_qwen3.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
type qwen3Model struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NormTopkProb bool `json:"norm_topk_prob"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
MropeSection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_scaling"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
|
||||
arch := "qwen3"
|
||||
if q.NumExperts > 0 {
|
||||
arch += "moe"
|
||||
}
|
||||
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = arch
|
||||
kv["block_count"] = q.HiddenLayers
|
||||
kv["context_length"] = q.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = q.HiddenSize
|
||||
kv["feed_forward_length"] = q.IntermediateSize
|
||||
kv["attention.head_count"] = q.NumAttentionHeads
|
||||
kv["attention.head_count_kv"] = q.NumKeyValueHeads
|
||||
kv["attention.key_length"] = q.HeadDim
|
||||
kv["attention.value_length"] = q.HeadDim
|
||||
|
||||
if q.NumExperts > 0 {
|
||||
kv["expert_count"] = q.NumExperts
|
||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
||||
}
|
||||
|
||||
kv["rope.freq_base"] = q.RopeTheta
|
||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
|
||||
switch q.RopeScaling.Type {
|
||||
case "":
|
||||
// no scaling
|
||||
case "yarn":
|
||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
case "mrope", "default":
|
||||
kv["rope.mrope_section"] = q.RopeScaling.MropeSection
|
||||
default:
|
||||
panic("unknown rope scaling type")
|
||||
}
|
||||
return kv
|
||||
}
|
||||
|
||||
// Tensors implements ModelConverter.
|
||||
func (q *qwen3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// TODO: handle split experts
|
||||
|
||||
for _, t := range ts {
|
||||
switch {
|
||||
case strings.Contains(t.Name(), "ffn_gate_up_exps"):
|
||||
afterFunc := func(t tensor.Tensor) (tensor.Tensor, error) { return tensor.Transpose(t, 0, 2, 1) }
|
||||
for t := range splitDim(t, 2,
|
||||
split{Replacer: strings.NewReplacer("gate_up", "gate"), afterFunc: afterFunc},
|
||||
split{Replacer: strings.NewReplacer("gate_up", "up"), afterFunc: afterFunc},
|
||||
) {
|
||||
t.Shape[1], t.Shape[2] = t.Shape[2], t.Shape[1]
|
||||
out = append(out, t)
|
||||
}
|
||||
case strings.Contains(t.Name(), "ffn_down_exps"):
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[1], shape[2] = shape[2], shape[1]
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tensor.Transpose(tt, 0, 2, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// flatten tensor so it can be written as a vector
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
default:
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Replacements implements ModelConverter.
|
||||
func (q *qwen3Model) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||
"mlp.experts.down_proj", "ffn_down_exps.weight",
|
||||
"mlp.experts.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"model.norm", "output_norm",
|
||||
}
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen3Model)(nil)
|
||||
116
convert/convert_qwen3vl.go
Normal file
116
convert/convert_qwen3vl.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen3VLModel struct {
|
||||
qwen3Model `json:"text_config"`
|
||||
|
||||
VisionModel struct {
|
||||
Depth uint32 `json:"depth"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
InChannels uint32 `json:"in_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
WindowSize uint32 `json:"window_size"`
|
||||
RMSNormEps float32 `json:"layer_norm_epsilon"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
|
||||
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &m.VisionModel)
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.qwen3Model.KV(t)
|
||||
|
||||
arch := "qwen3vl"
|
||||
if m.NumExperts > 0 {
|
||||
arch += "moe"
|
||||
}
|
||||
// override architecture
|
||||
kv["general.architecture"] = arch
|
||||
|
||||
kv["vision.block_count"] = cmp.Or(m.VisionModel.Depth, 32)
|
||||
kv["vision.embedding_length"] = m.VisionModel.HiddenSize
|
||||
kv["vision.attention.head_count"] = cmp.Or(m.VisionModel.NumHeads, 16)
|
||||
kv["vision.num_channels"] = m.VisionModel.InChannels
|
||||
kv["vision.patch_size"] = cmp.Or(m.VisionModel.PatchSize, 14)
|
||||
kv["vision.spatial_merge_size"] = cmp.Or(m.VisionModel.SpatialMergeSize, 2)
|
||||
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(m.VisionModel.RMSNormEps, 1e-6)
|
||||
kv["vision.rope.freq_base"] = cmp.Or(m.VisionModel.RopeTheta, 1e4)
|
||||
kv["vision.temporal_patch_size"] = cmp.Or(m.VisionModel.TemporalPatchSize, 2)
|
||||
kv["vision.deepstack_visual_indexes"] = m.VisionModel.DeepstackVisualIndexes
|
||||
|
||||
kv["vision.shortest_edge"] = m.VisionModel.Size.ShortestEdge
|
||||
kv["vision.longest_edge"] = m.VisionModel.Size.LongestEdge
|
||||
|
||||
kv["vision.image_mean"] = m.VisionModel.ImageMean
|
||||
kv["vision.image_std"] = m.VisionModel.ImageStd
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var rest []Tensor
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
switch {
|
||||
case strings.Contains(t.Name(), "attn_qkv"):
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
|
||||
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
|
||||
))...)
|
||||
case strings.Contains(t.Name(), "patch_embed") && strings.HasSuffix(t.Name(), "weight"):
|
||||
shape := t.Shape()
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
|
||||
WriterTo: t,
|
||||
})
|
||||
default:
|
||||
rest = append(rest, t)
|
||||
}
|
||||
}
|
||||
|
||||
return append(m.qwen3Model.Tensors(rest), out...)
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) Replacements() []string {
|
||||
return append(
|
||||
m.qwen3Model.Replacements(),
|
||||
"model.language_", "",
|
||||
"model.visual", "v",
|
||||
"patch_embed.proj", "patch_embed",
|
||||
"blocks", "blk",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.proj", "attn_out",
|
||||
"deepstack_merger_list", "deepstack_merger",
|
||||
)
|
||||
}
|
||||
@@ -19,8 +19,8 @@ type split struct {
|
||||
dim int
|
||||
slices []tensor.Slice
|
||||
|
||||
// fn is an optional function to apply to the tensor after slicing
|
||||
fn func(tensor.Tensor) (tensor.Tensor, error)
|
||||
// afterFunc is an optional function to apply to the tensor after slicing
|
||||
afterFunc func(tensor.Tensor) (tensor.Tensor, error)
|
||||
}
|
||||
|
||||
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
|
||||
@@ -54,8 +54,8 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
if split.fn != nil {
|
||||
tt, err = split.fn(tt)
|
||||
if split.afterFunc != nil {
|
||||
tt, err = split.afterFunc(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ func TestSplitDim(t *testing.T) {
|
||||
t.Run("split with transpose", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
||||
split{Replacer: strings.NewReplacer("b", "y"), afterFunc: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
||||
return tensor.Transpose(tt, 1, 0)
|
||||
}},
|
||||
))
|
||||
|
||||
@@ -242,13 +242,13 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"qwen3",
|
||||
"qwen3moe",
|
||||
"gptoss", "gpt-oss",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"qwen25vl",
|
||||
"gptoss", "gpt-oss",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,13 @@ func TestVisionModels(t *testing.T) {
|
||||
{
|
||||
model: "gemma3",
|
||||
},
|
||||
{
|
||||
model: "qwen3-vl:8b",
|
||||
},
|
||||
{
|
||||
// Qwen 3 VL mixture of experts
|
||||
model: "qwen3-vl:30b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, v := range testCases {
|
||||
|
||||
@@ -161,6 +161,7 @@ type Tensor interface {
|
||||
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
|
||||
@@ -1182,6 +1182,10 @@ func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
}
|
||||
|
||||
func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if slices.Contains(shape, -1) {
|
||||
inferShape(t, shape)
|
||||
}
|
||||
|
||||
switch len(shape) {
|
||||
case 0:
|
||||
return &Tensor{
|
||||
@@ -1324,7 +1328,43 @@ func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
// inferShape updates shape in place to automatically set a single -1 dimesion
|
||||
// based on the input tensor and the other dimensions
|
||||
func inferShape(t *Tensor, shape []int) {
|
||||
total := 1
|
||||
for _, dim := range t.Shape() {
|
||||
total *= dim
|
||||
}
|
||||
|
||||
dim := -1
|
||||
for i := range shape {
|
||||
switch shape[i] {
|
||||
case -1:
|
||||
if dim != -1 {
|
||||
panic("only one dimension can be inferred")
|
||||
}
|
||||
dim = i
|
||||
case 0:
|
||||
panic("dimension cannot be zero")
|
||||
default:
|
||||
if total%shape[i] != 0 {
|
||||
panic("cannot infer dimension")
|
||||
}
|
||||
|
||||
total /= shape[i]
|
||||
}
|
||||
}
|
||||
|
||||
if dim != -1 {
|
||||
shape[dim] = total
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if slices.Contains(shape, -1) {
|
||||
inferShape(t, shape)
|
||||
}
|
||||
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
@@ -1537,6 +1577,16 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
|
||||
var tt ml.Tensor = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_conv_3d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int64_t(c), C.int(s0), C.int(s1), C.int(s2), C.int(p0), C.int(p1), C.int(p2), C.int(d0), C.int(d1), C.int(d2)),
|
||||
}
|
||||
|
||||
tt = tt.Reshape(ctx, t.Dim(3)/c, t2.Dim(3)/c)
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
126
ml/backend/ggml/ggml_test.go
Normal file
126
ml/backend/ggml/ggml_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func setup(tb testing.TB) ml.Context {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := b.NewContext().Input()
|
||||
|
||||
tb.Cleanup(func() {
|
||||
ctx.Close()
|
||||
b.Close()
|
||||
})
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func TestInferShape(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input []int
|
||||
want []int
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "no inferred shape",
|
||||
input: []int{2, 3, 4},
|
||||
want: []int{2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "infer begin",
|
||||
input: []int{-1, 3, 4},
|
||||
want: []int{2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "infer mid",
|
||||
input: []int{2, -1, 4},
|
||||
want: []int{2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "infer end",
|
||||
input: []int{2, 3, -1},
|
||||
want: []int{2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "too many inferred dims",
|
||||
input: []int{-1, 3, -1},
|
||||
err: errors.New("only one dimension can be inferred"),
|
||||
},
|
||||
{
|
||||
name: "infer gather",
|
||||
input: []int{2, -1},
|
||||
want: []int{2, 12},
|
||||
},
|
||||
{
|
||||
name: "infer gather all",
|
||||
input: []int{-1},
|
||||
want: []int{24},
|
||||
},
|
||||
{
|
||||
name: "infer split",
|
||||
input: []int{2, -1, 3, 2},
|
||||
want: []int{2, 2, 3, 2},
|
||||
},
|
||||
{
|
||||
name: "indivisible infer",
|
||||
input: []int{2, -1, 2, 4},
|
||||
err: errors.New("cannot infer dimension"),
|
||||
},
|
||||
{
|
||||
name: "infer zero dim",
|
||||
input: []int{2, 0, 4},
|
||||
err: errors.New("dimension cannot be zero"),
|
||||
},
|
||||
}
|
||||
|
||||
ctx := setup(t)
|
||||
tensor, ok := ctx.Empty(ml.DTypeF32, 2, 3, 4).(*Tensor)
|
||||
if !ok {
|
||||
t.Fatal("expected *Tensor")
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil && tt.err == nil {
|
||||
// all good
|
||||
} else if r != nil && tt.err == nil {
|
||||
t.Errorf("unexpected panic: %v", r)
|
||||
} else if r == nil && tt.err != nil {
|
||||
t.Errorf("expected panic but did not get one: %v", tt.err)
|
||||
} else if errStr, ok := r.(string); ok && errStr != tt.err.Error() {
|
||||
t.Errorf("expected panic %q but got %q", tt.err.Error(), errStr)
|
||||
}
|
||||
}()
|
||||
|
||||
inferShape(tensor, tt.input)
|
||||
if diff := cmp.Diff(tt.want, tt.input); diff != "" {
|
||||
t.Errorf("%s: shape mismatch (-want +got):\n%s", tt.name, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,8 +4,26 @@ import "github.com/ollama/ollama/ml"
|
||||
|
||||
type Conv2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1)
|
||||
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type Conv3D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor {
|
||||
t = m.Weight.Conv3D(ctx, t, c, s0, s1, s2, p0, p1, p2, d0, d1, d2)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -14,4 +14,5 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3vl"
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ package qwen3
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@@ -210,7 +211,7 @@ var _ model.Model = (*Model)(nil)
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
for i := range layers {
|
||||
if c.String("general.architecture") == "qwen3moe" {
|
||||
if strings.HasSuffix(c.String("general.architecture"), "moe") {
|
||||
layers[i].MLP = &sparse{}
|
||||
} else {
|
||||
layers[i].MLP = &dense{}
|
||||
|
||||
194
model/models/qwen3vl/imageprocessor.go
Normal file
194
model/models/qwen3vl/imageprocessor.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package qwen3vl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
// ImageProcessor contains configuration for the Qwen 3 VL image processing
|
||||
type ImageProcessor struct {
|
||||
numChannels int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
mergeSize int
|
||||
shortestEdge int
|
||||
longestEdge int
|
||||
factor int
|
||||
rescaleFactor float32
|
||||
imageMean []float32
|
||||
imageStd []float32
|
||||
}
|
||||
|
||||
// newImageProcessor creates a new image processor with default values
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
|
||||
return ImageProcessor{
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: 2,
|
||||
mergeSize: mergeSize,
|
||||
shortestEdge: int(c.Uint("vision.shortest_edge", 64<<10)),
|
||||
// FIXME(mxyng): the model defined longest edge (16M) is too large for the default
|
||||
// context length of 8K and will panic. Adjusting to 2M for now.
|
||||
// longestEdge: int(c.Uint("vision.longest_edge", 16<<20)),
|
||||
longestEdge: 2 << 20,
|
||||
factor: patchSize * mergeSize,
|
||||
rescaleFactor: 1.0 / 255.0,
|
||||
imageMean: c.Floats("vision.image_mean", imageproc.ImageNetStandardMean[:]),
|
||||
imageStd: c.Floats("vision.image_std", imageproc.ImageNetStandardSTD[:]),
|
||||
}
|
||||
}
|
||||
|
||||
// SmartResize implements the smart resize algorithm
|
||||
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
factor := p.factor
|
||||
|
||||
if height < factor || width < factor {
|
||||
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
|
||||
} else if aspectRatio := max(height, width) / min(height, width); aspectRatio > 200 {
|
||||
panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %v", aspectRatio))
|
||||
}
|
||||
|
||||
round := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||
|
||||
hBar := round(float64(height)/float64(factor)) * factor
|
||||
wBar := round(float64(width)/float64(factor)) * factor
|
||||
|
||||
if hBar*wBar > p.longestEdge {
|
||||
beta := math.Sqrt(float64(height*width) / float64(p.longestEdge))
|
||||
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if hBar*wBar < p.shortestEdge {
|
||||
beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width))
|
||||
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
type Grid struct {
|
||||
Height int
|
||||
Width int
|
||||
Temporal int
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(ctx ml.Context, img image.Image) (ml.Tensor, *Grid, error) {
|
||||
origWidth := img.Bounds().Dx()
|
||||
origHeight := img.Bounds().Dy()
|
||||
|
||||
// Calculate smart resize dimensions
|
||||
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
|
||||
|
||||
// Resize image using existing functions
|
||||
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||
|
||||
normalizedPixels := imageproc.Normalize(
|
||||
resizedImg,
|
||||
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
|
||||
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
|
||||
true, // rescale
|
||||
true, // channelFirst
|
||||
)
|
||||
|
||||
// Calculate grid dimensions
|
||||
grid := &Grid{
|
||||
Height: resizedHeight / p.patchSize,
|
||||
Width: resizedWidth / p.patchSize,
|
||||
Temporal: 1, // For single images, temporal dimension is 1
|
||||
}
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
|
||||
}
|
||||
|
||||
patchDim := p.numChannels * p.temporalPatchSize *
|
||||
p.patchSize * p.patchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(patches, patchDim, numPatches)
|
||||
|
||||
// Return patches and grid dimensions
|
||||
return pixelValues, grid, nil
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
|
||||
channels := p.numChannels
|
||||
patchSize := p.patchSize
|
||||
mergeSize := p.mergeSize
|
||||
temporalPatchSize := p.temporalPatchSize
|
||||
|
||||
// Calculate output dimensions
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
|
||||
result := make([]float32, numPatches*patchDim)
|
||||
patchIndex := 0
|
||||
|
||||
// Single temporal frame handling (copies to all frames)
|
||||
for range grid.Temporal {
|
||||
for h := 0; h < grid.Height; h += mergeSize {
|
||||
for w := 0; w < grid.Width; w += mergeSize {
|
||||
// Handle the 2x2 merged patches
|
||||
for mh := range mergeSize {
|
||||
for mw := range mergeSize {
|
||||
baseOffset := patchIndex * patchDim
|
||||
|
||||
// Extract patch data for first temporal frame
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||
|
||||
for py := range patchSize {
|
||||
for px := range patchSize {
|
||||
// Calculate source pixel coordinates
|
||||
y := (h+mh)*patchSize + py
|
||||
x := (w+mw)*patchSize + px
|
||||
|
||||
// Source index in input tensor (CHW format)
|
||||
srcIdx := c*height*width + y*width + x
|
||||
|
||||
// Destination index in first temporal frame
|
||||
dstIdx := channelOffset + (py * patchSize) + px
|
||||
|
||||
if srcIdx < len(pixels) && dstIdx < len(result) {
|
||||
result[dstIdx] = pixels[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy first temporal frame to all other frames
|
||||
if temporalPatchSize > 1 {
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||
firstFrameOffset := channelOffset
|
||||
frameSize := patchSize * patchSize
|
||||
|
||||
// Copy first frame to all other frames
|
||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||
currentFrameOffset := channelOffset + (tp * frameSize)
|
||||
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
|
||||
result[firstFrameOffset:firstFrameOffset+frameSize])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patchIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
204
model/models/qwen3vl/model.go
Normal file
204
model/models/qwen3vl/model.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package qwen3vl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
positionCache []int32
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, grid, err := m.ProcessImage(ctx, img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate tensor dimensions
|
||||
visionOutputs, deepstackVisualEmbeds := m.VisionModel.Forward(ctx, pixelValues, grid)
|
||||
mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}}
|
||||
for i := range deepstackVisualEmbeds {
|
||||
mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]})
|
||||
}
|
||||
|
||||
return mm, nil
|
||||
}
|
||||
|
||||
var (
|
||||
tokenVision int32 = 151655
|
||||
tokenVisionStart int32 = 151652
|
||||
tokenVisionEnd int32 = 151653
|
||||
)
|
||||
|
||||
type modelInput struct {
|
||||
*input.Input
|
||||
position int32
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen 3 VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
m.positionCache = m.positionCache[:0]
|
||||
return slices.Collect(func(yield func(*input.Input) bool) {
|
||||
for i := range inputs {
|
||||
s := []modelInput{{Input: inputs[i]}}
|
||||
if mm := inputs[i].Multimodal; mm != nil {
|
||||
t := mm[0].Tensor
|
||||
s = slices.Repeat([]modelInput{
|
||||
{
|
||||
position: int32(i + 1),
|
||||
Input: &input.Input{Token: tokenVision},
|
||||
},
|
||||
}, t.Dim(1)+1+1)
|
||||
|
||||
s[0] = modelInput{
|
||||
Input: &input.Input{Token: tokenVisionStart},
|
||||
position: int32(i),
|
||||
}
|
||||
|
||||
s[len(s)-1] = modelInput{
|
||||
Input: &input.Input{Token: tokenVisionEnd},
|
||||
position: int32(i + mm[0].Data.(*Grid).Width/m.spatialMergeSize + 1),
|
||||
}
|
||||
|
||||
s[1] = modelInput{
|
||||
Input: &input.Input{
|
||||
Token: tokenVision,
|
||||
Multimodal: inputs[i].Multimodal,
|
||||
MultimodalHash: inputs[i].MultimodalHash,
|
||||
SameBatch: t.Dim(1),
|
||||
},
|
||||
position: int32(i + 1),
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range s {
|
||||
position := e.position
|
||||
if position == 0 && len(m.positionCache) > 0 {
|
||||
position = m.positionCache[len(m.positionCache)-1] + 1
|
||||
}
|
||||
|
||||
m.positionCache = append(m.positionCache, position)
|
||||
if !yield(e.Input) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positionSlice := slices.Collect(makeSlice2D[int32](3, len(batch.Positions)))
|
||||
for i, id := range batch.Positions {
|
||||
if id < int32(len(m.positionCache)) {
|
||||
id = m.positionCache[id]
|
||||
} else if len(m.positionCache) > 0 {
|
||||
id = id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
|
||||
}
|
||||
|
||||
positionSlice[0][i] = id
|
||||
positionSlice[1][i] = id
|
||||
positionSlice[2][i] = id
|
||||
}
|
||||
|
||||
hiddenStates := m.TextModel.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
||||
|
||||
var deepstackVisualEmbeds []ml.Tensor
|
||||
for _, mi := range batch.Multimodal {
|
||||
visionOutputs := mi.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
w := grid.Width / m.spatialMergeSize
|
||||
positionSlice[1][mi.Index+i] += int32(i / w)
|
||||
positionSlice[2][mi.Index+i] += int32(i % w)
|
||||
}
|
||||
}
|
||||
|
||||
deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:]))
|
||||
for i, mm := range mi.Multimodal[1:] {
|
||||
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
||||
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
|
||||
}
|
||||
}
|
||||
|
||||
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0]), len(positionSlice))
|
||||
cos, sin := m.rotaryEmbedding(ctx, positions)
|
||||
for i, layer := range m.TextModel.Layers {
|
||||
if m.Cache != nil {
|
||||
m.Cache.SetLayer(i)
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.TextModel.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, outputs, m.Cache, m.Options)
|
||||
if i < len(deepstackVisualEmbeds) {
|
||||
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, 1e-06)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, position ml.Tensor) (ml.Tensor, error) {
|
||||
m.positionCache = nil
|
||||
return nil, kvcache.ErrNotSupported
|
||||
})
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen3vl", New)
|
||||
model.Register("qwen3vlmoe", New)
|
||||
}
|
||||
229
model/models/qwen3vl/model_text.go
Normal file
229
model/models/qwen3vl/model_text.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package qwen3vl
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength int
|
||||
|
||||
eps,
|
||||
ropeBase,
|
||||
ropeScale float32
|
||||
mropeSections []int
|
||||
|
||||
numExperts, numExpertsUsed int
|
||||
normTopKProb bool
|
||||
|
||||
inverseFrequenciesCache []float32
|
||||
}
|
||||
|
||||
func (o TextOptions) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
type TextAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP interface {
|
||||
Forward(ml.Context, ml.Tensor, *TextOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||
routerLogits := mlp.Router.Forward(ctx, hiddenStates)
|
||||
|
||||
routingWeights := routerLogits.Softmax(ctx)
|
||||
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts)
|
||||
if opts.normTopKProb {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
|
||||
|
||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
*TextAttention
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
TextMLP
|
||||
}
|
||||
|
||||
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = d.TextAttention.Forward(ctx, hiddenStates, cos, sin, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = d.TextMLP.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
|
||||
Options *TextOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) rotaryEmbedding(ctx ml.Context, positions ml.Tensor) (_, _ ml.Tensor) {
|
||||
positions = positions.Reshape(ctx, 1, positions.Dim(0), positions.Dim(1))
|
||||
if len(m.Options.inverseFrequenciesCache) == 0 {
|
||||
m.Options.inverseFrequenciesCache = make([]float32, m.Options.headDim()/2)
|
||||
for i := range m.Options.inverseFrequenciesCache {
|
||||
frequency := float32(math.Pow(float64(m.Options.ropeBase), float64(i*2)/float64(m.Options.headDim())))
|
||||
m.Options.inverseFrequenciesCache[i] = 1 / frequency
|
||||
}
|
||||
}
|
||||
|
||||
inverseFrequencies := ctx.Input().FromFloats(m.Options.inverseFrequenciesCache, 1, len(m.Options.inverseFrequenciesCache))
|
||||
|
||||
positions = positions.Cast(ctx, ml.DTypeF32)
|
||||
frequencies := inverseFrequencies.Mulmat(ctx, positions)
|
||||
|
||||
interleaved := frequencies.View(ctx,
|
||||
0, frequencies.Dim(0),
|
||||
frequencies.Stride(1), frequencies.Dim(1),
|
||||
)
|
||||
|
||||
for _, i := range []int{1, 2} {
|
||||
args := []int{
|
||||
i * frequencies.Stride(0), 1,
|
||||
3 * frequencies.Stride(0), m.Options.mropeSections[i],
|
||||
frequencies.Stride(1), frequencies.Dim(1),
|
||||
}
|
||||
|
||||
ctx.Forward(frequencies.View(ctx, i*frequencies.Stride(2)+args[0], args[1:]...).
|
||||
Copy(ctx, interleaved.View(ctx, args[0], args[1:]...)))
|
||||
}
|
||||
|
||||
interleaved = interleaved.Concat(ctx, interleaved, 0)
|
||||
interleaved = interleaved.Reshape(ctx, interleaved.Dim(0), 1, interleaved.Dim(1), interleaved.Dim(2))
|
||||
return interleaved.Cos(ctx), interleaved.Sin(ctx)
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
layers := make([]TextLayer, c.Uint("block_count"))
|
||||
for i := range layers {
|
||||
if strings.HasSuffix(c.String("general.architecture"), "moe") {
|
||||
layers[i].TextMLP = &sparse{}
|
||||
} else {
|
||||
layers[i].TextMLP = &dense{}
|
||||
}
|
||||
}
|
||||
|
||||
m := TextModel{
|
||||
Layers: layers,
|
||||
Options: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||
for _, section := range c.Ints("mrope_sections", []int32{24, 20, 20}) {
|
||||
if !yield(int(section)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
return &m
|
||||
}
|
||||
268
model/models/qwen3vl/model_vision.go
Normal file
268
model/models/qwen3vl/model_vision.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package qwen3vl
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type VisionAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
||||
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
|
||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||
}
|
||||
|
||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||
}
|
||||
|
||||
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1))
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1))
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `gguf:"linear_fc1"`
|
||||
FC2 *nn.Linear `gguf:"linear_fc2"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts VisionOptions) ml.Tensor {
|
||||
return mlp.FC2.Forward(ctx, mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx))
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
Norm1 *nn.LayerNorm `gguf:"norm1"`
|
||||
Attention *VisionAttention
|
||||
Norm2 *nn.LayerNorm `gguf:"norm2"`
|
||||
MLP *VisionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.Attention.Forward(ctx, hiddenStates, cos, sin, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionOptions struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
patchSize,
|
||||
numChannels,
|
||||
spatialMergeSize,
|
||||
temporalPatchSize,
|
||||
gridPerSide int
|
||||
|
||||
eps,
|
||||
ropeTheta float32
|
||||
|
||||
deepstackVisualIndexes []int32
|
||||
mropeSections []int
|
||||
}
|
||||
|
||||
func (o VisionOptions) headDim() int {
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
type VisionPatchMerger struct {
|
||||
Norm *nn.LayerNorm `gguf:"norm"`
|
||||
FC1 *nn.Linear `gguf:"linear_fc1"`
|
||||
FC2 *nn.Linear `gguf:"linear_fc2"`
|
||||
}
|
||||
|
||||
func (m *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, postshuffleNorm bool, opts VisionOptions) ml.Tensor {
|
||||
hiddenSize := opts.hiddenSize * opts.spatialMergeSize * opts.spatialMergeSize
|
||||
if postshuffleNorm {
|
||||
visionOutputs = visionOutputs.Reshape(ctx, hiddenSize, -1)
|
||||
}
|
||||
|
||||
visionOutputs = m.Norm.Forward(ctx, visionOutputs, opts.eps)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, hiddenSize, -1)
|
||||
return m.FC2.Forward(ctx, m.FC1.Forward(ctx, visionOutputs).GELU(ctx))
|
||||
}
|
||||
|
||||
type VisionPositionEmbedding struct {
|
||||
PositionEmbedding *nn.Embedding `gguf:"pos_embed"`
|
||||
}
|
||||
|
||||
func makeSlice2D[T int32 | float32](n0, n1 int) iter.Seq[[]T] {
|
||||
return func(yield func([]T) bool) {
|
||||
for range n0 {
|
||||
if !yield(make([]T, n1)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts VisionOptions) ml.Tensor {
|
||||
indexSlice := slices.Collect(makeSlice2D[int32](4, grid.Height*grid.Width))
|
||||
weightSlice := slices.Collect(makeSlice2D[float32](4, grid.Height*grid.Width))
|
||||
|
||||
stepHeight := float32(opts.gridPerSide-1) / float32(grid.Height-1)
|
||||
stepWidth := float32(opts.gridPerSide-1) / float32(grid.Width-1)
|
||||
|
||||
var i int
|
||||
for h := range grid.Height {
|
||||
for w := range grid.Width {
|
||||
y, x := float32(h)*stepHeight, float32(w)*stepWidth
|
||||
|
||||
floorY, floorX := int32(y), int32(x)
|
||||
ceilY, ceilX := min(floorY+1, int32(opts.gridPerSide-1)), min(floorX+1, int32(opts.gridPerSide-1))
|
||||
|
||||
indexSlice[0][i] = floorY*int32(opts.gridPerSide) + floorX
|
||||
indexSlice[1][i] = floorY*int32(opts.gridPerSide) + ceilX
|
||||
indexSlice[2][i] = ceilY*int32(opts.gridPerSide) + floorX
|
||||
indexSlice[3][i] = ceilY*int32(opts.gridPerSide) + ceilX
|
||||
|
||||
weightSlice[0][i] = (1 - (y - float32(floorY))) * (1 - (x - float32(floorX)))
|
||||
weightSlice[1][i] = (1 - (y - float32(floorY))) * (x - float32(floorX))
|
||||
weightSlice[2][i] = (y - float32(floorY)) * (1 - (x - float32(floorX)))
|
||||
weightSlice[3][i] = (y - float32(floorY)) * (x - float32(floorX))
|
||||
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
indices := ctx.Input().FromInts(slices.Concat(indexSlice...), grid.Height*grid.Width*4)
|
||||
weights := ctx.Input().FromFloats(slices.Concat(weightSlice...), 1, grid.Height*grid.Width*4)
|
||||
|
||||
n := hiddenStates.Dim(0)
|
||||
positionEmbeds := m.PositionEmbedding.Forward(ctx, indices)
|
||||
positionEmbeds = positionEmbeds.Mul(ctx, weights)
|
||||
positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4)
|
||||
|
||||
positionEmbeds = positionEmbeds.View(ctx, 0, n, positionEmbeds.Stride(1), grid.Height*grid.Width).
|
||||
Add(ctx, positionEmbeds.View(ctx, 1*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)).
|
||||
Add(ctx, positionEmbeds.View(ctx, 2*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)).
|
||||
Add(ctx, positionEmbeds.View(ctx, 3*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width))
|
||||
|
||||
positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize)
|
||||
positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1)
|
||||
return hiddenStates.Add(ctx, positionEmbeds)
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv3D `gguf:"patch_embed"`
|
||||
PositionEmbedding *VisionPositionEmbedding
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
PatchMerger *VisionPatchMerger `gguf:"merger"`
|
||||
DeepstackMerger []*VisionPatchMerger `gguf:"deepstack_merger"`
|
||||
|
||||
VisionOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) positions(ctx ml.Context, grid *Grid) (_, _ ml.Tensor) {
|
||||
indices := ctx.Input().FromInts(slices.Collect(func(yield func(int32) bool) {
|
||||
for y := range grid.Height {
|
||||
for x := range grid.Width {
|
||||
if !yield(int32(y)) {
|
||||
return
|
||||
}
|
||||
if !yield(int32(x)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}), grid.Width*grid.Height*2)
|
||||
|
||||
indices = indices.Reshape(ctx, -1, grid.Width/m.spatialMergeSize, m.spatialMergeSize, grid.Height/m.spatialMergeSize)
|
||||
indices = indices.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
indices = indices.Reshape(ctx, -1)
|
||||
|
||||
halfDim := m.headDim() / 2
|
||||
maxGrid := max(grid.Height, grid.Width)
|
||||
frequencies := ctx.Input().FromFloats(slices.Collect(func(yield func(float32) bool) {
|
||||
ropeTheta := float64(m.ropeTheta)
|
||||
for i := range maxGrid {
|
||||
for j := range halfDim / 2 {
|
||||
if !yield(float32(i) / float32(math.Pow(ropeTheta, float64(j*2)/float64(halfDim)))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}), halfDim/2, maxGrid)
|
||||
|
||||
embeds := frequencies.Rows(ctx, indices)
|
||||
embeds = embeds.Reshape(ctx, halfDim, 1, -1)
|
||||
embeds = embeds.Concat(ctx, embeds, 0)
|
||||
return embeds.Cos(ctx), embeds.Sin(ctx)
|
||||
}
|
||||
|
||||
// Forward computes the vision model for an input tensor
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) (ml.Tensor, []ml.Tensor) {
|
||||
pixelValues = pixelValues.Reshape(ctx, m.patchSize, m.patchSize, m.temporalPatchSize, -1)
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.numChannels, m.patchSize, m.patchSize, m.temporalPatchSize, 0, 0, 0, 1, 1, 1)
|
||||
hiddenStates = m.PositionEmbedding.Forward(ctx, hiddenStates, grid, m.VisionOptions)
|
||||
|
||||
cos, sin := m.positions(ctx, grid)
|
||||
|
||||
deepstackStates := make([]ml.Tensor, len(m.deepstackVisualIndexes))
|
||||
for i, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions)
|
||||
if i := slices.Index(m.deepstackVisualIndexes, int32(i)); i >= 0 {
|
||||
deepstackStates[i] = m.DeepstackMerger[i].Forward(ctx, hiddenStates, true, m.VisionOptions)
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, false, m.VisionOptions)
|
||||
return hiddenStates, deepstackStates
|
||||
}
|
||||
|
||||
// newVisionModel creates a new instance of the Qwen vision model
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes")
|
||||
model := &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
|
||||
DeepstackMerger: make([]*VisionPatchMerger, len(deepstackVisualIndexes)),
|
||||
VisionOptions: VisionOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length", 1280)),
|
||||
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
ropeTheta: c.Float("vision.rope.freq_base", 10000.0),
|
||||
spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)),
|
||||
temporalPatchSize: int(c.Uint("vision.temporal_patch_size", 2)),
|
||||
gridPerSide: int(math.Sqrt(float64(c.Uint("vision.num_positional_embeddings", 2304)))),
|
||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||
for _, section := range c.Ints("mrope_sections", []int32{24, 20, 20}) {
|
||||
if !yield(int(section)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
deepstackVisualIndexes: deepstackVisualIndexes,
|
||||
},
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
@@ -235,15 +235,28 @@ func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
|
||||
return count
|
||||
}
|
||||
|
||||
// TODO(jessegross): If we need to reprocess the inputs we should ensure that
|
||||
// we don't split up a SameBatch
|
||||
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
// ShiftDiscard computes how many inputs can be discarded from the cache. Inputs in the same batch
|
||||
// are discarded together.
|
||||
func (c *InputCache) ShiftDiscard(inputs []*input.Input, numKeep int32) int32 {
|
||||
targetFree := max((c.numCtx-numKeep)/2, 1)
|
||||
currentFree := c.numCtx - int32(len(inputs))
|
||||
|
||||
currentFree := c.numCtx - inputLen
|
||||
var discard, sameBatch int32
|
||||
for _, input := range inputs[numKeep:] {
|
||||
if sameBatch <= 0 && currentFree >= targetFree {
|
||||
break
|
||||
}
|
||||
|
||||
return max(targetFree-currentFree, 0)
|
||||
sameBatch--
|
||||
currentFree++
|
||||
discard++
|
||||
|
||||
if input.SameBatch > 0 {
|
||||
sameBatch = int32(input.SameBatch)
|
||||
}
|
||||
}
|
||||
|
||||
return discard
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
@@ -264,7 +277,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||
}
|
||||
|
||||
inputLen := int32(len(slot.Inputs))
|
||||
discard := c.ShiftDiscard(inputLen, numKeep)
|
||||
discard := c.ShiftDiscard(slot.Inputs, numKeep)
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package ollamarunner
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -238,59 +239,137 @@ func TestShiftDiscard(t *testing.T) {
|
||||
name string
|
||||
numCtx int32
|
||||
numKeep int32
|
||||
inputLen int32
|
||||
inputs []*input.Input
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Shift",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 2048,
|
||||
inputs: slices.Repeat([]*input.Input{{}}, 2048),
|
||||
expected: 1021,
|
||||
},
|
||||
{
|
||||
name: "Max Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 2048,
|
||||
inputs: slices.Repeat([]*input.Input{{}}, 2048),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "No Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 0,
|
||||
inputLen: 2048,
|
||||
inputs: slices.Repeat([]*input.Input{{}}, 2048),
|
||||
expected: 1024,
|
||||
},
|
||||
{
|
||||
name: "Truncate",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 5000,
|
||||
inputs: slices.Repeat([]*input.Input{{}}, 5000),
|
||||
expected: 3973,
|
||||
},
|
||||
{
|
||||
name: "Truncate Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 5000,
|
||||
inputs: slices.Repeat([]*input.Input{{}}, 5000),
|
||||
expected: 2953,
|
||||
},
|
||||
{
|
||||
name: "No Op",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 512,
|
||||
inputs: slices.Repeat([]*input.Input{{}}, 512),
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Same Batch",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputs: slices.Collect(func(yield func(*input.Input) bool) {
|
||||
for range 1024 {
|
||||
if !yield(&input.Input{}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(&input.Input{SameBatch: 512 - 1}) {
|
||||
return
|
||||
}
|
||||
|
||||
for range 2048 - 1024 - 1 {
|
||||
if !yield(&input.Input{}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
expected: 1531,
|
||||
},
|
||||
{
|
||||
name: "Same Batch Near Start",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputs: slices.Collect(func(yield func(*input.Input) bool) {
|
||||
for range 10 {
|
||||
if !yield(&input.Input{}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(&input.Input{SameBatch: 512 - 1}) {
|
||||
return
|
||||
}
|
||||
|
||||
for range 2048 - 10 - 1 {
|
||||
if !yield(&input.Input{}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
expected: 1021,
|
||||
},
|
||||
{
|
||||
name: "Consecutive Same Batch",
|
||||
numCtx: 32,
|
||||
inputs: slices.Collect(func(yield func(*input.Input) bool) {
|
||||
for i := range 32 {
|
||||
input := input.Input{}
|
||||
if i%10 == 0 {
|
||||
input.SameBatch = 10 - 1
|
||||
}
|
||||
if !yield(&input) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
expected: 20,
|
||||
},
|
||||
{
|
||||
name: "Overlapping Same Batch",
|
||||
numCtx: 32,
|
||||
inputs: slices.Collect(func(yield func(*input.Input) bool) {
|
||||
for i := range 32 {
|
||||
input := input.Input{}
|
||||
if slices.Contains([]int{4, 8, 14}, i) {
|
||||
input.SameBatch = 10 - 1
|
||||
}
|
||||
if !yield(&input) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
expected: 24,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := InputCache{numCtx: tt.numCtx}
|
||||
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
|
||||
result := c.ShiftDiscard(tt.inputs, tt.numKeep)
|
||||
if result != tt.expected {
|
||||
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
|
||||
t.Errorf("shiftDiscard(ctx: %v, keep: %v inputs: %v): have %v; want %v", tt.numCtx, tt.numKeep, len(tt.inputs), result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -214,7 +214,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
parts = []string{prompt}
|
||||
}
|
||||
|
||||
postTokenize := false
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
@@ -257,11 +256,10 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
mmStore.addMultimodal(imageEmbeddings)
|
||||
|
||||
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
}
|
||||
|
||||
if visionModel && postTokenize {
|
||||
if visionModel {
|
||||
var err error
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
|
||||
@@ -142,7 +142,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
|
||||
// This model is much more capable with a larger context, so set that
|
||||
// unless it would penalize performance too much
|
||||
if !s.lowVRAM && slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
if !s.lowVRAM && slices.Contains([]string{
|
||||
"gptoss", "gpt-oss",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, model.Config.ModelFamily) {
|
||||
opts.NumCtx = max(opts.NumCtx, 8192)
|
||||
}
|
||||
|
||||
|
||||
@@ -390,11 +390,11 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
|
||||
numParallel = 1
|
||||
}
|
||||
|
||||
// `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1
|
||||
// `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains(req.model.Config.ModelFamilies, "mllama") && numParallel != 1 {
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("mllama does not currently support parallel requests")
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
|
||||
Reference in New Issue
Block a user