From e9f6ea232fd8858f0f4b1999db380a663b62bbd9 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Tue, 3 Mar 2026 16:39:22 -0800 Subject: [PATCH] Add qwen3.5-next-moe support to MLX runner and models (#14417) This change adds support for qwen3.5-next-moe models (qwen3-next/qwen3.5-next/qwen3-coder) to the MLX runner. It also: * introduces recurrent cache support and related MLX ops * updates pipeline/runner integration and adds tests * properly quantizes stacked expert tensors * a Gated Delta Metal kernel for fast SSM inference * adds new MLX calls for Conv1d, DepthwideConv1d, Contiguous, Exp, Log, SoftmaxAxis --- x/create/create.go | 27 +- x/create/create_test.go | 42 + x/mlxrunner/cache.go | 138 ++- x/mlxrunner/cache/cache.go | 18 +- x/mlxrunner/cache/recurrent.go | 161 ++++ x/mlxrunner/imports.go | 2 + x/mlxrunner/mlx/gated_delta.go | 370 +++++++ x/mlxrunner/mlx/mlx.go | 2 +- x/mlxrunner/mlx/ops_extra.go | 62 +- x/mlxrunner/pipeline.go | 40 +- x/models/nn/nn.go | 34 + x/models/qwen3_5/qwen3_5.go | 1387 +++++++++++++++++++++++++++ x/models/qwen3_5/qwen3_5_test.go | 159 +++ x/models/qwen3_5_moe/qwen3_5_moe.go | 16 + 14 files changed, 2407 insertions(+), 51 deletions(-) create mode 100644 x/mlxrunner/cache/recurrent.go create mode 100644 x/mlxrunner/mlx/gated_delta.go create mode 100644 x/models/qwen3_5/qwen3_5.go create mode 100644 x/models/qwen3_5/qwen3_5_test.go create mode 100644 x/models/qwen3_5_moe/qwen3_5_moe.go diff --git a/x/create/create.go b/x/create/create.go index 385efadab..46b4393b3 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -288,6 +288,18 @@ func normalizeQuantType(quantize string) string { } } +func isStackedExpertWeight(name string) bool { + // Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert) + // or "...proj" (pre-stacked packed tensor). + if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") { + return false + } + + return strings.Contains(name, ".mlp.switch_mlp.") || + strings.Contains(name, ".mlp.experts.") || + strings.Contains(name, ".mlp.shared_experts.") +} + // GetTensorQuantization returns the appropriate quantization type for a tensor. // Returns "" if the tensor should not be quantized. // This implements mixed-precision quantization: @@ -296,18 +308,25 @@ func normalizeQuantType(quantize string) string { // - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel) // - Norms, embeddings, biases, routing gates: no quantization func GetTensorQuantization(name string, shape []int32, quantize string) string { + stackedExpert := isStackedExpertWeight(name) + // Use basic name-based check first - if !ShouldQuantize(name, "") { + if !stackedExpert && !ShouldQuantize(name, "") { return "" } - // Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any) - if len(shape) != 2 { + // Quantize standard linear weights (2D). Also allow stacked expert weights (3D), + // e.g. qwen switch_mlp / experts combined tensors. + if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) { return "" } // Skip small tensors (less than 1024 elements) - not worth quantizing - if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 { + var elems int64 = 1 + for _, d := range shape { + elems *= int64(d) + } + if elems < 1024 { return "" } diff --git a/x/create/create_test.go b/x/create/create_test.go index fb48987d6..7d9e68956 100644 --- a/x/create/create_test.go +++ b/x/create/create_test.go @@ -557,6 +557,10 @@ func TestShouldQuantizeTensor(t *testing.T) { // 3D+ tensors should not be quantized {"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false}, {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false}, + {"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true}, + {"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true}, + {"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true}, + {"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true}, // Embeddings should not be quantized regardless of shape {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false}, @@ -619,6 +623,44 @@ func TestExpertGroupPrefix(t *testing.T) { } } +func TestGetTensorQuantization_StackedExpert3D(t *testing.T) { + gateUp := GetTensorQuantization( + "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", + []int32{64, 22016, 4096}, + "int4", + ) + if gateUp != "int4" { + t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4") + } + + down := GetTensorQuantization( + "model.layers.1.mlp.experts.down_proj.weight", + []int32{64, 4096, 14336}, + "int4", + ) + if down != "int8" { + t.Fatalf("down_proj quantization = %q, want %q", down, "int8") + } + + combinedGateUp := GetTensorQuantization( + "model.language_model.layers.0.mlp.experts.gate_up_proj", + []int32{256, 1024, 2048}, + "int8", + ) + if combinedGateUp != "int8" { + t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8") + } + + combinedDown := GetTensorQuantization( + "model.language_model.layers.0.mlp.experts.down_proj", + []int32{256, 2048, 512}, + "int4", + ) + if combinedDown != "int8" { + t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8") + } +} + func TestCreateSafetensorsModel_WithQuantize(t *testing.T) { dir := t.TempDir() diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index 48f953b2c..a9ff8904c 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -30,21 +30,80 @@ type cacheSession struct { remaining []int32 } +func appendCacheState(dst []*mlx.Array, c cache.Cache) []*mlx.Array { + if c == nil { + return dst + } + + keys, values := c.State() + if keys != nil && keys.Valid() { + dst = append(dst, keys) + } + if values != nil && values.Valid() { + dst = append(dst, values) + } + + return dst +} + +func (c *kvCache) free() { + for i, kv := range c.caches { + if kv == nil { + continue + } + kv.Free() + c.caches[i] = nil + } + c.caches = nil + c.tokens = nil +} + +func (c *kvCache) cachesCanTrim() bool { + for _, kv := range c.caches { + if kv == nil { + continue + } + if !kv.CanTrim() { + return false + } + } + return true +} + +func (c *kvCache) trimToPrefix(prefix int) { + for _, kv := range c.caches { + if kv == nil || !kv.CanTrim() { + continue + } + if trim := kv.Offset() - prefix; trim > 0 { + kv.Trim(trim) + } + } + if prefix < len(c.tokens) { + c.tokens = c.tokens[:prefix] + } +} + // begin prepares caches for a new request. It finds the nearest // matching cache or creates new caches if none match. func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { - if len(c.caches) == 0 { + ensureCaches := func() { + if len(c.caches) != 0 { + return + } if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { c.caches = cacheFactory.NewCaches() - } else { - c.caches = make([]cache.Cache, m.NumLayers()) - for i := range c.caches { - c.caches[i] = cache.NewKVCache() - } + return + } + c.caches = make([]cache.Cache, m.NumLayers()) + for i := range c.caches { + c.caches[i] = cache.NewKVCache() } } + ensureCaches() remaining := c.findRemaining(inputs) + ensureCaches() return &cacheSession{ cache: c, @@ -56,18 +115,36 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { // close saves the token state if the forward pass ran. func (s *cacheSession) close() { - if offset := s.caches[0].Offset(); offset > 0 { - // Ensure that if we have run the forward pass and set the metadata - // that we also actually have the data - arrays := make([]*mlx.Array, 0, 2*len(s.caches)) - for _, c := range s.caches { - k, v := c.State() - arrays = append(arrays, k, v) - } - mlx.AsyncEval(arrays...) - - s.cache.tokens = append(s.inputs, s.outputs...)[:offset] + if len(s.caches) == 0 { + return } + + offset := -1 + arrays := make([]*mlx.Array, 0, 2*len(s.caches)) + for _, kv := range s.caches { + if kv == nil { + continue + } + // Mixed cache types (e.g. recurrent + KV) can transiently report different + // offsets, so use the minimum as the safe reusable token prefix. + if off := kv.Offset(); offset < 0 || off < offset { + offset = off + } + arrays = appendCacheState(arrays, kv) + } + if offset <= 0 { + return + } + + // Ensure that if we have run the forward pass and set the metadata + // that we also actually have the data. + mlx.AsyncEval(arrays...) + + stored := append(s.inputs, s.outputs...) + if offset > len(stored) { + offset = len(stored) + } + s.cache.tokens = stored[:offset] } // findRemaining finds the longest common prefix between tokens and the cached @@ -85,11 +162,13 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 { } if prefix < len(c.tokens) { - trim := len(c.tokens) - prefix - for _, kv := range c.caches { - kv.Trim(trim) + if c.cachesCanTrim() { + c.trimToPrefix(prefix) + } else { + c.free() + slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence") + return tokens } - c.tokens = c.tokens[:prefix] } if prefix == 0 { @@ -104,10 +183,21 @@ func (c *kvCache) log() { if len(c.caches) == 0 { return } + offset := -1 var totalBytes int for _, kv := range c.caches { - k, v := kv.State() - totalBytes += k.NumBytes() + v.NumBytes() + if kv == nil { + continue + } + if off := kv.Offset(); offset < 0 || off < offset { + offset = off + } + for _, a := range appendCacheState(nil, kv) { + totalBytes += a.NumBytes() + } } - logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes))) + if offset < 0 { + return + } + logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes))) } diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 3c1ff6011..7d0d0b060 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -9,7 +9,9 @@ import ( type Cache interface { Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array) + // State returns the cache-owned state roots that should be kept/evaluated. State() (keys, values *mlx.Array) + CanTrim() bool Trim(int) int Clone() Cache Free() @@ -60,13 +62,15 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { } func (c *KVCache) State() (*mlx.Array, *mlx.Array) { - if c.offset == c.keys.Dim(2) { - return c.keys, c.values + if c.keys == nil || c.values == nil { + return nil, nil } return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) } +func (c *KVCache) CanTrim() bool { return true } + func (c *KVCache) Trim(n int) int { n = min(c.offset, n) c.offset -= n @@ -183,13 +187,15 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra } func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) { - if c.offset < c.keys.Dim(2) { - return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), - c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) + if c.keys == nil || c.values == nil { + return nil, nil } - return c.keys, c.values + return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), + c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) } +func (c *RotatingKVCache) CanTrim() bool { return true } + func (c *RotatingKVCache) Trim(n int) int { n = min(c.offset, n) c.offset -= n diff --git a/x/mlxrunner/cache/recurrent.go b/x/mlxrunner/cache/recurrent.go new file mode 100644 index 000000000..0cbbc01e2 --- /dev/null +++ b/x/mlxrunner/cache/recurrent.go @@ -0,0 +1,161 @@ +//go:build mlx + +package cache + +import "github.com/ollama/ollama/x/mlxrunner/mlx" + +// RecurrentCache stores state for linear-recurrent layers. +// +// Conv state shape: [B, convTail, convDim] +// Delta state shape: [B, numVHeads, headVDim, headKDim] +type RecurrentCache struct { + convState *mlx.Array + deltaState *mlx.Array + offset int + + convTail int + convDim int + numVHeads int + headVDim int + headKDim int +} + +func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array { + if v == nil || !v.Valid() { + return old + } + if old == v { + return old + } + + mlx.Pin(v) + if old != nil && old != v { + mlx.Unpin(old) + } + + return v +} + +func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bool) *mlx.Array { + if v == nil || !v.Valid() { + return old + } + if old == v { + return old + } + + root := v + if ensureContiguous { + root = mlx.Contiguous(v, false) + } + detached := root.Clone() + + mlx.Pin(detached) + if old != nil && old != detached { + mlx.Unpin(old) + } + + return detached +} + +func snapshotPinned(a *mlx.Array) *mlx.Array { + if a == nil || !a.Valid() { + return nil + } + snap := mlx.Copy(a) + mlx.Eval(snap) + mlx.Pin(snap) + return snap +} + +func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache { + return &RecurrentCache{ + convTail: int(convTail), + convDim: int(convDim), + numVHeads: int(numVHeads), + headVDim: int(headVDim), + headKDim: int(headKDim), + } +} + +func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) { + if batch <= 0 { + batch = 1 + } + + needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype || + c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim + needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype || + c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim + if !needConv && !needDelta { + return + } + + if needConv { + c.convState = c.setStateRaw(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim)) + } + if needDelta { + c.deltaState = c.setStateRaw(c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim)) + } +} + +func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array { + c.ensure(batch, dtype) + return c.convState +} + +func (c *RecurrentCache) SetConvState(v *mlx.Array) { + c.convState = c.setStateDetached(c.convState, v, true) +} + +func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array { + c.ensure(batch, dtype) + return c.deltaState +} + +func (c *RecurrentCache) SetDeltaState(v *mlx.Array) { + c.deltaState = c.setStateDetached(c.deltaState, v, false) +} + +func (c *RecurrentCache) Advance(n int) { + c.offset += n +} + +func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { + return keys, values +} + +func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) { + return c.convState, c.deltaState +} + +func (c *RecurrentCache) CanTrim() bool { return false } + +func (c *RecurrentCache) Trim(n int) int { + // Recurrent state is not directly trimmable. Divergent prefixes must drop the cache. + _ = n + return 0 +} + +func (c *RecurrentCache) Clone() Cache { + clone := &RecurrentCache{ + offset: c.offset, + convTail: c.convTail, + convDim: c.convDim, + numVHeads: c.numVHeads, + headVDim: c.headVDim, + headKDim: c.headKDim, + convState: snapshotPinned(c.convState), + deltaState: snapshotPinned(c.deltaState), + } + return clone +} + +func (c *RecurrentCache) Free() { + mlx.Unpin(c.convState, c.deltaState) + c.convState, c.deltaState = nil, nil + c.offset = 0 +} + +func (c *RecurrentCache) Offset() int { return c.offset } +func (c *RecurrentCache) Len() int { return c.offset } diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go index 5e0cfe86d..9202dac34 100644 --- a/x/mlxrunner/imports.go +++ b/x/mlxrunner/imports.go @@ -7,4 +7,6 @@ import ( _ "github.com/ollama/ollama/x/models/glm4_moe_lite" _ "github.com/ollama/ollama/x/models/llama" _ "github.com/ollama/ollama/x/models/qwen3" + _ "github.com/ollama/ollama/x/models/qwen3_5" + _ "github.com/ollama/ollama/x/models/qwen3_5_moe" ) diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go new file mode 100644 index 000000000..7ace1f6d3 --- /dev/null +++ b/x/mlxrunner/mlx/gated_delta.go @@ -0,0 +1,370 @@ +//go:build mlx + +package mlx + +// #include +// #include "generated.h" +import "C" + +import ( + "sync" + "unsafe" +) + +var ( + gatedDeltaMetalKernelOnce sync.Once + gatedDeltaMetalKernel C.mlx_fast_metal_kernel + gatedDeltaMetalDisabled bool +) + +const gatedDeltaMetalKernelSource = ` +auto n = thread_position_in_grid.z; +auto b_idx = n / Hv; +auto hv_idx = n % Hv; +auto hk_idx = hv_idx / (Hv / Hk); +constexpr int n_per_t = Dk / 32; + +// q, k: [B, T, Hk, Dk] +auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; +auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + +// v, y: [B, T, Hv, Dv] +auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; +y += b_idx * T * Hv * Dv + hv_idx * Dv; + +auto dk_idx = thread_position_in_threadgroup.x; +auto dv_idx = thread_position_in_grid.y; + +// state_in, state_out: [B, Hv, Dv, Dk] +auto i_state = state_in + (n * Dv + dv_idx) * Dk; +auto o_state = state_out + (n * Dv + dv_idx) * Dk; + +float state[n_per_t]; +for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = static_cast(i_state[s_idx]); +} + +// g: [B, T, Hv] +auto g_ = g + b_idx * T * Hv; +auto beta_ = beta + b_idx * T * Hv; + +for (int t = 0; t < T; ++t) { + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * g_[hv_idx]; + kv_mem += state[i] * k_[s_idx]; + } + kv_mem = simd_sum(kv_mem); + + auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx]; + + float out = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] + k_[s_idx] * delta; + out += state[i] * q_[s_idx]; + } + out = simd_sum(out); + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out); + } + + q_ += Hk * Dk; + k_ += Hk * Dk; + v_ += Hv * Dv; + y += Hv * Dv; + g_ += Hv; + beta_ += Hv; +} + +for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + o_state[s_idx] = static_cast(state[i]); +} +` + +func cStringVector(values []string) (C.mlx_vector_string, func(), bool) { + vec := C.mlx_vector_string_new() + ok := true + for _, s := range values { + cs := C.CString(s) + if C.mlx_vector_string_append_value(vec, cs) != 0 { + ok = false + } + C.free(unsafe.Pointer(cs)) + if !ok { + break + } + } + cleanup := func() { + C.mlx_vector_string_free(vec) + } + return vec, cleanup, ok +} + +func initGatedDeltaMetalKernel() { + inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"}) + if !ok { + gatedDeltaMetalDisabled = true + freeInputs() + return + } + defer freeInputs() + + outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"}) + if !ok { + gatedDeltaMetalDisabled = true + freeOutputs() + return + } + defer freeOutputs() + + cName := C.CString("gated_delta_step") + defer C.free(unsafe.Pointer(cName)) + cSource := C.CString(gatedDeltaMetalKernelSource) + defer C.free(unsafe.Pointer(cSource)) + cHeader := C.CString("") + defer C.free(unsafe.Pointer(cHeader)) + + gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new( + cName, + inputs, + outputs, + cSource, + cHeader, + C.bool(true), + C.bool(false), + ) +} + +// gatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update. +// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure. +func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) { + if gatedDeltaMetalDisabled { + return nil, nil, false + } + if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil { + return nil, nil, false + } + + qd := q.Dims() + kd := k.Dims() + vd := v.Dims() + gd := g.Dims() + bd := beta.Dims() + sd := state.Dims() + if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 { + return nil, nil, false + } + + B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3] + if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 { + return nil, nil, false + } + if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk { + return nil, nil, false + } + Hv, Dv := vd[2], vd[3] + if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 { + return nil, nil, false + } + if gd[0] != B || gd[1] != T || gd[2] != Hv { + return nil, nil, false + } + if bd[0] != B || bd[1] != T || bd[2] != Hv { + return nil, nil, false + } + if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk { + return nil, nil, false + } + + dtype := q.DType() + if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype { + return nil, nil, false + } + + gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel) + if gatedDeltaMetalDisabled { + return nil, nil, false + } + + cfg := C.mlx_fast_metal_kernel_config_new() + defer C.mlx_fast_metal_kernel_config_free(cfg) + + cInT := C.CString("InT") + defer C.free(unsafe.Pointer(cInT)) + if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 { + gatedDeltaMetalDisabled = true + return nil, nil, false + } + for _, tpl := range []struct { + name string + value int + }{ + {name: "Dk", value: Dk}, + {name: "Dv", value: Dv}, + {name: "Hk", value: Hk}, + {name: "Hv", value: Hv}, + } { + cn := C.CString(tpl.name) + rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value)) + C.free(unsafe.Pointer(cn)) + if rc != 0 { + gatedDeltaMetalDisabled = true + return nil, nil, false + } + } + + yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)} + stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)} + if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 { + gatedDeltaMetalDisabled = true + return nil, nil, false + } + if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 { + gatedDeltaMetalDisabled = true + return nil, nil, false + } + if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 { + gatedDeltaMetalDisabled = true + return nil, nil, false + } + threadY := Dv + if threadY > 4 { + threadY = 4 + } + if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 { + gatedDeltaMetalDisabled = true + return nil, nil, false + } + + tScalar := FromValue(T) + inputs := []C.mlx_array{ + q.ctx, + k.ctx, + v.ctx, + g.ctx, + beta.ctx, + state.ctx, + tScalar.ctx, + } + inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs))) + defer C.mlx_vector_array_free(inVec) + + outVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 { + gatedDeltaMetalDisabled = true + return nil, nil, false + } + if int(C.mlx_vector_array_size(outVec)) < 2 { + return nil, nil, false + } + + y = New("GATED_DELTA_METAL_Y") + nextState = New("GATED_DELTA_METAL_STATE") + C.mlx_vector_array_get(&y.ctx, outVec, 0) + C.mlx_vector_array_get(&nextState.ctx, outVec, 1) + return y, nextState, true +} + +func repeatHeadsForGatedDelta(x *Array, repeatFactor int) *Array { + if repeatFactor <= 1 { + return x + } + shape := x.Dims() + x = ExpandDims(x, 3) + x = Tile(x, []int32{1, 1, 1, int32(repeatFactor), 1}) + return Reshape(x, int32(shape[0]), int32(shape[1]), int32(shape[2]*repeatFactor), int32(shape[3])) +} + +func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) { + if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil { + return nil, nil + } + + qd := q.Dims() + kd := k.Dims() + vd := v.Dims() + gd := g.Dims() + bd := beta.Dims() + sd := state.Dims() + if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 { + return nil, nil + } + + B, T, Hk, Dk := int32(qd[0]), int32(qd[1]), int32(qd[2]), int32(qd[3]) + Hv, Dv := int32(vd[2]), int32(vd[3]) + if T <= 0 || Hk <= 0 || Dk <= 0 || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 { + return nil, nil + } + if kd[0] != int(B) || kd[1] != int(T) || kd[2] != int(Hk) || kd[3] != int(Dk) { + return nil, nil + } + if vd[0] != int(B) || vd[1] != int(T) { + return nil, nil + } + if gd[0] != int(B) || gd[1] != int(T) || gd[2] != int(Hv) { + return nil, nil + } + if bd[0] != int(B) || bd[1] != int(T) || bd[2] != int(Hv) { + return nil, nil + } + if sd[0] != int(B) || sd[1] != int(Hv) || sd[2] != int(Dv) || sd[3] != int(Dk) { + return nil, nil + } + + repeatFactor := int(Hv / Hk) + q = repeatHeadsForGatedDelta(q, repeatFactor) + k = repeatHeadsForGatedDelta(k, repeatFactor) + + nextState = state + if T == 1 { + qt := Squeeze(q, 1) + kt := Squeeze(k, 1) + vt := Squeeze(v, 1) + gt := Squeeze(g, 1) + bt := Squeeze(beta, 1) + + nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1)) + kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false) + delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1)) + nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1))) + yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false) + return ExpandDims(yt, 1), nextState + } + + outs := make([]*Array, 0, T) + for t := int32(0); t < T; t++ { + qt := Squeeze(SliceStartStop(q, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1) + kt := Squeeze(SliceStartStop(k, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1) + vt := Squeeze(SliceStartStop(v, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dv}), 1) + gt := Squeeze(SliceStartStop(g, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1) + bt := Squeeze(SliceStartStop(beta, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1) + + nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1)) + kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false) + delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1)) + nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1))) + yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false) + outs = append(outs, ExpandDims(yt, 1)) + } + return Concatenate(outs, 1), nextState +} + +// GatedDelta runs the recurrent update operation. +// +// It uses the fused Metal kernel when available and otherwise falls back to a +// backend-agnostic MLX implementation with identical inputs/outputs. +func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) { + if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok { + return y, nextState + } + y, nextState = gatedDeltaFallback(q, k, v, g, beta, state) + if y == nil || nextState == nil { + panic("mlx.GatedDelta: fallback failed (invalid inputs or unsupported shapes)") + } + return y, nextState +} diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go index 0bf43830c..962d192ae 100644 --- a/x/mlxrunner/mlx/mlx.go +++ b/x/mlxrunner/mlx/mlx.go @@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) { defer C.mlx_vector_array_free(vector) for _, output := range outputs { - if output.Valid() { + if output != nil && output.Valid() { C.mlx_vector_array_append_value(vector, output.ctx) } } diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index e83b77fb8..283a2141c 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -113,6 +113,35 @@ func Where(condition, a, b *Array) *Array { return out } +func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array { + out := New("CONV1D") + C.mlx_conv1d( + &out.ctx, + x.ctx, + weight.ctx, + C.int(stride), + C.int(padding), + C.int(dilation), + C.int(groups), + DefaultStream().ctx, + ) + if bias != nil && bias.Valid() { + out = Add(out, bias) + } + return out +} + +func Contiguous(a *Array, allowColMajor bool) *Array { + out := New("CONTIGUOUS") + C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx) + return out +} + +func DepthwiseConv1d(x, weight *Array, bias *Array) *Array { + groups := int32(x.Dim(x.NumDims() - 1)) + return Conv1d(x, weight, bias, 1, 0, 1, groups) +} + // Convenience wrappers (function-style for the model code) func Stack(arrays []*Array, axis int) *Array { @@ -271,6 +300,24 @@ func Sigmoid(a *Array) *Array { return a.Sigmoid() } +func Exp(a *Array) *Array { + out := New("EXP") + C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +func Log(a *Array) *Array { + out := New("LOG") + C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +func SoftmaxAxis(a *Array, axis int, precise bool) *Array { + out := New("SOFTMAX_AXIS") + C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx) + return out +} + func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array { mask := New("") sinks := New("") @@ -288,7 +335,11 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b func RMSNormFn(x, weight *Array, eps float32) *Array { out := New("FAST_RMSNORM") - C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx) + var w C.mlx_array + if weight != nil { + w = weight.ctx + } + C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx) return out } @@ -378,6 +429,15 @@ func Collect(v any) []*Array { return arrays } +func Copy(a *Array) *Array { + if a == nil || !a.Valid() { + return a + } + out := New("COPY") + C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) { if !v.IsValid() { return diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 9061029ae..dbf1e182d 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -16,6 +16,10 @@ import ( "github.com/ollama/ollama/x/mlxrunner/mlx" ) +func prefillChunkSize() int { + return 2 << 10 +} + func (r *Runner) TextGenerationPipeline(request Request) error { if r.Model == nil { return errors.New("model not loaded") @@ -31,7 +35,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { mlx.DisableCompile() } mlx.ResetPeakMemory() - + ctx := request.Ctx var ( sample, logprobs *mlx.Array nextSample, nextLogprobs *mlx.Array @@ -74,24 +78,30 @@ func (r *Runner) TextGenerationPipeline(request Request) error { defer session.close() caches := session.caches tokens := session.remaining + prefillChunk := prefillChunkSize() + + materializeCaches := func() { + state := make([]*mlx.Array, 0, 2*len(caches)) + for _, c := range caches { + state = appendCacheState(state, c) + } + if len(state) == 0 { + return + } + mlx.Eval(state...) + } now := time.Now() total, processed := len(tokens), 0 for total-processed > 1 { - if err := request.Ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return err } - n := min(2<<10, total-processed-1) + n := min(prefillChunk, total-processed-1) r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) mlx.Sweep() - mlx.Eval(func() []*mlx.Array { - s := make([]*mlx.Array, 2*len(caches)) - for i, c := range caches { - s[2*i], s[2*i+1] = c.State() - } - return s - }()...) + materializeCaches() processed += n slog.Info("Prompt processing progress", "processed", processed, "total", total) mlx.ClearCache() @@ -118,7 +128,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1} for i := range request.Options.MaxTokens { - if err := request.Ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return err } @@ -140,8 +150,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } select { - case <-request.Ctx.Done(): - return request.Ctx.Err() + case <-ctx.Done(): + return ctx.Err() case request.Responses <- CompletionResponse{ Content: r.Decode(output, &b), }: @@ -158,8 +168,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error { final.EvalDuration = time.Since(now) select { - case <-request.Ctx.Done(): - return request.Ctx.Err() + case <-ctx.Done(): + return ctx.Err() case request.Responses <- final: return nil } diff --git a/x/models/nn/nn.go b/x/models/nn/nn.go index 3f57d483a..78f1b92b6 100644 --- a/x/models/nn/nn.go +++ b/x/models/nn/nn.go @@ -15,6 +15,40 @@ type LinearLayer interface { OutputDim() int32 } +// Conv1d applies 1D convolution over NLC input. +type Conv1d struct { + Weight *mlx.Array + Bias *mlx.Array + Stride int32 + Padding int32 + Dilation int32 + Groups int32 +} + +func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d { + if stride <= 0 { + stride = 1 + } + if dilation <= 0 { + dilation = 1 + } + if groups <= 0 { + groups = 1 + } + return &Conv1d{ + Weight: weight, + Bias: bias, + Stride: stride, + Padding: padding, + Dilation: dilation, + Groups: groups, + } +} + +func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array { + return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups) +} + // Linear applies an affine transformation: y = x @ W.T + b type Linear struct { Weight *mlx.Array diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go new file mode 100644 index 000000000..fbee82b59 --- /dev/null +++ b/x/models/qwen3_5/qwen3_5.go @@ -0,0 +1,1387 @@ +//go:build mlx + +// Package qwen3_5 provides the Qwen 3.5 text and MoE implementation for MLX. +package qwen3_5 + +import ( + "encoding/json" + "fmt" + "math" + "strings" + + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/mlxrunner/model/base" + "github.com/ollama/ollama/x/models/nn" + "github.com/ollama/ollama/x/tokenizer" +) + +func init() { + base.Register("Qwen3_5ForCausalLM", NewModel) + base.Register("Qwen3_5ForConditionalGeneration", NewModel) + base.Register("Qwen3NextForCausalLM", NewModel) + base.Register("Qwen3NextForConditionalGeneration", NewModel) +} + +// RopeParameters carries optional rope metadata embedded under rope_parameters. +type RopeParameters struct { + Type string `json:"type"` + RopeType string `json:"rope_type"` + RopeTheta float32 `json:"rope_theta"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` +} + +// Config holds Qwen 3.5 text config (top-level or nested text_config). +type Config struct { + ModelType string `json:"model_type"` + HiddenSize int32 `json:"hidden_size"` + IntermediateSize int32 `json:"intermediate_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + RMSNormEps float32 `json:"rms_norm_eps"` + VocabSize int32 `json:"vocab_size"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + AttentionBias bool `json:"attention_bias"` + TieWordEmbeddings bool `json:"tie_word_embeddings"` + LayerTypes []string `json:"layer_types"` + + FullAttentionInterval int32 `json:"full_attention_interval"` + + LinearNumValueHeads int32 `json:"linear_num_value_heads"` + LinearNumKeyHeads int32 `json:"linear_num_key_heads"` + LinearKeyHeadDim int32 `json:"linear_key_head_dim"` + LinearValueHeadDim int32 `json:"linear_value_head_dim"` + LinearConvKernelDim int32 `json:"linear_conv_kernel_dim"` + DecoderSparseStep int32 `json:"decoder_sparse_step"` + SharedExpertGateRank int32 `json:"-"` + + NumExperts int32 `json:"num_experts"` + NumExpertsPerTok int32 `json:"num_experts_per_tok"` + SharedExpertIntermediateSize int32 `json:"shared_expert_intermediate_size"` + MoeIntermediateSize int32 `json:"moe_intermediate_size"` + NormTopKProb bool `json:"norm_topk_prob"` + MLPOnlyLayers []int32 `json:"mlp_only_layers"` + + RopeTheta float32 `json:"rope_theta"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + RopeScaling map[string]any `json:"rope_scaling"` + RopeParameters *RopeParameters `json:"rope_parameters"` + + // Quantization metadata. + QuantGroupSize int `json:"-"` + QuantBits int `json:"-"` + QuantMode string `json:"-"` + TensorQuant map[string]*model.TensorQuantInfo `json:"-"` + + // Computed fields. + Scale float32 `json:"-"` + RopeDim int32 `json:"-"` +} + +// Model is the Qwen 3.5 model. +type Model struct { + EmbedTokens *nn.Embedding + Layers []*Layer + Norm *nn.RMSNorm + LMHead nn.LinearLayer + + tok *tokenizer.Tokenizer + *Config + + weightPrefix string +} + +// Layer is a transformer decoder layer. +type Layer struct { + InputNorm *nn.RMSNorm + PostAttentionNorm *nn.RMSNorm + + IsLinear bool + FullAttn *FullAttention + Linear *GatedDeltaNet + MLP MLPBlock +} + +// FullAttention is the full-attention branch used every N layers. +type FullAttention struct { + QProj nn.LinearLayer + KProj nn.LinearLayer + VProj nn.LinearLayer + OProj nn.LinearLayer + + QNorm *nn.RMSNorm + KNorm *nn.RMSNorm +} + +// GatedDeltaNet is the recurrent linear-attention branch. +type GatedDeltaNet struct { + InProjQKV nn.LinearLayer + InProjZ nn.LinearLayer + InProjB nn.LinearLayer + InProjA nn.LinearLayer + InProjQKVZ nn.LinearLayer + InProjBA nn.LinearLayer + OutProj nn.LinearLayer + + Conv1D *nn.Conv1d + ConvWeight *mlx.Array + NormWeight *mlx.Array + DtBias *mlx.Array + ALog *mlx.Array + AExp *mlx.Array +} + +// MLPBlock is the feed-forward interface for dense and MoE blocks. +type MLPBlock interface { + Forward(x *mlx.Array, cfg *Config) *mlx.Array +} + +// DenseMLP is SwiGLU feed-forward. +type DenseMLP struct { + GateProj nn.LinearLayer + UpProj nn.LinearLayer + DownProj nn.LinearLayer +} + +// SparseMoE is Qwen3.5's sparse MoE with shared expert. +type SparseMoE struct { + Gate nn.LinearLayer + SwitchMLP *SwitchMLP + SharedExpert *DenseMLP + SharedExpertGate nn.LinearLayer +} + +// SwitchMLP executes selected expert MLPs. +type SwitchMLP struct { + GateWeight *mlx.Array + UpWeight *mlx.Array + DownWeight *mlx.Array + + GateWeightQ, GateScales, GateBiases *mlx.Array + UpWeightQ, UpScales, UpBiases *mlx.Array + DownWeightQ, DownScales, DownBiases *mlx.Array + + GateBits int + UpBits int + DownBits int + + GateGroupSize int + UpGroupSize int + DownGroupSize int + + UseQuantized bool +} + +type stackedExpertWeights struct { + Weight *mlx.Array + Scales *mlx.Array + Biases *mlx.Array + Bits int + GroupSize int + Mode string +} + +func parseConfig(configData []byte) (Config, error) { + var rawTop map[string]json.RawMessage + if err := json.Unmarshal(configData, &rawTop); err != nil { + return Config{}, fmt.Errorf("parse config envelope: %w", err) + } + + var cfg Config + activeRaw := rawTop + if textRaw, ok := rawTop["text_config"]; ok { + if err := json.Unmarshal(textRaw, &cfg); err != nil { + return Config{}, fmt.Errorf("parse text_config: %w", err) + } + if err := json.Unmarshal(textRaw, &activeRaw); err != nil { + return Config{}, fmt.Errorf("parse text_config envelope: %w", err) + } + } else { + if err := json.Unmarshal(configData, &cfg); err != nil { + return Config{}, fmt.Errorf("parse config: %w", err) + } + } + + if cfg.HiddenSize <= 0 { + return Config{}, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize) + } + if cfg.NumHiddenLayers <= 0 { + return Config{}, fmt.Errorf("invalid num_hidden_layers: %d", cfg.NumHiddenLayers) + } + if cfg.NumAttentionHeads <= 0 { + return Config{}, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads) + } + if cfg.NumKeyValueHeads <= 0 { + cfg.NumKeyValueHeads = cfg.NumAttentionHeads + } + if cfg.HeadDim <= 0 { + if cfg.HiddenSize%cfg.NumAttentionHeads != 0 { + return Config{}, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads) + } + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + } + if cfg.HeadDim <= 0 { + return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim) + } + + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + if cfg.LinearConvKernelDim <= 0 { + cfg.LinearConvKernelDim = 4 + } + if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 { + return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)", + cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim) + } + if cfg.LinearNumValueHeads%cfg.LinearNumKeyHeads != 0 { + return Config{}, fmt.Errorf("linear_num_value_heads (%d) must be divisible by linear_num_key_heads (%d)", cfg.LinearNumValueHeads, cfg.LinearNumKeyHeads) + } + + if cfg.RopeParameters != nil { + if cfg.RopeParameters.RopeTheta > 0 { + cfg.RopeTheta = cfg.RopeParameters.RopeTheta + } + if cfg.RopeParameters.PartialRotaryFactor > 0 { + cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor + } + } + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 100000.0 + } + if cfg.PartialRotaryFactor == 0 { + cfg.PartialRotaryFactor = 0.25 + } + if cfg.PartialRotaryFactor < 0 { + cfg.PartialRotaryFactor = 0.25 + } + ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor) + if ropeDim <= 0 { + ropeDim = cfg.HeadDim + } + if ropeDim > cfg.HeadDim { + ropeDim = cfg.HeadDim + } + cfg.RopeDim = ropeDim + + if cfg.FullAttentionInterval <= 0 { + for i, lt := range cfg.LayerTypes { + if strings.Contains(strings.ToLower(lt), "full") { + cfg.FullAttentionInterval = int32(i + 1) + break + } + } + if cfg.FullAttentionInterval <= 0 { + cfg.FullAttentionInterval = 4 + } + } + if cfg.FullAttentionInterval > cfg.NumHiddenLayers { + cfg.FullAttentionInterval = cfg.NumHiddenLayers + } + + if cfg.NumExperts > 0 { + if cfg.NumExpertsPerTok <= 0 { + cfg.NumExpertsPerTok = 1 + } + if cfg.MoeIntermediateSize <= 0 { + cfg.MoeIntermediateSize = cfg.IntermediateSize + } + if cfg.SharedExpertIntermediateSize <= 0 { + cfg.SharedExpertIntermediateSize = cfg.IntermediateSize + } + if _, ok := activeRaw["norm_topk_prob"]; !ok { + cfg.NormTopKProb = true + } + if cfg.DecoderSparseStep <= 0 { + cfg.DecoderSparseStep = 1 + } + } + + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + return cfg, nil +} + +type tensorPathLayout struct { + containerPrefix string + modelPrefix string +} + +func (l tensorPathLayout) modelPath(suffix string) string { + return l.containerPrefix + l.modelPrefix + suffix +} + +func resolveTensorPathLayout(tensors map[string]*mlx.Array) tensorPathLayout { + for _, layout := range []tensorPathLayout{ + {containerPrefix: "", modelPrefix: "model."}, + {containerPrefix: "language_model.", modelPrefix: "model."}, + {containerPrefix: "language_model.", modelPrefix: ""}, + {containerPrefix: "model.language_model.", modelPrefix: "model."}, + {containerPrefix: "model.language_model.", modelPrefix: ""}, + } { + if tensors[layout.modelPath("embed_tokens.weight")] != nil { + return layout + } + } + + return tensorPathLayout{modelPrefix: "model."} +} + +func layerIsLinear(cfg *Config, layer int32) bool { + if len(cfg.LayerTypes) == int(cfg.NumHiddenLayers) { + t := strings.ToLower(cfg.LayerTypes[layer]) + return !strings.Contains(t, "full") + } + if cfg.FullAttentionInterval <= 0 { + return true + } + return (layer+1)%cfg.FullAttentionInterval != 0 +} + +func layerUsesMoE(cfg *Config, layer int32) bool { + if cfg.NumExperts <= 0 { + return false + } + for _, l := range cfg.MLPOnlyLayers { + if l == layer { + return false + } + } + if cfg.DecoderSparseStep <= 1 { + return true + } + return (layer+1)%cfg.DecoderSparseStep == 0 +} + +// NewModel creates a Qwen 3.5 model from a manifest root. +func NewModel(root *model.Root) (base.Model, error) { + configData, err := root.Manifest.ReadConfig("config.json") + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + cfg, err := parseConfig(configData) + if err != nil { + return nil, err + } + + if qt := root.QuantType(); qt != "" { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) + if gs := root.GroupSize(); gs > 0 { + cfg.QuantGroupSize = gs + } + } else { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("") + } + cfg.TensorQuant = root.AllTensorQuant() + + tokData, err := root.Manifest.ReadConfig("tokenizer.json") + if err != nil { + return nil, fmt.Errorf("load tokenizer config: %w", err) + } + + tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData} + if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil { + tokConfig.GenerationConfigJSON = genConfigData + } + if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil { + tokConfig.TokenizerConfigJSON = tokConfigData + } + + tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) + if err != nil { + return nil, fmt.Errorf("parse tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]*Layer, cfg.NumHiddenLayers), + Config: &cfg, + tok: tok, + } + + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + m.Layers[i] = &Layer{IsLinear: layerIsLinear(&cfg, i)} + } + + return m, nil +} + +func tensorAny(tensors map[string]*mlx.Array, keys ...string) (*mlx.Array, string) { + for _, k := range keys { + if v := tensors[k]; v != nil { + return v, k + } + } + return nil, "" +} + +func tensorByBase(tensors map[string]*mlx.Array, base string) (*mlx.Array, string) { + return tensorAny(tensors, base+".weight", base) +} + +func supportsGatherQMM(mode string, bits int) bool { + return mode == "affine" && (bits == 4 || bits == 8) +} + +func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) { + for _, k := range keys { + if k == "" { + continue + } + if t := tensors[k]; t != nil { + delete(tensors, k) + } + } +} + +func stackAndClone(parts []*mlx.Array) *mlx.Array { + if len(parts) == 0 { + return nil + } + stacked := mlx.Stack(parts, 0) + cloned := stacked.Clone() + mlx.Eval(cloned) + return cloned +} + +func transposeExpertWeightForGatherMM(w *mlx.Array) *mlx.Array { + if w == nil || !w.Valid() || w.NumDims() != 3 { + return w + } + t := mlx.Transpose(w, 0, 2, 1) + cloned := t.Clone() + mlx.Eval(cloned) + return cloned +} + +func describeMoEProjection(prefix string, w *stackedExpertWeights) string { + if w == nil { + return prefix + "=missing" + } + if w.Scales != nil { + return fmt.Sprintf("%s=qmm(mode=%s,bits=%d,gs=%d)", prefix, w.Mode, w.Bits, w.GroupSize) + } + if w.Bits > 0 || w.Mode != "" { + reason := "dequantized" + if !supportsGatherQMM(w.Mode, w.Bits) { + reason = "unsupported_gather_qmm" + } + return fmt.Sprintf("%s=%s(mode=%s,bits=%d,gs=%d)", prefix, reason, w.Mode, w.Bits, w.GroupSize) + } + return prefix + "=fp" +} + +func summarizeMoEFallbackReason(gateW, upW, downW *stackedExpertWeights) string { + for _, w := range []*stackedExpertWeights{gateW, upW, downW} { + if w == nil { + return "missing_projection" + } + if w.Scales != nil { + continue + } + if w.Bits > 0 || w.Mode != "" { + if !supportsGatherQMM(w.Mode, w.Bits) { + return fmt.Sprintf("unsupported_gather_qmm(mode=%s,bits=%d)", w.Mode, w.Bits) + } + return "dequantized_quant_weights" + } + } + return "unquantized_weights" +} + +func sliceStackedExpertAxis1(a *mlx.Array, start, stop int32) *mlx.Array { + if a == nil || !a.Valid() { + return nil + } + dims := a.Dims() + if len(dims) < 2 { + return nil + } + beg := make([]int32, len(dims)) + end := make([]int32, len(dims)) + for i, d := range dims { + end[i] = int32(d) + } + beg[1] = start + end[1] = stop + return mlx.SliceStartStop(a, beg, end) +} + +func loadStackedProjection(tensors map[string]*mlx.Array, cfg *Config, useQuantized bool, bases ...string) *stackedExpertWeights { + for _, base := range bases { + w, key := tensorByBase(tensors, base) + if w == nil { + continue + } + + scales := tensors[key+"_scale"] + if scales == nil { + return &stackedExpertWeights{Weight: w} + } + + qbiases := tensors[key+"_qbias"] + groupSize, bits, mode := model.ResolveLinearQuantParams( + cfg.QuantGroupSize, + cfg.QuantBits, + cfg.QuantMode, + cfg.TensorQuant, + key, + w, + scales, + ) + if useQuantized && supportsGatherQMM(mode, bits) { + return &stackedExpertWeights{ + Weight: w, + Scales: scales, + Biases: qbiases, + Bits: bits, + GroupSize: groupSize, + Mode: mode, + } + } + + return &stackedExpertWeights{ + Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode), + Bits: bits, + GroupSize: groupSize, + Mode: mode, + } + } + + return nil +} + +func collectPerExpertProjection(tensors map[string]*mlx.Array, cfg *Config, useQuantized bool, layerPrefix, proj string, numExperts int32) *stackedExpertWeights { + weights := make([]*mlx.Array, 0, numExperts) + scales := make([]*mlx.Array, 0, numExperts) + biases := make([]*mlx.Array, 0, numExperts) + consumedKeys := make([]string, 0, numExperts*3) + bits := 0 + groupSize := 0 + mode := cfg.QuantMode + + for e := int32(0); e < numExperts; e++ { + base := fmt.Sprintf("%s.mlp.experts.%d.%s", layerPrefix, e, proj) + w, key := tensorByBase(tensors, base) + if w == nil { + continue + } + consumedKeys = append(consumedKeys, key) + + s := tensors[key+"_scale"] + if s == nil { + weights = append(weights, w) + continue + } + consumedKeys = append(consumedKeys, key+"_scale") + qb := tensors[key+"_qbias"] + if qb != nil { + consumedKeys = append(consumedKeys, key+"_qbias") + } + gs, b, m := model.ResolveLinearQuantParams( + cfg.QuantGroupSize, + cfg.QuantBits, + cfg.QuantMode, + cfg.TensorQuant, + key, + w, + s, + ) + if bits == 0 { + bits = b + groupSize = gs + mode = m + } + if useQuantized && supportsGatherQMM(m, b) { + weights = append(weights, w) + scales = append(scales, s) + if qb != nil { + biases = append(biases, qb) + } + } else { + weights = append(weights, mlx.Dequantize(w, s, qb, gs, b, m)) + } + } + + if len(weights) == 0 { + return nil + } + + out := &stackedExpertWeights{Weight: stackAndClone(weights), Bits: bits, GroupSize: groupSize, Mode: mode} + if len(scales) == len(weights) { + out.Scales = stackAndClone(scales) + } + if len(biases) == len(weights) { + out.Biases = stackAndClone(biases) + } + freeTensorKeys(tensors, consumedKeys...) + return out +} + +func splitGateUpProjection(tensors map[string]*mlx.Array, cfg *Config, useQuantized bool, layerPrefix string) (gate, up, down *stackedExpertWeights) { + gateUp, key := tensorAny( + tensors, + layerPrefix+".mlp.experts.gate_up_proj.weight", + layerPrefix+".mlp.experts.gate_up_proj", + ) + if gateUp == nil { + return nil, nil, nil + } + + if scales := tensors[key+"_scale"]; scales != nil { + qbiases := tensors[key+"_qbias"] + groupSize, bits, mode := model.ResolveLinearQuantParams( + cfg.QuantGroupSize, + cfg.QuantBits, + cfg.QuantMode, + cfg.TensorQuant, + key, + gateUp, + scales, + ) + if useQuantized && supportsGatherQMM(mode, bits) { + gate = &stackedExpertWeights{ + Bits: bits, + GroupSize: groupSize, + Mode: mode, + } + up = &stackedExpertWeights{ + Bits: bits, + GroupSize: groupSize, + Mode: mode, + } + // Keep quantized packed tensor and split along the out-dim (axis=1). + // This assumes MLX quantization preserves the leading [experts, out, ...] layout. + if gateUp.NumDims() != 3 { + return nil, nil, nil + } + shape := gateUp.Dims() + nExperts, twoHidden, inHidden := int32(shape[0]), int32(shape[1]), int32(shape[2]) + _ = nExperts + _ = inHidden + mid := twoHidden / 2 + + gate.Weight = sliceStackedExpertAxis1(gateUp, 0, mid) + up.Weight = sliceStackedExpertAxis1(gateUp, mid, twoHidden) + gate.Scales = sliceStackedExpertAxis1(scales, 0, mid) + up.Scales = sliceStackedExpertAxis1(scales, mid, twoHidden) + if qbiases != nil { + gate.Biases = sliceStackedExpertAxis1(qbiases, 0, mid) + up.Biases = sliceStackedExpertAxis1(qbiases, mid, twoHidden) + } + } else { + gateUp = mlx.Dequantize(gateUp, scales, qbiases, groupSize, bits, mode) + gate = &stackedExpertWeights{Bits: bits, GroupSize: groupSize, Mode: mode} + up = &stackedExpertWeights{Bits: bits, GroupSize: groupSize, Mode: mode} + } + } + + if gateUp.NumDims() != 3 { + return nil, nil, nil + } + shape := gateUp.Dims() + nExperts, twoHidden, inHidden := int32(shape[0]), int32(shape[1]), int32(shape[2]) + mid := twoHidden / 2 + + if gate == nil { + gate = &stackedExpertWeights{} + } + if up == nil { + up = &stackedExpertWeights{} + } + if gate.Weight == nil { + gate.Weight = mlx.SliceStartStop(gateUp, []int32{0, 0, 0}, []int32{nExperts, mid, inHidden}) + } + if up.Weight == nil { + up.Weight = mlx.SliceStartStop(gateUp, []int32{0, mid, 0}, []int32{nExperts, twoHidden, inHidden}) + } + + downW, downKey := tensorAny( + tensors, + layerPrefix+".mlp.experts.down_proj.weight", + layerPrefix+".mlp.experts.down_proj", + ) + if downW == nil { + return gate, up, nil + } + if scales := tensors[downKey+"_scale"]; scales != nil { + qbiases := tensors[downKey+"_qbias"] + groupSize, bits, mode := model.ResolveLinearQuantParams( + cfg.QuantGroupSize, + cfg.QuantBits, + cfg.QuantMode, + cfg.TensorQuant, + downKey, + downW, + scales, + ) + if useQuantized && supportsGatherQMM(mode, bits) { + down = &stackedExpertWeights{ + Weight: downW, + Scales: scales, + Biases: qbiases, + Bits: bits, + GroupSize: groupSize, + Mode: mode, + } + return gate, up, down + } + downW = mlx.Dequantize(downW, scales, qbiases, groupSize, bits, mode) + down = &stackedExpertWeights{Bits: bits, GroupSize: groupSize, Mode: mode} + } + if down == nil { + down = &stackedExpertWeights{} + } + down.Weight = downW + return gate, up, down +} + +func sanitizeConvWeight(w *mlx.Array) *mlx.Array { + if w == nil { + return nil + } + if w.NumDims() == 3 { + if w.Dim(1) == 1 { + return mlx.Squeeze(w, 1) + } + if w.Dim(2) == 1 { + return mlx.Squeeze(w, 2) + } + } + return w +} + +func depthwiseConv1dKernelWeight(w *mlx.Array) *mlx.Array { + if w == nil { + return nil + } + switch w.NumDims() { + case 2: + // qwen3.5 manual path stores [C, K]; MLX grouped conv expects [Cout, K, Cin/groups]. + // For depthwise conv (groups=C), that is [C, K, 1]. + return mlx.ExpandDims(w, 2) + case 3: + switch { + case w.Dim(2) == 1: + // [C, K, 1] + return w + case w.Dim(1) == 1: + // [C, 1, K] -> [C, K, 1] + return mlx.Transpose(w, 0, 2, 1) + case w.Dim(0) == 1: + // [1, K, C] -> [C, K, 1] + return mlx.Transpose(w, 2, 1, 0) + } + } + return nil +} + +func shouldShiftNormKey(key string) bool { + for _, suffix := range []string{ + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + "model.norm.weight", + ".self_attn.q_norm.weight", + ".self_attn.k_norm.weight", + } { + if strings.HasSuffix(key, suffix) { + return true + } + } + return false +} + +func maybeShiftNormWeight(key string, w *mlx.Array, shouldShift bool) *mlx.Array { + if !shouldShift || w == nil || w.NumDims() != 1 || !shouldShiftNormKey(key) { + return w + } + return mlx.AddScalar(w, 1.0) +} + +// LoadWeights assigns tensors to model fields. +func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { + layout := resolveTensorPathLayout(tensors) + m.weightPrefix = layout.containerPrefix + prefix := m.weightPrefix + modelPrefix := layout.containerPrefix + layout.modelPrefix + cfg := m.Config + + linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant) + + shouldShiftNormWeights := false + mtpKeys := make([]string, 0) + for name, t := range tensors { + if strings.Contains(name, "mtp.") { + shouldShiftNormWeights = true + mtpKeys = append(mtpKeys, name) + continue + } + if !shouldShiftNormWeights && strings.Contains(name, ".linear_attn.conv1d.weight") && t != nil && t.NumDims() == 3 && t.Dim(2) != 1 { + shouldShiftNormWeights = true + } + } + if len(mtpKeys) > 0 { + freeTensorKeys(tensors, mtpKeys...) + } + + embedKey := modelPrefix + "embed_tokens.weight" + embedWeight := tensors[embedKey] + if embedWeight == nil { + return fmt.Errorf("missing embedding weight: %sembed_tokens.weight", modelPrefix) + } + m.EmbedTokens = nn.NewEmbedding(embedWeight) + + normKey := modelPrefix + "norm.weight" + normWeight := maybeShiftNormWeight(normKey, tensors[normKey], shouldShiftNormWeights) + if normWeight == nil { + return fmt.Errorf("missing final norm weight: %snorm.weight", modelPrefix) + } + m.Norm = nn.NewRMSNorm(normWeight, cfg.RMSNormEps) + + if cfg.TieWordEmbeddings { + m.LMHead = nn.NewLinear(embedWeight, nil) + } else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil { + m.LMHead = lmHead + } else if lmHead := linears.Make("lm_head"); lmHead != nil { + m.LMHead = lmHead + } else { + m.LMHead = nn.NewLinear(embedWeight, nil) + } + + useQuantizedExperts := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) + if !useQuantizedExperts && cfg.TensorQuant != nil { + for _, tq := range cfg.TensorQuant { + if tq == nil { + continue + } + _, bits, mode := model.QuantizationParams(tq.QuantType) + if supportsGatherQMM(mode, bits) { + useQuantizedExperts = true + break + } + } + } + moeLoadSummaries := make([]string, 0) + + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + layerPrefix := fmt.Sprintf("%slayers.%d", modelPrefix, i) + layer := &Layer{IsLinear: layerIsLinear(cfg, i)} + + if w := maybeShiftNormWeight(layerPrefix+".input_layernorm.weight", tensors[layerPrefix+".input_layernorm.weight"], shouldShiftNormWeights); w != nil { + layer.InputNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + if w := maybeShiftNormWeight(layerPrefix+".post_attention_layernorm.weight", tensors[layerPrefix+".post_attention_layernorm.weight"], shouldShiftNormWeights); w != nil { + layer.PostAttentionNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + if layer.InputNorm == nil || layer.PostAttentionNorm == nil { + return fmt.Errorf("layer %d: missing layer norms", i) + } + + if layer.IsLinear { + lin := &GatedDeltaNet{} + lin.InProjQKV = linears.Make(layerPrefix + ".linear_attn.in_proj_qkv") + lin.InProjZ = linears.Make(layerPrefix + ".linear_attn.in_proj_z") + lin.InProjB = linears.Make(layerPrefix + ".linear_attn.in_proj_b") + lin.InProjA = linears.Make(layerPrefix + ".linear_attn.in_proj_a") + lin.InProjQKVZ = linears.Make(layerPrefix + ".linear_attn.in_proj_qkvz") + lin.InProjBA = linears.Make(layerPrefix + ".linear_attn.in_proj_ba") + lin.OutProj = linears.Make(layerPrefix + ".linear_attn.out_proj") + + lin.ConvWeight = sanitizeConvWeight(tensors[layerPrefix+".linear_attn.conv1d.weight"]) + if lin.ConvWeight == nil { + lin.ConvWeight = sanitizeConvWeight(tensors[layerPrefix+".linear_attn.conv1d"]) + } + lin.NormWeight, _ = tensorAny(tensors, + layerPrefix+".linear_attn.norm.weight", + layerPrefix+".linear_attn.norm", + ) + lin.DtBias, _ = tensorAny(tensors, + layerPrefix+".linear_attn.dt_bias", + layerPrefix+".linear_attn.dt_proj", + ) + lin.ALog, _ = tensorAny(tensors, + layerPrefix+".linear_attn.A_log", + layerPrefix+".linear_attn.a_log", + ) + if lin.ALog != nil { + lin.AExp = mlx.Exp(lin.ALog.AsType(mlx.DTypeFloat32)) + } + + hasSplit := lin.InProjQKV != nil && lin.InProjZ != nil && lin.InProjB != nil && lin.InProjA != nil + hasCombined := lin.InProjQKVZ != nil && lin.InProjBA != nil + if (!hasSplit && !hasCombined) || lin.OutProj == nil { + return fmt.Errorf("layer %d: missing linear attention projections", i) + } + if lin.ConvWeight == nil || lin.NormWeight == nil || lin.DtBias == nil || lin.ALog == nil || lin.AExp == nil { + return fmt.Errorf("layer %d: missing linear attention state tensors", i) + } + if lin.ConvWeight.NumDims() != 2 { + return fmt.Errorf("layer %d: conv1d weight must be 2D after sanitization, got %dD", i, lin.ConvWeight.NumDims()) + } + if convKernel := depthwiseConv1dKernelWeight(lin.ConvWeight); convKernel != nil { + lin.Conv1D = nn.NewConv1d(convKernel, nil, 1, 0, 1, int32(lin.ConvWeight.Dim(0))) + } + + layer.Linear = lin + } else { + attn := &FullAttention{} + attn.QProj = linears.Make(layerPrefix + ".self_attn.q_proj") + attn.KProj = linears.Make(layerPrefix + ".self_attn.k_proj") + attn.VProj = linears.Make(layerPrefix + ".self_attn.v_proj") + attn.OProj = linears.Make(layerPrefix + ".self_attn.o_proj") + + if w := maybeShiftNormWeight(layerPrefix+".self_attn.q_norm.weight", tensors[layerPrefix+".self_attn.q_norm.weight"], shouldShiftNormWeights); w != nil { + attn.QNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + if w := maybeShiftNormWeight(layerPrefix+".self_attn.k_norm.weight", tensors[layerPrefix+".self_attn.k_norm.weight"], shouldShiftNormWeights); w != nil { + attn.KNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + + if attn.QProj == nil || attn.KProj == nil || attn.VProj == nil || attn.OProj == nil { + return fmt.Errorf("layer %d: missing full attention projections", i) + } + if attn.QNorm == nil || attn.KNorm == nil { + return fmt.Errorf("layer %d: missing full attention q/k norms", i) + } + layer.FullAttn = attn + } + + if layerUsesMoE(cfg, i) { + moe := &SparseMoE{} + moe.Gate = linears.Make(layerPrefix + ".mlp.gate") + if moe.Gate == nil { + return fmt.Errorf("layer %d: missing moe gate", i) + } + + gateW := loadStackedProjection(tensors, cfg, useQuantizedExperts, + layerPrefix+".mlp.switch_mlp.gate_proj", + layerPrefix+".mlp.experts.gate_proj", + ) + upW := loadStackedProjection(tensors, cfg, useQuantizedExperts, + layerPrefix+".mlp.switch_mlp.up_proj", + layerPrefix+".mlp.experts.up_proj", + ) + downW := loadStackedProjection(tensors, cfg, useQuantizedExperts, + layerPrefix+".mlp.switch_mlp.down_proj", + layerPrefix+".mlp.experts.down_proj", + ) + if gateW == nil || upW == nil || downW == nil { + g2, u2, d2 := splitGateUpProjection(tensors, cfg, useQuantizedExperts, layerPrefix) + if gateW == nil { + gateW = g2 + } + if upW == nil { + upW = u2 + } + if downW == nil { + downW = d2 + } + } + if gateW == nil || upW == nil || downW == nil { + gateW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "gate_proj", cfg.NumExperts) + upW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "up_proj", cfg.NumExperts) + downW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "down_proj", cfg.NumExperts) + } + + if gateW == nil || upW == nil || downW == nil { + return fmt.Errorf("layer %d: missing switch expert weights", i) + } + + switchMLP := &SwitchMLP{} + if gateW.Scales != nil && upW.Scales != nil && downW.Scales != nil { + switchMLP.UseQuantized = true + switchMLP.GateWeightQ = gateW.Weight + switchMLP.GateScales = gateW.Scales + switchMLP.GateBiases = gateW.Biases + switchMLP.GateBits = gateW.Bits + switchMLP.GateGroupSize = gateW.GroupSize + switchMLP.UpWeightQ = upW.Weight + switchMLP.UpScales = upW.Scales + switchMLP.UpBiases = upW.Biases + switchMLP.UpBits = upW.Bits + switchMLP.UpGroupSize = upW.GroupSize + switchMLP.DownWeightQ = downW.Weight + switchMLP.DownScales = downW.Scales + switchMLP.DownBiases = downW.Biases + switchMLP.DownBits = downW.Bits + switchMLP.DownGroupSize = downW.GroupSize + } else { + switchMLP.GateWeight = transposeExpertWeightForGatherMM(gateW.Weight) + switchMLP.UpWeight = transposeExpertWeightForGatherMM(upW.Weight) + switchMLP.DownWeight = transposeExpertWeightForGatherMM(downW.Weight) + moeLoadSummaries = append(moeLoadSummaries, + fmt.Sprintf( + "layer=%d moe_fallback reason=%s %s %s %s", + i, + summarizeMoEFallbackReason(gateW, upW, downW), + describeMoEProjection("gate", gateW), + describeMoEProjection("up", upW), + describeMoEProjection("down", downW), + ), + ) + } + if switchMLP.UseQuantized { + moeLoadSummaries = append(moeLoadSummaries, + fmt.Sprintf( + "layer=%d moe_quantized %s %s %s", + i, + describeMoEProjection("gate", gateW), + describeMoEProjection("up", upW), + describeMoEProjection("down", downW), + ), + ) + } + moe.SwitchMLP = switchMLP + + sharedGateProj := linears.Make(layerPrefix + ".mlp.shared_expert.gate_proj") + sharedUpProj := linears.Make(layerPrefix + ".mlp.shared_expert.up_proj") + sharedDownProj := linears.Make(layerPrefix + ".mlp.shared_expert.down_proj") + if sharedGateProj != nil && sharedUpProj != nil && sharedDownProj != nil { + moe.SharedExpert = &DenseMLP{ + GateProj: sharedGateProj, + UpProj: sharedUpProj, + DownProj: sharedDownProj, + } + moe.SharedExpertGate = linears.Make(layerPrefix + ".mlp.shared_expert_gate") + } + + layer.MLP = moe + } else { + mlp := &DenseMLP{ + GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"), + UpProj: linears.Make(layerPrefix + ".mlp.up_proj"), + DownProj: linears.Make(layerPrefix + ".mlp.down_proj"), + } + if mlp.GateProj == nil || mlp.UpProj == nil || mlp.DownProj == nil { + return fmt.Errorf("layer %d: missing dense mlp projections", i) + } + layer.MLP = mlp + } + + m.Layers[i] = layer + } + + return nil +} + +func softplus(x *mlx.Array) *mlx.Array { + return mlx.Log(mlx.AddScalar(mlx.Exp(x), 1.0)) +} + +func depthwiseCausalConv1d(x, w *mlx.Array, outLen int32) *mlx.Array { + if x == nil || w == nil { + return nil + } + if w.NumDims() != 2 { + return nil + } + B := int32(x.Dim(0)) + C := int32(w.Dim(0)) + K := int32(w.Dim(1)) + var out *mlx.Array + for i := int32(0); i < K; i++ { + seg := mlx.SliceStartStop(x, []int32{0, i, 0}, []int32{B, i + outLen, C}) + wi := mlx.SliceStartStop(w, []int32{0, i}, []int32{C, i + 1}) + wi = mlx.Reshape(wi, 1, 1, C) + term := mlx.Mul(seg, wi) + if out == nil { + out = term + } else { + out = mlx.Add(out, term) + } + } + return out +} + +func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k, v, z, b, a *mlx.Array) { + nk := cfg.LinearNumKeyHeads + nv := cfg.LinearNumValueHeads + dk := cfg.LinearKeyHeadDim + dv := cfg.LinearValueHeadDim + vPerK := nv / nk + + mixedQKVZ = mlx.Reshape(mixedQKVZ, B, L, nk, 2*dk+2*vPerK*dv) + q = mlx.SliceStartStop(mixedQKVZ, []int32{0, 0, 0, 0}, []int32{B, L, nk, dk}) + k = mlx.SliceStartStop(mixedQKVZ, []int32{0, 0, 0, dk}, []int32{B, L, nk, 2 * dk}) + v = mlx.SliceStartStop(mixedQKVZ, []int32{0, 0, 0, 2 * dk}, []int32{B, L, nk, 2*dk + vPerK*dv}) + z = mlx.SliceStartStop(mixedQKVZ, []int32{0, 0, 0, 2*dk + vPerK*dv}, []int32{B, L, nk, 2*dk + 2*vPerK*dv}) + + v = mlx.Reshape(v, B, L, nv, dv) + z = mlx.Reshape(z, B, L, nv, dv) + + mixedBA = mlx.Reshape(mixedBA, B, L, nk, 2*vPerK) + b = mlx.SliceStartStop(mixedBA, []int32{0, 0, 0, 0}, []int32{B, L, nk, vPerK}) + a = mlx.SliceStartStop(mixedBA, []int32{0, 0, 0, vPerK}, []int32{B, L, nk, 2 * vPerK}) + b = mlx.Reshape(b, B, L, nv) + a = mlx.Reshape(a, B, L, nv) + + return q, k, v, z, b, a +} + +func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + qg := a.QProj.Forward(x) + qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2) + q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim}) + gate := mlx.SliceStartStop(qg, []int32{0, 0, 0, cfg.HeadDim}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim * 2}) + gate = mlx.Reshape(gate, B, L, cfg.NumAttentionHeads*cfg.HeadDim) + + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) + v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) + + q = a.QNorm.Forward(q, cfg.RMSNormEps) + k = a.KNorm.Forward(k, cfg.RMSNormEps) + + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + offset := 0 + if c != nil { + offset = c.Offset() + } + q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) + k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) + + if c != nil { + k, v = c.Update(k, v) + } + + out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + out = mlx.Mul(out, mlx.Sigmoid(gate)) + out = a.OProj.Forward(out) + return out +} + +func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + var qkv, z, b, a *mlx.Array + useSplitProj := g.InProjQKV != nil && g.InProjZ != nil && g.InProjB != nil && g.InProjA != nil + if useSplitProj { + qkv = g.InProjQKV.Forward(x) + z = g.InProjZ.Forward(x) + z = mlx.Reshape(z, B, L, cfg.LinearNumValueHeads, cfg.LinearValueHeadDim) + b = g.InProjB.Forward(x) + a = g.InProjA.Forward(x) + } else { + mixedQKVZ := g.InProjQKVZ.Forward(x) + mixedBA := g.InProjBA.Forward(x) + var q, k, v *mlx.Array + q, k, v, z, b, a = splitQKVZBA(mixedQKVZ, mixedBA, cfg, B, L) + qkv = mlx.Concatenate([]*mlx.Array{ + mlx.Reshape(q, B, L, cfg.LinearNumKeyHeads*cfg.LinearKeyHeadDim), + mlx.Reshape(k, B, L, cfg.LinearNumKeyHeads*cfg.LinearKeyHeadDim), + mlx.Reshape(v, B, L, cfg.LinearNumValueHeads*cfg.LinearValueHeadDim), + }, -1) + } + + convTail := cfg.LinearConvKernelDim - 1 + var convState *mlx.Array + var rc *cache.RecurrentCache + if c != nil { + if typed, ok := c.(*cache.RecurrentCache); ok { + rc = typed + convState = rc.ConvState(int(B), x.DType()) + } + } + if convState == nil { + convState = mlx.Zeros(x.DType(), int(B), int(convTail), int(2*cfg.LinearNumKeyHeads*cfg.LinearKeyHeadDim+cfg.LinearNumValueHeads*cfg.LinearValueHeadDim)) + } + + convInput := mlx.Concatenate([]*mlx.Array{convState, qkv}, 1) + var convOut *mlx.Array + if g.Conv1D != nil { + convOut = g.Conv1D.Forward(convInput) + } else { + convOut = depthwiseCausalConv1d(convInput, g.ConvWeight, L) + } + convOut = mlx.SiLU(convOut) + if rc != nil { + total := int32(convInput.Dim(1)) + start := total - convTail + nextConv := mlx.SliceStartStop(convInput, []int32{0, start, 0}, []int32{B, total, int32(convInput.Dim(2))}) + rc.SetConvState(nextConv) + } + + keyDim := cfg.LinearNumKeyHeads * cfg.LinearKeyHeadDim + valueDim := cfg.LinearNumValueHeads * cfg.LinearValueHeadDim + q := mlx.SliceStartStop(convOut, []int32{0, 0, 0}, []int32{B, L, keyDim}) + k := mlx.SliceStartStop(convOut, []int32{0, 0, keyDim}, []int32{B, L, 2 * keyDim}) + v := mlx.SliceStartStop(convOut, []int32{0, 0, 2 * keyDim}, []int32{B, L, 2*keyDim + valueDim}) + q = mlx.Reshape(q, B, L, cfg.LinearNumKeyHeads, cfg.LinearKeyHeadDim) + k = mlx.Reshape(k, B, L, cfg.LinearNumKeyHeads, cfg.LinearKeyHeadDim) + v = mlx.Reshape(v, B, L, cfg.LinearNumValueHeads, cfg.LinearValueHeadDim) + invScale := float32(1.0 / math.Sqrt(float64(cfg.LinearKeyHeadDim))) + q = mlx.MulScalar(mlx.RMSNormFn(q, nil, 1e-6), invScale*invScale) + k = mlx.MulScalar(mlx.RMSNormFn(k, nil, 1e-6), invScale) + + aF32 := a.AsType(mlx.DTypeFloat32) + dtBiasF32 := g.DtBias.AsType(mlx.DTypeFloat32) + gDecay := softplus(mlx.Add(aF32, dtBiasF32)) + gDecay = mlx.Mul(gDecay, g.AExp) + gDecay = mlx.Exp(mlx.MulScalar(gDecay, -1)) + gDecay = gDecay.AsType(a.DType()) + + beta := mlx.Sigmoid(b) + + var state *mlx.Array + if rc != nil { + state = rc.DeltaState(int(B), x.DType()) + } + if state == nil { + state = mlx.Zeros(x.DType(), int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim)) + } + + out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state) + out = mlx.RMSNormFn(out, g.NormWeight, cfg.RMSNormEps) + out = mlx.Mul(out, mlx.SiLU(z)) + out = mlx.Reshape(out, B, L, valueDim) + out = g.OutProj.Forward(out) + if rc != nil { + rc.SetDeltaState(state) + rc.Advance(int(L)) + } + return out +} + +func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array { + return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) +} + +func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array { + dims := x.Dims() + B, L := int32(dims[0]), int32(dims[1]) + topK := cfg.NumExpertsPerTok + + xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2) + xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize) + idxFlat := mlx.Reshape(indices, B*L, topK) + + doSort := B*L >= 64 + var invOrder *mlx.Array + n := B * L * topK + + if doSort { + idxAll := mlx.Flatten(idxFlat) + order := mlx.Argsort(idxAll, 0) + invOrder = mlx.Argsort(order, 0) + xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1) + idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) + } + + var gate, up, hidden, down *mlx.Array + if s.UseQuantized { + gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases, + nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort) + up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, + nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) + hidden = mlx.Mul(mlx.SiLU(gate), up) + down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, + nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) + } else { + gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort) + up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort) + hidden = mlx.Mul(mlx.SiLU(gate), up) + down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort) + } + + if doSort { + down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize) + } else { + down = mlx.Squeeze(down, 2) + } + + return mlx.Reshape(down, B, L, topK, cfg.HiddenSize) +} + +func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array { + dims := x.Dims() + B, L := int32(dims[0]), int32(dims[1]) + + probs := mlx.SoftmaxAxis(m.Gate.Forward(x), -1, true) + neg := mlx.Neg(probs) + inds := mlx.Argpartition(neg, int(cfg.NumExpertsPerTok)-1, -1) + shape := inds.Dims() + inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), cfg.NumExpertsPerTok}) + + scores := mlx.TakeAlongAxis(probs, inds, -1) + if cfg.NormTopKProb && cfg.NumExpertsPerTok > 1 { + sumScores := mlx.Sum(scores, -1, true) + scores = mlx.Div(scores, sumScores) + } + + expertOut := m.SwitchMLP.Forward(x, inds, cfg) + y := mlx.Sum(mlx.Mul(expertOut, mlx.ExpandDims(scores, -1)), 2, false) + + if m.SharedExpert != nil { + shared := m.SharedExpert.Forward(x, cfg) + if m.SharedExpertGate != nil { + shared = mlx.Mul(shared, mlx.Sigmoid(m.SharedExpertGate.Forward(x))) + } + y = mlx.Add(y, shared) + } + + return mlx.Reshape(y, B, L, cfg.HiddenSize) +} + +func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + var r *mlx.Array + normed := l.InputNorm.Forward(x, cfg.RMSNormEps) + if l.IsLinear { + r = l.Linear.Forward(normed, c, B, L, cfg) + } else { + r = l.FullAttn.Forward(normed, c, B, L, cfg) + } + h := mlx.Add(x, r) + r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg) + return mlx.Add(h, r) +} + +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + dims := tokens.Dims() + B, L := int32(dims[0]), int32(dims[1]) + + h := m.EmbedTokens.Forward(tokens) + for i, layer := range m.Layers { + var c cache.Cache + if caches != nil && i < len(caches) { + c = caches[i] + } + h = layer.Forward(h, c, B, L, m.Config) + } + out := m.Norm.Forward(h, m.RMSNormEps) + return out +} + +func (m *Model) Unembed(x *mlx.Array) *mlx.Array { + return m.LMHead.Forward(x) +} + +func (m *Model) NumLayers() int { + return len(m.Layers) +} + +func (m *Model) MaxContextLength() int { + return int(m.MaxPositionEmbeddings) +} + +func (m *Model) Tokenizer() *tokenizer.Tokenizer { + return m.tok +} + +func (m *Model) NewCaches() []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + convTail := m.LinearConvKernelDim - 1 + convDim := 2*m.LinearNumKeyHeads*m.LinearKeyHeadDim + m.LinearNumValueHeads*m.LinearValueHeadDim + for i, layer := range m.Layers { + if layer.IsLinear { + caches[i] = cache.NewRecurrentCache(convTail, convDim, m.LinearNumValueHeads, m.LinearValueHeadDim, m.LinearKeyHeadDim) + } else { + caches[i] = cache.NewKVCache() + } + } + return caches +} diff --git a/x/models/qwen3_5/qwen3_5_test.go b/x/models/qwen3_5/qwen3_5_test.go new file mode 100644 index 000000000..0a70da189 --- /dev/null +++ b/x/models/qwen3_5/qwen3_5_test.go @@ -0,0 +1,159 @@ +//go:build mlx + +package qwen3_5 + +import ( + "testing" + + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +func TestParseConfigNestedDefaults(t *testing.T) { + data := []byte(`{ + "model_type": "Qwen3_5MoeForConditionalGeneration", + "text_config": { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_hidden_layers": 8, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "head_dim": 128, + "linear_num_value_heads": 64, + "linear_num_key_heads": 16, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_conv_kernel_dim": 4, + "num_experts": 16, + "num_experts_per_tok": 4, + "moe_intermediate_size": 2048, + "shared_expert_intermediate_size": 4096, + "rope_parameters": { + "rope_theta": 500000, + "partial_rotary_factor": 0.5 + } + } + }`) + + cfg, err := parseConfig(data) + if err != nil { + t.Fatalf("parseConfig failed: %v", err) + } + + if cfg.RopeTheta != 500000 { + t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta) + } + if cfg.RopeDim != 64 { + t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim) + } + if cfg.FullAttentionInterval != 4 { + t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval) + } + if !cfg.NormTopKProb { + t.Fatalf("norm_topk_prob should default to true for MoE") + } +} + +func TestLayerSelectionHelpers(t *testing.T) { + cfg := &Config{ + NumHiddenLayers: 6, + FullAttentionInterval: 3, + NumExperts: 8, + DecoderSparseStep: 2, + MLPOnlyLayers: []int32{1}, + } + + if !layerIsLinear(cfg, 0) { + t.Fatalf("layer 0 should be linear") + } + if layerIsLinear(cfg, 2) { + t.Fatalf("layer 2 should be full attention") + } + + if layerUsesMoE(cfg, 1) { + t.Fatalf("layer 1 should be forced dense by mlp_only_layers") + } + if !layerUsesMoE(cfg, 3) { + t.Fatalf("layer 3 should use moe with decoder_sparse_step=2") + } +} + +func TestResolveTensorPathLayout(t *testing.T) { + dummy := mlx.New("dummy") + + tests := []struct { + name string + key string + wantContainer string + wantModel string + }{ + { + name: "standard", + key: "model.embed_tokens.weight", + wantContainer: "", + wantModel: "model.", + }, + { + name: "nested language model with inner model", + key: "model.language_model.model.embed_tokens.weight", + wantContainer: "model.language_model.", + wantModel: "model.", + }, + { + name: "nested language model without inner model", + key: "model.language_model.embed_tokens.weight", + wantContainer: "model.language_model.", + wantModel: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + layout := resolveTensorPathLayout(map[string]*mlx.Array{ + tt.key: dummy, + }) + + if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel { + t.Fatalf( + "resolveTensorPathLayout() = {%q %q}, want {%q %q}", + layout.containerPrefix, + layout.modelPrefix, + tt.wantContainer, + tt.wantModel, + ) + } + }) + } +} + +func TestNewCachesLayout(t *testing.T) { + m := &Model{ + Config: &Config{ + LinearConvKernelDim: 4, + LinearNumKeyHeads: 2, + LinearKeyHeadDim: 8, + LinearNumValueHeads: 4, + LinearValueHeadDim: 16, + }, + Layers: []*Layer{ + {IsLinear: true}, + {IsLinear: false}, + {IsLinear: true}, + }, + } + + caches := m.NewCaches() + if len(caches) != len(m.Layers) { + t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers)) + } + + if _, ok := caches[0].(*cache.RecurrentCache); !ok { + t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0]) + } + if _, ok := caches[1].(*cache.KVCache); !ok { + t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1]) + } + if _, ok := caches[2].(*cache.RecurrentCache); !ok { + t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2]) + } +} diff --git a/x/models/qwen3_5_moe/qwen3_5_moe.go b/x/models/qwen3_5_moe/qwen3_5_moe.go new file mode 100644 index 000000000..9e0be26be --- /dev/null +++ b/x/models/qwen3_5_moe/qwen3_5_moe.go @@ -0,0 +1,16 @@ +//go:build mlx + +// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases. +package qwen3_5_moe + +import ( + "github.com/ollama/ollama/x/mlxrunner/model/base" + "github.com/ollama/ollama/x/models/qwen3_5" +) + +func init() { + base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel) + base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel) + base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel) + base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel) +}