[GH-ISSUE #15865] mlxrunner: gated_delta_step kernel writes recurrent state in InT (bf16) instead of StT (fp32), corrupting Qwen3.5/3.6 GatedDeltaNet output #72170

Open
opened 2026-05-05 03:35:15 -05:00 by GiteaMirror · 2 comments
Owner

Originally created by @andreinknv on GitHub (Apr 28, 2026).
Original GitHub issue: https://github.com/ollama/ollama/issues/15865

mlxrunner: gated_delta_step kernel writes recurrent state in InT (bf16) instead of StT (fp32), corrupting Qwen3.5/3.6 GatedDeltaNet output

What is the issue?

The Metal/CUDA gated_delta_step kernel in x/mlxrunner/mlx/gated_delta.go casts the recurrent state output back to InT (bf16 for our model) before writing it to the next-step state buffer. The reference implementation in mlx_lm.models.gated_delta uses a separate StT template arg for the state dtype and keeps state in fp32 to preserve precision across the recurrent decode loop.

Result on Qwen3.5/Qwen3.6 35B-A3B (any of bf16/mxfp8/nvfp4 variants): the model load, MoE routing, full-attention layers, embedding, and lm_head all work correctly, but the linear-attention-layer recurrent state degrades each decode step. The model emits one or two semi-correct tokens, then collapses into degenerate repetition (<|im_start|><|im_start|> echo, or repeating the last sampled token like 4\n4\n4\n4...).

Reproducer

  • macOS 26.4.1, M4 Max 36 GB unified memory, Xcode 26.4
  • Ollama v0.22.0 plus PR #15793 (mlx 0.31.2) and PR #15760 (per-tensor quant overrides) applied — the panic and crash paths those fix are already resolved
  • Model: qwen3.6:35b-a3b-coding-nvfp4 from the official library
curl -s http://localhost:11434/api/chat -d '{
  "model": "qwen3.6:35b-a3b-coding-nvfp4",
  "messages":[{"role":"user","content":"What is 2+2?"}],
  "stream": false,
  "think": false,
  "options": {"num_predict": 50, "temperature": 0.0}
}'

Before this fix

content: '<|im_start|><|im_start|>'
eval_count: 2
done_reason: stop

The model runs the forward pass cleanly (all 1436 tensors loaded, peak 19.7 GiB VRAM) and reaches the sampler, but the first few token logits put <|im_start|> (id 248045) at the top instead of any sensible response.

After this fix

content: '\n\n4\n4\n4\n4\n4...'
eval_count: 50
done_reason: length

The model now correctly identifies 4 as the answer to 2+2. EOS detection and a couple of remaining linear-attention precision points (gDecay cast, conv state dtype) are likely additional issues — but the recurrent-state-precision bug alone gates the model from producing any meaningful tokens at all.

Root cause

Comparing to mlx-lm/main/mlx_lm/models/gated_delta.py:

# MLX-LM kernel template
return mx.fast.metal_kernel(
    name=f"gated_delta_step{suffix}",
    input_names=inputs,
    output_names=["y", "state_out"],
    source=source,  # source uses InT for y cast and StT for state cast
)
# MLX-LM kernel source — state is cast back as StT, NOT InT:
source = f"""
  ...
  for (int i = 0; i < n_per_t; ++i) {{
    auto s_idx = n_per_t * dk_idx + i;
    o_state[s_idx] = static_cast<StT>(state[i]);  # ← StT
  }}
"""
# MLX-LM gated_delta_update — state explicitly fp32:
if state is None:
    state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)

Ollama's kernel source has static_cast<InT>(state[i]) and Ollama's qwen3_5.go initializes state with x.DType() (the input bf16). The kernel template never declares StT. Net effect: every recurrent step round-trips state through bf16 (~7 bits of mantissa), losing ~0.4% relative precision per step. Decay terms close to 1.0 (typical) compound this drift across 100s of decode steps until state is meaningless.

Patch

