Compare commits

...

5 Commits

Author SHA1 Message Date
Jesse Gross
9191dfaf05 llm: Enable flash attention for mistral3 by default 2025-12-04 15:19:06 -08:00
Jesse Gross
1108d8b34e ggml: Enable flash attention for vision encoders
Although the vision component of multimodal models typically already
call the optimized nn.Attention, it is converted into non-fused
operations. That is because the backend-specific fused kernels may
have requirements, such as padding, and they is performed by the
cache, which vision encoders don't use.

This implements a fallback path in the backend, softening the
requirements into optimizations. In turn, this allows flash attention
to be used for vision encoders, saving a significant amount of VRAM
and improving performance.
2025-12-04 15:19:06 -08:00
Jesse Gross
7837a5bc7e ggml: Always set cache padding to 256
We currently use cache padding of 32 when not using flash attention
and 256 with flash attention, which is based on the historic alignment
requirements of these kernels. The restrictions have since been
loosened but there are still performance benefits, such as better
CUDA graph reuse.

Since the requirement is no longer kernel-specific, set the padding
uniformly to 256, as llama.cpp has.
2025-12-04 15:19:06 -08:00
Patrick Devine
0a844f8e96 convert: add deepseek converter (#12980)
This change adds the ability for `ollama create` to convert models that use
the DeepSeek2 architecture (specifically DeepSeekV3 and DeepSeek-R1).
2025-12-04 13:49:30 -08:00
Eloi Torrents
a03223b86f cmd/bench: support writing benchmark output to file (#13263)
* cmd/bench: support writing benchmark output to file

This changes Ollama to allow the bench command to write benchmark
results to a user-specified output file instead of stdout when the
--output flag is provided.

---------

Co-authored-by: Patrick Devine <patrick@infrahq.com>
2025-12-04 13:22:41 -08:00
7 changed files with 221 additions and 10 deletions

View File

@@ -48,8 +48,8 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
case "benchstat":
if verbose {
printHeader := func() {
fmt.Printf("sysname: %s\n", runtime.GOOS)
fmt.Printf("machine: %s\n", runtime.GOARCH)
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
@@ -147,6 +147,17 @@ func BenchmarkChat(fOpt flagOptions) error {
return err
}
var out io.Writer = os.Stdout
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
return err
}
defer f.Close()
out = f
}
for _, model := range models {
for range *fOpt.epochs {
options := make(map[string]interface{})
@@ -241,13 +252,14 @@ func BenchmarkChat(fOpt flagOptions) error {
},
}
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
}
}
}
return nil
}

View File

@@ -208,6 +208,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &gptossModel{}
case "DeepseekOCRForCausalLM":
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -0,0 +1,173 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"github.com/ollama/ollama/fs/ggml"
)
type deepseek2Model struct {
ModelParameters // architectures, vocab_size
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
ScoringFunc string `json:"scoring_func"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
RopeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
Type string `json:"type"`
MScaleAllDim float32 `json:"mscale_all_dim"`
} `json:"rope_scaling"`
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"
kv["deepseek2.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["deepseek2.attention.head_count"] = numHeads
kv["deepseek2.attention.head_count_kv"] = numKVHeads
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
kv["deepseek2.attention.value_length"] = p.VHeadDim
kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
kv["deepseek2.embedding_length"] = p.HiddenSize
kv["deepseek2.expert_count"] = p.ExpertCount
kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount
var scoringFunc uint32
switch p.ScoringFunc {
case "softmax":
// not currently supported in the model, but needed for Deepseek-OCR
scoringFunc = 1
case "sigmoid":
scoringFunc = 2
}
kv["deepseek2.expert_gating_func"] = scoringFunc
kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
kv["deepseek2.feed_forward_length"] = p.IntermediateSize
kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim
kv["tokenizer.ggml.pre"] = "deepseek-v3"
return kv
}
func (p *deepseek2Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"language_model.", "",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@@ -831,6 +831,7 @@ func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"gemma3",
"gptoss", "gpt-oss",
"mistral3",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))

View File

@@ -233,8 +233,10 @@ type Tensor interface {
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
//
// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor
}
type number interface {

View File

@@ -687,7 +687,7 @@ func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
}
}
@@ -1645,7 +1645,29 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor {
// If the cache didn't help us with required transformations, do them here
if !cacheConfigApplied {
cacheConfig := t.b.CacheConfig()
// Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller
if cacheConfig.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
if mask != nil {
padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1)
if padSize > 0 {
mask = mask.Pad(ctx, 0, padSize, 0, 0)
}
if mask.DType() != cacheConfig.MaskDType {
mask = mask.Cast(ctx, cacheConfig.MaskDType)
}
}
}
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t

View File

@@ -57,10 +57,9 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
key, value, mask = cache.Get(ctx)
}
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
cacheConfigApplied := cache != nil
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)