[PR #15870] fix(mlxrunner): preserve fp32 precision in gated_delta_step recurrent state #62033

Open
opened 2026-04-29 16:59:22 -05:00 by GiteaMirror · 0 comments
Owner

📋 Pull Request Information

Original PR: https://github.com/ollama/ollama/pull/15870
Author: @andreinknv
Created: 4/29/2026
Status: 🔄 Open

Base: mainHead: fix/qwen3_5-recurrent-state-fp32-precision


📝 Commits (1)

  • 1a172e3 fix(mlxrunner): preserve fp32 precision in gated_delta_step recurrent state

📊 Changes

2 files changed (+32 additions, -8 deletions)

View changed files

📝 x/mlxrunner/cache/recurrent.go (+10 -2)
📝 x/mlxrunner/mlx/gated_delta.go (+22 -6)

📄 Description

Summary

Fixes incoherent generation in Qwen3.5/3.6 GatedDeltaNet (linear-attention) layers by preserving the fp32 recurrent-state accumulator across kernel invocations, matching MLX-LM's reference. Refs #15865, #15866.

The gated_delta_step Metal/CUDA kernel computed state in float (fp32) inside the inner loop but cast it back to InT (the input dtype, typically bf16) when writing to o_state. That truncated 16 mantissa bits to 7 every recurrent step. Across 30 linear-attention layers and N tokens of a real prompt the state degrades enough that generation becomes incoherent (e.g. "Copyright ofusr =" for a "What is 2+2?" prompt against mlx-community/Qwen3.6-35B-A3B-4bit).

Changes

Two surgical changes in two files:

x/mlxrunner/mlx/gated_delta.go

  • Add an StT template arg to both Metal and CUDA kernels (separate from InT)
  • Cast state writes via static_cast<StT>(state[i]) instead of static_cast<InT>(state[i])
  • Loosen the state.DType() == dtype precondition so the kernel accepts fp32 state alongside bf16 inputs
  • Set the state output_arg dtype to state.DType() instead of the input dtype

x/mlxrunner/cache/recurrent.go

  • Hardcode deltaState to fp32 in ensure(). Conv state continues to track the activation dtype (typically bf16); only the recurrent accumulator is widened.
  • Documented why with a comment pointing at the kernel side and the MLX-LM reference.

No API changes, no qwen3_5.go changes -- call sites still pass a single dtype to RecurrentCache.Get(b, dtype), which now applies only to conv state. Full-attention layers and other models that don't use RecurrentCache are unaffected.

Why hardcode fp32 vs. add a parameter

MLX-LM's reference always allocates the recurrent state as mx.float32. There's no current model that wants a different precision for it, and adding a knob would just push the decision to every call site without a use case to justify it. Easier to revisit if a model ever needs bf16 state for memory reasons.

Memory cost

Delta state is [B, num_v_heads, head_v_dim, head_k_dim]. For Qwen3.6-35B-A3B (32 v-heads x 128 dim x 128 dim per head x 30 linear layers x B=1 = ~63 MB at fp32 vs ~16 MB at bf16). Negligible relative to the model weights.

Verification

  • go build ./x/... -- clean
  • go test ./x/mlxrunner/cache/... ./x/mlxrunner/mlx/... -- all pass

Functional verification (M4 Max / 36 GB / macOS 26.4) using mlx-community/Qwen3.6-35B-A3B-4bit imported via ollama create --experimental:

Before After
"What is 2+2?" "\n\n// Copyright ofusr = ..." "<think>\n\n</think>\n\nThe result of $2 + 2$ is **4**."
"Reverse a string in Go" gibberish clean idiomatic Go one-liner
Throughput 110 tok/s 110 tok/s
nomic-embed-text (regression check) works works

mlx-lm's reference on the same checkpoint runs ~112 tok/s for comparison.

Test plan

  • Reviewer can reproduce on Apple Silicon by importing any mlx-community/Qwen3.* MoE checkpoint via ollama create --experimental and running a short prompt -- output should be coherent rather than the Copyright/static/-1999... pattern reported in #15866.
  • CI: existing cache + mlx tests should continue to pass (verified locally).

Generated with Claude Code


🔄 This issue represents a GitHub Pull Request. It cannot be merged through Gitea due to API limitations.

## 📋 Pull Request Information **Original PR:** https://github.com/ollama/ollama/pull/15870 **Author:** [@andreinknv](https://github.com/andreinknv) **Created:** 4/29/2026 **Status:** 🔄 Open **Base:** `main` ← **Head:** `fix/qwen3_5-recurrent-state-fp32-precision` --- ### 📝 Commits (1) - [`1a172e3`](https://github.com/ollama/ollama/commit/1a172e308d862f9d9e046634dc8d0edbf104a181) fix(mlxrunner): preserve fp32 precision in gated_delta_step recurrent state ### 📊 Changes **2 files changed** (+32 additions, -8 deletions) <details> <summary>View changed files</summary> 📝 `x/mlxrunner/cache/recurrent.go` (+10 -2) 📝 `x/mlxrunner/mlx/gated_delta.go` (+22 -6) </details> ### 📄 Description ## Summary Fixes incoherent generation in Qwen3.5/3.6 GatedDeltaNet (linear-attention) layers by preserving the fp32 recurrent-state accumulator across kernel invocations, matching MLX-LM's reference. Refs #15865, #15866. The `gated_delta_step` Metal/CUDA kernel computed state in `float` (fp32) inside the inner loop but cast it back to `InT` (the input dtype, typically bf16) when writing to `o_state`. That truncated 16 mantissa bits to 7 every recurrent step. Across 30 linear-attention layers and N tokens of a real prompt the state degrades enough that generation becomes incoherent (e.g. `"Copyright ofusr ="` for a `"What is 2+2?"` prompt against `mlx-community/Qwen3.6-35B-A3B-4bit`). ## Changes Two surgical changes in two files: ### `x/mlxrunner/mlx/gated_delta.go` - Add an `StT` template arg to both Metal and CUDA kernels (separate from `InT`) - Cast state writes via `static_cast<StT>(state[i])` instead of `static_cast<InT>(state[i])` - Loosen the `state.DType() == dtype` precondition so the kernel accepts fp32 state alongside bf16 inputs - Set the state output_arg dtype to `state.DType()` instead of the input dtype ### `x/mlxrunner/cache/recurrent.go` - Hardcode `deltaState` to fp32 in `ensure()`. Conv state continues to track the activation dtype (typically bf16); only the recurrent accumulator is widened. - Documented why with a comment pointing at the kernel side and the MLX-LM reference. No API changes, no `qwen3_5.go` changes -- call sites still pass a single dtype to `RecurrentCache.Get(b, dtype)`, which now applies only to conv state. Full-attention layers and other models that don't use `RecurrentCache` are unaffected. ## Why hardcode fp32 vs. add a parameter MLX-LM's reference always allocates the recurrent state as `mx.float32`. There's no current model that wants a different precision for it, and adding a knob would just push the decision to every call site without a use case to justify it. Easier to revisit if a model ever needs bf16 state for memory reasons. ## Memory cost Delta state is `[B, num_v_heads, head_v_dim, head_k_dim]`. For Qwen3.6-35B-A3B (32 v-heads x 128 dim x 128 dim per head x 30 linear layers x B=1 = ~63 MB at fp32 vs ~16 MB at bf16). Negligible relative to the model weights. ## Verification - `go build ./x/...` -- clean - `go test ./x/mlxrunner/cache/... ./x/mlxrunner/mlx/...` -- all pass Functional verification (M4 Max / 36 GB / macOS 26.4) using `mlx-community/Qwen3.6-35B-A3B-4bit` imported via `ollama create --experimental`: | | Before | After | |---|---|---| | `"What is 2+2?"` | `"\n\n// Copyright ofusr = ..."` | `"<think>\n\n</think>\n\nThe result of $2 + 2$ is **4**."` | | `"Reverse a string in Go"` | gibberish | clean idiomatic Go one-liner | | Throughput | 110 tok/s | 110 tok/s | | `nomic-embed-text` (regression check) | works | works | mlx-lm's reference on the same checkpoint runs ~112 tok/s for comparison. ## Test plan - [ ] Reviewer can reproduce on Apple Silicon by importing any `mlx-community/Qwen3.*` MoE checkpoint via `ollama create --experimental` and running a short prompt -- output should be coherent rather than the `Copyright/static/-1999...` pattern reported in #15866. - [ ] CI: existing cache + mlx tests should continue to pass (verified locally). Generated with [Claude Code](https://claude.com/claude-code) </content> </invoke> --- <sub>🔄 This issue represents a GitHub Pull Request. It cannot be merged through Gitea due to API limitations.</sub>
GiteaMirror added the pull-request label 2026-04-29 16:59:22 -05:00
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: github-starred/ollama#62033