Three files (115 lines diff against v0.22.0). Summary:

  1. x/mlxrunner/mlx/gated_delta.go — kernel source (Metal and CUDA) uses StT for the state output cast; kernel config adds cStT template arg with state.DType(); state output_arg uses state.DType(); dtype validation no longer requires state.DType() == dtype.
  2. x/models/qwen3_5/qwen3_5.goGatedDeltaNet.Forward initializes the recurrent state as mlx.DTypeFloat32 (matches MLX-LM's mx.zeros(..., dtype=mx.float32)).
  3. x/mlxrunner/cache/recurrent.go — split ensure into ensureConv and ensureDelta so the conv state can stay at inputs.dtype (bf16) while the delta/recurrent state is fp32.

Validation

Test (temperature=0) Before After
What is 2+2? first token `< im_start
Token-level eval before done 2 (immediate stop) 50+ (proper inference)
MLX panic / load crash none (PR #15793 already fixed) none

There are remaining issues with multi-token coherence (the model gets the first token right then often loops) — those look like additional precision-leak points (conv state, gDecay) that I'm continuing to investigate, but the state-precision bug is the single biggest one and gates everything else.

  • PR #15793 (MLX 0.31.2 + threading) — fixes load-time panic, complementary to this fix
  • PR #15759 / #15760 (per-tensor quant overrides) — fixes a separate SparseMoE.Forward panic, complementary to this fix
  • PR #14968 (mlx: qwen3.5 vision support) — touches GatedDeltaNet but is currently CONFLICTING with main; the patch above applies cleanly to v0.22.0 + #15793 + #15760 and could merge ahead of the broader vision work
  • Issue #15822 (MLX runner failed with qwen3.6:35b-a3b-coding-bf16 format=json) — possibly the same root cause, since bf16 variants also exercise the same kernel path

Environment

  • Ollama version: v0.22.0 + #15793 + #15759 + #15760 + this patch
  • OS: macOS 26.4.1
  • Hardware: Apple M4 Max, 36 GB unified
  • MLX: 0.31.2 (via #15793)

Patch

full diff
diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go
index c691f05..ec8fe82 100644
--- a/x/mlxrunner/mlx/gated_delta.go
+++ b/x/mlxrunner/mlx/gated_delta.go
@@ -83,7 +83,7 @@ for (int t = 0; t < T; ++t) {
 
 for (int i = 0; i < n_per_t; ++i) {
   auto s_idx = n_per_t * dk_idx + i;
-  o_state[s_idx] = static_cast<InT>(state[i]);
+  o_state[s_idx] = static_cast<StT>(state[i]);
 }
 `
 
@@ -163,7 +163,7 @@ for (int t = 0; t < T_val; ++t) {
 
 for (int i = 0; i < n_per_t; ++i) {
   auto s_idx = n_per_t * dk_idx + i;
-  o_state[s_idx] = static_cast<InT>(state[i]);
+  o_state[s_idx] = static_cast<StT>(state[i]);
 }
 `
 
@@ -263,9 +263,11 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 	}
 
 	dtype := q.DType()
-	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
+	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype {
 		return nil, nil, false
 	}
+	// state may have a different dtype (typically fp32) than q/k/v/g/beta (typically bf16)
+	// — this matches MLX-LM, where state stays fp32 to preserve recurrent precision.
 
 	gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
 	if gatedDeltaMetalDisabled {
@@ -281,6 +283,12 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
+	cStT := C.CString("StT")
+	defer C.free(unsafe.Pointer(cStT))
+	if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(state.DType())) != 0 {
+		gatedDeltaMetalDisabled = true
+		return nil, nil, false
+	}
 	for _, tpl := range []struct {
 		name  string
 		value int
@@ -305,7 +313,7 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
-	if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
+	if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(state.DType())) != 0 {
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
@@ -517,9 +525,11 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 	}
 
 	dtype := q.DType()
-	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
+	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype {
 		return nil, nil, false
 	}
+	// state may have a different dtype (typically fp32) than q/k/v/g/beta (typically bf16)
+	// — this matches MLX-LM, where state stays fp32 to preserve recurrent precision.
 
 	gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
 	if gatedDeltaCUDADisabled {
@@ -535,6 +545,12 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
+	cStT := C.CString("StT")
+	defer C.free(unsafe.Pointer(cStT))
+	if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(state.DType())) != 0 {
+		gatedDeltaCUDADisabled = true
+		return nil, nil, false
+	}
 	for _, tpl := range []struct {
 		name  string
 		value int
@@ -559,7 +575,7 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
-	if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
+	if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(state.DType())) != 0 {
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go
index f29563f..6ff586d 100644
--- a/x/models/qwen3_5/qwen3_5.go
+++ b/x/models/qwen3_5/qwen3_5.go
@@ -1231,12 +1231,16 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
 
 	beta := mlx.Sigmoid(b)
 
+	// Recurrent state must be fp32 to match MLX-LM's reference and avoid bf16
+	// precision loss across many recurrent steps. This is the canonical
+	// gated-delta-rule precision contract; the kernel internally accumulates
+	// in float and now correctly casts to StT (fp32) when writing back.
 	var state *mlx.Array
 	if rc != nil {
-		state = rc.DeltaState(int(B), x.DType())
+		state = rc.DeltaState(int(B), mlx.DTypeFloat32)
 	}
 	if state == nil {
-		state = mlx.Zeros(x.DType(), int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim))
+		state = mlx.Zeros(mlx.DTypeFloat32, int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim))
 	}
 
 	out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state)
Originally created by @andreinknv on GitHub (Apr 28, 2026). Original GitHub issue: https://github.com/ollama/ollama/issues/15865 # mlxrunner: gated_delta_step kernel writes recurrent state in `InT` (bf16) instead of `StT` (fp32), corrupting Qwen3.5/3.6 GatedDeltaNet output ## What is the issue? The Metal/CUDA `gated_delta_step` kernel in `x/mlxrunner/mlx/gated_delta.go` casts the recurrent state output back to **`InT`** (bf16 for our model) before writing it to the next-step state buffer. The reference implementation in `mlx_lm.models.gated_delta` uses a separate **`StT`** template arg for the state dtype and keeps state in **fp32** to preserve precision across the recurrent decode loop. Result on Qwen3.5/Qwen3.6 35B-A3B (any of `bf16`/`mxfp8`/`nvfp4` variants): the model load, MoE routing, full-attention layers, embedding, and lm_head all work correctly, but the linear-attention-layer recurrent state degrades each decode step. The model emits one or two semi-correct tokens, then collapses into degenerate repetition (`<|im_start|><|im_start|>` echo, or repeating the last sampled token like `4\n4\n4\n4...`). ## Reproducer * macOS 26.4.1, M4 Max 36 GB unified memory, Xcode 26.4 * Ollama v0.22.0 plus PR #15793 (mlx 0.31.2) and PR #15760 (per-tensor quant overrides) applied — the panic and crash paths those fix are already resolved * Model: `qwen3.6:35b-a3b-coding-nvfp4` from the official library ```bash curl -s http://localhost:11434/api/chat -d '{ "model": "qwen3.6:35b-a3b-coding-nvfp4", "messages":[{"role":"user","content":"What is 2+2?"}], "stream": false, "think": false, "options": {"num_predict": 50, "temperature": 0.0} }' ``` ### Before this fix ``` content: '<|im_start|><|im_start|>' eval_count: 2 done_reason: stop ``` The model runs the forward pass cleanly (all 1436 tensors loaded, peak 19.7 GiB VRAM) and reaches the sampler, but the first few token logits put `<|im_start|>` (id 248045) at the top instead of any sensible response. ### After this fix ``` content: '\n\n4\n4\n4\n4\n4...' eval_count: 50 done_reason: length ``` The model now correctly identifies `4` as the answer to `2+2`. EOS detection and a couple of remaining linear-attention precision points (gDecay cast, conv state dtype) are likely additional issues — but the recurrent-state-precision bug alone gates the model from producing any meaningful tokens at all. ## Root cause Comparing to `mlx-lm/main/mlx_lm/models/gated_delta.py`: ```python # MLX-LM kernel template return mx.fast.metal_kernel( name=f"gated_delta_step{suffix}", input_names=inputs, output_names=["y", "state_out"], source=source, # source uses InT for y cast and StT for state cast ) ``` ```python # MLX-LM kernel source — state is cast back as StT, NOT InT: source = f""" ... for (int i = 0; i < n_per_t; ++i) {{ auto s_idx = n_per_t * dk_idx + i; o_state[s_idx] = static_cast<StT>(state[i]); # ← StT }} """ ``` ```python # MLX-LM gated_delta_update — state explicitly fp32: if state is None: state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32) ``` Ollama's kernel source has `static_cast<InT>(state[i])` and Ollama's `qwen3_5.go` initializes state with `x.DType()` (the input bf16). The kernel template never declares `StT`. Net effect: every recurrent step round-trips state through bf16 (~7 bits of mantissa), losing ~0.4% relative precision per step. Decay terms close to 1.0 (typical) compound this drift across 100s of decode steps until state is meaningless. ## Patch Three files (115 lines diff against v0.22.0). Summary: 1. `x/mlxrunner/mlx/gated_delta.go` — kernel source (Metal **and** CUDA) uses `StT` for the state output cast; kernel config adds `cStT` template arg with `state.DType()`; state output_arg uses `state.DType()`; dtype validation no longer requires `state.DType() == dtype`. 2. `x/models/qwen3_5/qwen3_5.go` — `GatedDeltaNet.Forward` initializes the recurrent state as `mlx.DTypeFloat32` (matches MLX-LM's `mx.zeros(..., dtype=mx.float32)`). 3. `x/mlxrunner/cache/recurrent.go` — split `ensure` into `ensureConv` and `ensureDelta` so the conv state can stay at `inputs.dtype` (bf16) while the delta/recurrent state is fp32. ## Validation | Test (temperature=0) | Before | After | |---|---|---| | `What is 2+2?` first token | `<|im_start|>` | **`4`** ✓ | | Token-level eval before `done` | 2 (immediate stop) | 50+ (proper inference) | | MLX panic / load crash | none (PR #15793 already fixed) | none | There are remaining issues with multi-token coherence (the model gets the first token right then often loops) — those look like additional precision-leak points (conv state, gDecay) that I'm continuing to investigate, but the state-precision bug is the single biggest one and gates everything else. ## Related * PR #15793 (MLX 0.31.2 + threading) — fixes load-time panic, complementary to this fix * PR #15759 / #15760 (per-tensor quant overrides) — fixes a separate `SparseMoE.Forward` panic, complementary to this fix * PR #14968 (mlx: qwen3.5 vision support) — touches `GatedDeltaNet` but is currently CONFLICTING with main; the patch above applies cleanly to v0.22.0 + #15793 + #15760 and could merge ahead of the broader vision work * Issue #15822 (`MLX runner failed with qwen3.6:35b-a3b-coding-bf16 format=json`) — possibly the same root cause, since bf16 variants also exercise the same kernel path ## Environment * Ollama version: v0.22.0 + #15793 + #15759 + #15760 + this patch * OS: macOS 26.4.1 * Hardware: Apple M4 Max, 36 GB unified * MLX: 0.31.2 (via #15793) ## Patch <details> <summary>full diff</summary> ```diff diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go index c691f05..ec8fe82 100644 --- a/x/mlxrunner/mlx/gated_delta.go +++ b/x/mlxrunner/mlx/gated_delta.go @@ -83,7 +83,7 @@ for (int t = 0; t < T; ++t) { for (int i = 0; i < n_per_t; ++i) { auto s_idx = n_per_t * dk_idx + i; - o_state[s_idx] = static_cast<InT>(state[i]); + o_state[s_idx] = static_cast<StT>(state[i]); } ` @@ -163,7 +163,7 @@ for (int t = 0; t < T_val; ++t) { for (int i = 0; i < n_per_t; ++i) { auto s_idx = n_per_t * dk_idx + i; - o_state[s_idx] = static_cast<InT>(state[i]); + o_state[s_idx] = static_cast<StT>(state[i]); } ` @@ -263,9 +263,11 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b } dtype := q.DType() - if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype { + if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype { return nil, nil, false } + // state may have a different dtype (typically fp32) than q/k/v/g/beta (typically bf16) + // — this matches MLX-LM, where state stays fp32 to preserve recurrent precision. gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel) if gatedDeltaMetalDisabled { @@ -281,6 +283,12 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b gatedDeltaMetalDisabled = true return nil, nil, false } + cStT := C.CString("StT") + defer C.free(unsafe.Pointer(cStT)) + if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(state.DType())) != 0 { + gatedDeltaMetalDisabled = true + return nil, nil, false + } for _, tpl := range []struct { name string value int @@ -305,7 +313,7 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b gatedDeltaMetalDisabled = true return nil, nil, false } - if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 { + if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(state.DType())) != 0 { gatedDeltaMetalDisabled = true return nil, nil, false } @@ -517,9 +525,11 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar } dtype := q.DType() - if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype { + if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype { return nil, nil, false } + // state may have a different dtype (typically fp32) than q/k/v/g/beta (typically bf16) + // — this matches MLX-LM, where state stays fp32 to preserve recurrent precision. gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel) if gatedDeltaCUDADisabled { @@ -535,6 +545,12 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar gatedDeltaCUDADisabled = true return nil, nil, false } + cStT := C.CString("StT") + defer C.free(unsafe.Pointer(cStT)) + if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(state.DType())) != 0 { + gatedDeltaCUDADisabled = true + return nil, nil, false + } for _, tpl := range []struct { name string value int @@ -559,7 +575,7 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar gatedDeltaCUDADisabled = true return nil, nil, false } - if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 { + if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(state.DType())) != 0 { gatedDeltaCUDADisabled = true return nil, nil, false } diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index f29563f..6ff586d 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -1231,12 +1231,16 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co beta := mlx.Sigmoid(b) + // Recurrent state must be fp32 to match MLX-LM's reference and avoid bf16 + // precision loss across many recurrent steps. This is the canonical + // gated-delta-rule precision contract; the kernel internally accumulates + // in float and now correctly casts to StT (fp32) when writing back. var state *mlx.Array if rc != nil { - state = rc.DeltaState(int(B), x.DType()) + state = rc.DeltaState(int(B), mlx.DTypeFloat32) } if state == nil { - state = mlx.Zeros(x.DType(), int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim)) + state = mlx.Zeros(mlx.DTypeFloat32, int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim)) } out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state) ``` </details>
Author
Owner

@andreinknv commented on GitHub (Apr 29, 2026):

Update: a second fix is needed in RecurrentCache to make this patch work end-to-end

Following up on this — when I tested the kernel patch (StT cast + FP32 dtype request) end-to-end, the model still produced incoherent output. Tracing through, I found that RecurrentCache.ensure() shares one dtype parameter to validate both convState (which the call site passes BF16, matching activations) and deltaState (which after this fix passes FP32). Each ConvState(B, BF16) call sees deltaState.DType() != BF16 and silently re-zeros it; each DeltaState(B, FP32) call sees convState.DType() != FP32 and silently re-zeros it. Since GatedDeltaNet.Forward calls both per layer, the cache is destroyed every forward step. Output ends up like "public static -19999... Copyright ofusr =".

Splitting ensure() into per-state init fixes it:

func (c *RecurrentCache) ensureConv(batch int, dtype mlx.DType) {
    if batch <= 0 { batch = 1 }
    if c.convState != nil && c.convState.Valid() && c.convState.DType() == dtype &&
        c.convState.Dim(0) == batch && c.convState.Dim(1) == c.convTail && c.convState.Dim(2) == c.convDim {
        return
    }
    c.convState = c.setState(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim), false)
}

func (c *RecurrentCache) ensureDelta(batch int, dtype mlx.DType) {
    if batch <= 0 { batch = 1 }
    if c.deltaState != nil && c.deltaState.Valid() && c.deltaState.DType() == dtype &&
        c.deltaState.Dim(0) == batch && c.deltaState.Dim(1) == c.numVHeads &&
        c.deltaState.Dim(2) == c.headVDim && c.deltaState.Dim(3) == c.headKDim {
        return
    }
    c.deltaState = c.setState(c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim), false)
}

