mirror of
https://github.com/ollama/ollama.git
synced 2025-12-05 18:46:22 -06:00
Add deepseek v3.1 (#13063)
* Add mla for flash attention * Revert to using chunks
This commit is contained in:
@@ -230,7 +230,7 @@ type Tensor interface {
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
type ScaledDotProductAttention interface {
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, scale float64) Tensor
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||
}
|
||||
|
||||
type number interface {
|
||||
|
||||
@@ -1625,7 +1625,7 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
kqMask = mask.(*Tensor).t
|
||||
@@ -1642,6 +1642,16 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
|
||||
}
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
|
||||
if vmla != nil {
|
||||
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
|
||||
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||
cur = vmla.Mulmat(ctx, cur)
|
||||
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||
cur = cur.Contiguous(ctx)
|
||||
kqv = cur.(*Tensor).t
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
@@ -1654,6 +1664,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
if vmla != nil {
|
||||
kqv = vmla.Mulmat(ctx, kqv)
|
||||
}
|
||||
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,10 +22,14 @@ import (
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithSinks(ctx, query, key, value, nil, scale, cache)
|
||||
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
@@ -56,7 +60,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, scale)
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
@@ -71,6 +75,11 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
|
||||
if vmla != nil {
|
||||
kqv = vmla.Mulmat(ctx, kqv)
|
||||
}
|
||||
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package deepseek2
|
||||
// uses deepseek 2 architecture but written based on deepseek 3 model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
isMLA bool
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
normTopKProb bool
|
||||
@@ -32,8 +34,6 @@ type Options struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength,
|
||||
originalContextLength int
|
||||
|
||||
eps,
|
||||
@@ -62,6 +62,9 @@ type Attention struct {
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
KB *nn.Linear `gguf:"attn_k_b"`
|
||||
VB *nn.Linear `gguf:"attn_v_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
@@ -69,7 +72,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 { // nil {
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
@@ -88,21 +91,35 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||
)
|
||||
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
var attention ml.Tensor
|
||||
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
if !opts.isMLA { // v3
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
} else { // v3.1
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass)
|
||||
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
value := kPass
|
||||
|
||||
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
@@ -233,6 +250,10 @@ func New(c fs.Config) (model.Model, error) {
|
||||
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
|
||||
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
|
||||
|
||||
isMLA := c.Uint("attention.key_length_mla") != 0 && c.Uint("attention.value_length_mla") != 0
|
||||
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
|
||||
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
@@ -254,11 +275,10 @@ func New(c fs.Config) (model.Model, error) {
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
isMLA: isMLA,
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
@@ -266,13 +286,13 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")),
|
||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||
qkHeadDim: int(c.Uint("attention.key_length")),
|
||||
vHeadDim: int(c.Uint("attention.value_length")),
|
||||
qkHeadDim: keyLength,
|
||||
vHeadDim: valueLength,
|
||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
|
||||
Reference in New Issue
Block a user