mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
model: add MLA absorption for glm4moelite (#13810)
* model: add MLA absorption for glm4moelite Split the combined KV_B tensor into separate K_B and V_B tensors during conversion, enabling MLA (Multi-head Latent Attention) absorption which compresses the KV cache for improved efficiency. * ggml: enable MLA flash attention for GLM-4.7-flash Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash uses head size 576 with gqa_ratio 4, which was previously only supported for gqa_ratio 16 (DeepSeek). Metal changes: - Enable head size 576 for flash attention - Increase simdgroups to 8 for large heads (>=512) - Add case 8 kernel dispatch for 8 simdgroups CUDA changes: - Add gqa_ratio 4 support for head 576/512 - Add tile configs for (576, 512, 4) and (576, 512, 8) - Add MMA config cases for ncols 4 - Add template instances for ncols2=4 * model: add compatibility validation for glm4moelite architecture
This commit is contained in:
@@ -39,6 +39,13 @@ type Model interface {
|
||||
Config() config
|
||||
}
|
||||
|
||||
// Validator is an optional interface that models can implement to perform
|
||||
// validation after tensors have been loaded. If validation fails, model
|
||||
// loading will fail with the returned error.
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// MultimodalProcessor must be implemented by multimodal models.
|
||||
type MultimodalProcessor interface {
|
||||
// EncodeMultimodal processes a single input (such as an image) and
|
||||
@@ -116,6 +123,13 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
base := Base{b: b, config: m.Config()}
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
|
||||
if validator, ok := m.(Validator); ok {
|
||||
if err := validator.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -11,6 +12,8 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
@@ -47,7 +50,9 @@ type Attention struct {
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
KB *nn.Linear `gguf:"attn_k_b"`
|
||||
VB *nn.Linear `gguf:"attn_v_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
@@ -78,15 +83,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
// MLA absorption: absorb K projection into query
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3)
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
|
||||
attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
@@ -217,8 +223,12 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
kvLoraRank := int(c.Uint("attention.kv_lora_rank"))
|
||||
qkRopeHeadDim := int(c.Uint("rope.dimension_count"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
// For MLA absorption, the effective key dimension is kvLoraRank + qkRopeHeadDim
|
||||
mlaKeyLength := kvLoraRank + qkRopeHeadDim
|
||||
kqScale := 1.0 / math.Sqrt(float64(mlaKeyLength))
|
||||
|
||||
var pre []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
@@ -279,6 +289,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) {
|
||||
return ErrOldModelFormat
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
|
||||
73
model/models/glm4moelite/model_test.go
Normal file
73
model/models/glm4moelite/model_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model *Model
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid model with KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing KB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{VB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{KB: &nn.Linear{}}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing both KB and VB",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: &Attention{}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil Attention is ok",
|
||||
model: &Model{
|
||||
Layers: []Layer{
|
||||
{Attention: nil},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.model.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && err != ErrOldModelFormat {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user