func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
    c.ensureConv(batch, dtype); return c.convState
}
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
    c.ensureDelta(batch, dtype); return c.deltaState
}

With both pieces applied (gated_delta.go StT cast + qwen3_5.go FP32 request + recurrent.go split), qwen3.6:35b-a3b (and the equivalent mlx-community/Qwen3.6-35B-A3B-4bit imported via ollama create --experimental) generate coherent output at ~110 tok/s on M4 Max, matching the mlx-lm reference.

I've updated the patch attached to this issue to include the cache split — applying just the kernel/dtype changes without the cache split would actually introduce the symptoms this issue is trying to fix, since the dtype mismatch then triggers cache zeroing that didn't happen before.

(No KVCache changes — full-attention layers are unaffected.)

<!-- gh-comment-id:4340170538 --> @andreinknv commented on GitHub (Apr 29, 2026): ### Update: a second fix is needed in `RecurrentCache` to make this patch work end-to-end Following up on this — when I tested the kernel patch (StT cast + FP32 dtype request) end-to-end, the model still produced incoherent output. Tracing through, I found that `RecurrentCache.ensure()` shares one `dtype` parameter to validate both `convState` (which the call site passes BF16, matching activations) and `deltaState` (which after this fix passes FP32). Each `ConvState(B, BF16)` call sees `deltaState.DType() != BF16` and silently re-zeros it; each `DeltaState(B, FP32)` call sees `convState.DType() != FP32` and silently re-zeros it. Since `GatedDeltaNet.Forward` calls both per layer, the cache is destroyed every forward step. Output ends up like `"public static -19999... Copyright ofusr ="`. Splitting `ensure()` into per-state init fixes it: ```go func (c *RecurrentCache) ensureConv(batch int, dtype mlx.DType) { if batch <= 0 { batch = 1 } if c.convState != nil && c.convState.Valid() && c.convState.DType() == dtype && c.convState.Dim(0) == batch && c.convState.Dim(1) == c.convTail && c.convState.Dim(2) == c.convDim { return } c.convState = c.setState(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim), false) } func (c *RecurrentCache) ensureDelta(batch int, dtype mlx.DType) { if batch <= 0 { batch = 1 } if c.deltaState != nil && c.deltaState.Valid() && c.deltaState.DType() == dtype && c.deltaState.Dim(0) == batch && c.deltaState.Dim(1) == c.numVHeads && c.deltaState.Dim(2) == c.headVDim && c.deltaState.Dim(3) == c.headKDim { return } c.deltaState = c.setState(c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim), false) } func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array { c.ensureConv(batch, dtype); return c.convState } func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array { c.ensureDelta(batch, dtype); return c.deltaState } ``` With both pieces applied (`gated_delta.go` StT cast + `qwen3_5.go` FP32 request + `recurrent.go` split), `qwen3.6:35b-a3b` (and the equivalent `mlx-community/Qwen3.6-35B-A3B-4bit` imported via `ollama create --experimental`) generate coherent output at **~110 tok/s on M4 Max**, matching the `mlx-lm` reference. I've updated the patch attached to this issue to include the cache split — applying just the kernel/dtype changes without the cache split would actually *introduce* the symptoms this issue is trying to fix, since the dtype mismatch then triggers cache zeroing that didn't happen before. (No KVCache changes — full-attention layers are unaffected.) </content> </invoke>
Author
Owner

@andreinknv commented on GitHub (Apr 29, 2026):

PR up: #15870 — bundles the kernel StT cast and the RecurrentCache deltaState fp32 change into one commit, since they have to land together to keep the model coherent.

<!-- gh-comment-id:4340234288 --> @andreinknv commented on GitHub (Apr 29, 2026): PR up: #15870 — bundles the kernel `StT` cast and the `RecurrentCache` deltaState fp32 change into one commit, since they have to land together to keep the model coherent.
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: github-starred/ollama#72170