[GH-ISSUE #15866] qwen3.6:35b-a3b-coding-nvfp4 model file has corrupted K-projection weights (layers 0-1 entirely zero), making linear attention output zero #72171

Open
opened 2026-05-05 03:35:20 -05:00 by GiteaMirror · 5 comments
Owner

Originally created by @andreinknv on GitHub (Apr 28, 2026).
Original GitHub issue: https://github.com/ollama/ollama/issues/15866

qwen3.6:35b-a3b-coding-nvfp4 model file has corrupted K-projection weights for early linear-attention layers

What is the issue?

The NVFP4 packaging of qwen3.6:35b-a3b-coding-nvfp4 shipped from registry.ollama.ai/library/qwen3.6 has the K-projection portion of linear_attn.in_proj_qkv.weight entirely zeroed out for layer 0 and mostly zeroed for layer 1 — making coherent generation impossible regardless of any runner-side fix.

This is a model-packaging issue, not a runtime bug. The mlx-community/Qwen3.6-35B-A3B-4bit model (same architecture, same source weights, MLX 4-bit affine quant instead of NVFP4) has the expected natural sparsity (3-5%) in the same regions.

Reproducer

The linear_attn.in_proj_qkv.weight tensor is laid out per HF transformers as [Q rows 0..2047 | K rows 2048..4095 | V rows 4096..8191]. A direct read of the safetensors blob shows:

layer  0  K-zero-rows = 2048 / 2048   (100% — entire K projection is zero!)
layer  1  K-zero-rows = 1797 / 2048   ( 88%)
layer  2  K-zero-rows =   57 / 2048   (  3%)
layer  3  ... (full attention layer, no in_proj_qkv)
layer  4+ K-zero-rows ≤ 2 / 2048      (~natural sparsity)

For comparison, on mlx-community/Qwen3.6-35B-A3B-4bit (4-bit affine, same architecture):

layer 0  K-zero-rows =  56 / 2048
layer 1  K-zero-rows = 102 / 2048
layer 2  K-zero-rows =  56 / 2048

The mlx-community model generates coherent text via mlx-lm generate (top next token for "What is 2+2?" → "Here", correctly leading into a thinking trace). The Ollama NVFP4 model cannot generate coherent text under any code fix because layers 0 and 1 of the linear-attention path produce K=0 → linear attention output = 0 → the residual stream loses the entire linear-attention contribution from those layers.

Verification commands

# Reproduce against the cached blob (replace the SHA with whichever blob holds the layer-0 in_proj_qkv tensor):
import struct, json, numpy as np
blob = "/Users/<you>/.ollama/models/blobs/sha256-..."   # blob carrying layers.0.linear_attn.in_proj_qkv
with open(blob, 'rb') as f:
    hsz = struct.unpack('<Q', f.read(8))[0]
    hdr = json.loads(f.read(hsz))
    body = 8 + hsz
    info = hdr['language_model.model.layers.0.linear_attn.in_proj_qkv.weight']
    s, e = info['data_offsets']
    f.seek(body + s)
    arr = np.frombuffer(f.read(e-s), dtype=np.uint32).reshape(info['shape'])
    print('K rows zero:', sum(1 for r in arr[2048:4096] if not r.any()), '/ 2048')
# → 2048 / 2048 on the broken model

Likely cause

The conversion from BF16 (or the source FP8/BF16 RedHatAI NVFP4) to Ollama's NVFP4 packaging picked a global scale per tensor that was too large to represent the small magnitudes in early-layer K projections — they all rounded to the FP4 zero codepoint.

For verification: the scales tensor for in_proj_qkv is non-zero across all rows (max-row-abs ≈ 1144 on K rows for layer 0), but the underlying FP4 nibbles are all zero, so dequantization yields zero regardless of the (non-zero) scale. That fingerprint matches "small-magnitude rows quantized below FP4 representable range" rather than "block of weights deleted in I/O".

Suggested fixes

  1. Re-quantize with per-tensor (or per-row block) scales chosen to preserve the K projection's dynamic range — or fall back to a higher-precision quant for early-layer linear-attention weights.
  2. Check the same artefact for the mxfp8 and other variants in the library (qwen3.6:35b-a3b-coding-mxfp8, qwen3.6:35b-a3b-nvfp4, qwen3.6:27b-coding-nvfp4).
  3. Re-run the conversion using the per-tensor quant overrides path (e.g. matching the pattern in #15760, where mlp.gate and mlp.shared_expert_gate get an 8-bit override) — applying an 8-bit override (or skipping quant entirely) to linear_attn.in_proj_qkv would avoid the zero-collapse.
  • #15865 — independent recurrent-state precision bug in gated_delta_step (the kernel cast state to InT instead of StT). With only that fix and without this corruption, the model reaches the sampler but the very first decode step gets wrong logits because layer-0 linear attention contributes zero. Together, both issues explain the "broken Qwen3.6 NVFP4 generation on Ollama".
  • #15822, #15834, #15700 — user reports of broken generation on this model family on macOS; this issue likely explains the coding-nvfp4 sub-cases.
  • mlx-community / Qwen3.6-35B-A3B-4bit on HF works correctly via mlx-lm, demonstrating the architecture is otherwise sound.

Environment

  • Ollama version: v0.22.0 (as of 2026-04-28)
  • OS: macOS 26.4.1, Apple M4 Max
  • Source model: registry.ollama.ai/library/qwen3.6:35b-a3b-coding-nvfp4 (digest sha256:cd2692a833e6...)
Originally created by @andreinknv on GitHub (Apr 28, 2026). Original GitHub issue: https://github.com/ollama/ollama/issues/15866 # `qwen3.6:35b-a3b-coding-nvfp4` model file has corrupted K-projection weights for early linear-attention layers ## What is the issue? The NVFP4 packaging of `qwen3.6:35b-a3b-coding-nvfp4` shipped from `registry.ollama.ai/library/qwen3.6` has the K-projection portion of `linear_attn.in_proj_qkv.weight` **entirely zeroed out** for layer 0 and mostly zeroed for layer 1 — making coherent generation impossible regardless of any runner-side fix. This is a model-packaging issue, not a runtime bug. The `mlx-community/Qwen3.6-35B-A3B-4bit` model (same architecture, same source weights, MLX 4-bit affine quant instead of NVFP4) has the expected natural sparsity (3-5%) in the same regions. ## Reproducer The `linear_attn.in_proj_qkv.weight` tensor is laid out per HF transformers as `[Q rows 0..2047 | K rows 2048..4095 | V rows 4096..8191]`. A direct read of the safetensors blob shows: ``` layer 0 K-zero-rows = 2048 / 2048 (100% — entire K projection is zero!) layer 1 K-zero-rows = 1797 / 2048 ( 88%) layer 2 K-zero-rows = 57 / 2048 ( 3%) layer 3 ... (full attention layer, no in_proj_qkv) layer 4+ K-zero-rows ≤ 2 / 2048 (~natural sparsity) ``` For comparison, on `mlx-community/Qwen3.6-35B-A3B-4bit` (4-bit affine, same architecture): ``` layer 0 K-zero-rows = 56 / 2048 layer 1 K-zero-rows = 102 / 2048 layer 2 K-zero-rows = 56 / 2048 ``` The mlx-community model generates coherent text via `mlx-lm generate` (top next token for "What is 2+2?" → "Here", correctly leading into a thinking trace). The Ollama NVFP4 model cannot generate coherent text under any code fix because layers 0 and 1 of the linear-attention path produce K=0 → linear attention output = 0 → the residual stream loses the entire linear-attention contribution from those layers. ## Verification commands ```python # Reproduce against the cached blob (replace the SHA with whichever blob holds the layer-0 in_proj_qkv tensor): import struct, json, numpy as np blob = "/Users/<you>/.ollama/models/blobs/sha256-..." # blob carrying layers.0.linear_attn.in_proj_qkv with open(blob, 'rb') as f: hsz = struct.unpack('<Q', f.read(8))[0] hdr = json.loads(f.read(hsz)) body = 8 + hsz info = hdr['language_model.model.layers.0.linear_attn.in_proj_qkv.weight'] s, e = info['data_offsets'] f.seek(body + s) arr = np.frombuffer(f.read(e-s), dtype=np.uint32).reshape(info['shape']) print('K rows zero:', sum(1 for r in arr[2048:4096] if not r.any()), '/ 2048') # → 2048 / 2048 on the broken model ``` ## Likely cause The conversion from BF16 (or the source FP8/BF16 RedHatAI NVFP4) to Ollama's NVFP4 packaging picked a global scale per tensor that was too large to represent the small magnitudes in early-layer K projections — they all rounded to the FP4 zero codepoint. For verification: the *scales* tensor for `in_proj_qkv` is non-zero across all rows (max-row-abs ≈ 1144 on K rows for layer 0), but the underlying FP4 nibbles are all zero, so dequantization yields zero regardless of the (non-zero) scale. That fingerprint matches "small-magnitude rows quantized below FP4 representable range" rather than "block of weights deleted in I/O". ## Suggested fixes 1. Re-quantize with per-tensor (or per-row block) scales chosen to preserve the K projection's dynamic range — or fall back to a higher-precision quant for early-layer linear-attention weights. 2. Check the same artefact for the `mxfp8` and other variants in the library (`qwen3.6:35b-a3b-coding-mxfp8`, `qwen3.6:35b-a3b-nvfp4`, `qwen3.6:27b-coding-nvfp4`). 3. Re-run the conversion using the per-tensor quant overrides path (e.g. matching the pattern in #15760, where `mlp.gate` and `mlp.shared_expert_gate` get an 8-bit override) — applying an 8-bit override (or skipping quant entirely) to `linear_attn.in_proj_qkv` would avoid the zero-collapse. ## Related * #15865 — independent recurrent-state precision bug in `gated_delta_step` (the kernel cast state to `InT` instead of `StT`). With **only** that fix and **without** this corruption, the model reaches the sampler but the very first decode step gets wrong logits because layer-0 linear attention contributes zero. Together, both issues explain the "broken Qwen3.6 NVFP4 generation on Ollama". * #15822, #15834, #15700 — user reports of broken generation on this model family on macOS; this issue likely explains the coding-nvfp4 sub-cases. * mlx-community / Qwen3.6-35B-A3B-4bit on HF works correctly via `mlx-lm`, demonstrating the architecture is otherwise sound. ## Environment * Ollama version: v0.22.0 (as of 2026-04-28) * OS: macOS 26.4.1, Apple M4 Max * Source model: `registry.ollama.ai/library/qwen3.6:35b-a3b-coding-nvfp4` (digest `sha256:cd2692a833e6...`)
Author
Owner

@andreinknv commented on GitHub (Apr 28, 2026):

Confirmed the corruption is introduced by Ollama's repackaging, not in the upstream source.

RedHatAI/Qwen3.6-35B-A3B-NVFP4 (the likely upstream source) keeps all linear_attn.* weights in BF16, with only the bulk MoE expert weights in NVFP4 (a mixed-precision approach):

model.language_model.layers.0.linear_attn.in_proj_qkv.weight  shape=[8192, 2048] dtype=BF16
model.language_model.layers.0.linear_attn.in_proj_z.weight    shape=[4096, 2048] dtype=BF16
model.language_model.layers.0.linear_attn.in_proj_a.weight    shape=[32, 2048]   dtype=BF16
model.language_model.layers.0.linear_attn.in_proj_b.weight    shape=[32, 2048]   dtype=BF16

A direct HTTP-range fetch of just the K-region (rows 2048-4095) from RedHat's model.safetensors confirms the source weights are fully populated:

RedHatAI source K rows zero: 0 / 2048
first-row first-8 floats: [0.0147, -0.0098, 0.0172, -0.0074, 0.0052, 0.0064, -0.0405, -0.0004]

Compare to Ollama's repackaged tensor (registry.ollama.ai/library/qwen3.6:35b-a3b-coding-nvfp4):

Ollama tensor: shape=[8192, 256] dtype=U32 (NVFP4 packed, group_size=16)
K rows zero: 2048 / 2048   ← every K row collapsed to zero on quantization

The values RedHat keeps in BF16 are typical small attention-projection magnitudes (~0.01-0.04). When NVFP4-quantized at group_size=16, those magnitudes fall below the smallest representable codepoint and round to zero across whole rows.

Conclusion: Ollama's NVFP4 packaging tool is quantizing tensors that should remain BF16 (or at minimum should use a higher-bit override, mirroring how mlp.gate already gets an 8-bit override). The fix is to mirror RedHat's mixed-precision recipe — keep linear_attn.* projections in BF16, NVFP4 only the MoE expert weights.

This affects every Qwen3.6/3.5 35B-A3B *-nvfp4 tag in the official library; same check should be re-run on 27b-coding-nvfp4, 35b-a3b-nvfp4 (non-coding), and 27b-nvfp4.

In the meantime, users on Apple Silicon can route around this by running mlx-community/Qwen3.6-35B-A3B-4bit via mlx-lm generate directly — that quant uses MLX 4-bit affine with group_size=64 and produces coherent output (verified: top next token for "What is 2+2?" is Here, leading into a thinking trace).

<!-- gh-comment-id:4339694632 --> @andreinknv commented on GitHub (Apr 28, 2026): Confirmed the corruption is introduced by Ollama's repackaging, not in the upstream source. `RedHatAI/Qwen3.6-35B-A3B-NVFP4` (the likely upstream source) keeps all `linear_attn.*` weights in **BF16**, with only the bulk MoE expert weights in NVFP4 (a mixed-precision approach): ``` model.language_model.layers.0.linear_attn.in_proj_qkv.weight shape=[8192, 2048] dtype=BF16 model.language_model.layers.0.linear_attn.in_proj_z.weight shape=[4096, 2048] dtype=BF16 model.language_model.layers.0.linear_attn.in_proj_a.weight shape=[32, 2048] dtype=BF16 model.language_model.layers.0.linear_attn.in_proj_b.weight shape=[32, 2048] dtype=BF16 ``` A direct HTTP-range fetch of just the K-region (rows 2048-4095) from RedHat's `model.safetensors` confirms the source weights are fully populated: ``` RedHatAI source K rows zero: 0 / 2048 first-row first-8 floats: [0.0147, -0.0098, 0.0172, -0.0074, 0.0052, 0.0064, -0.0405, -0.0004] ``` Compare to Ollama's repackaged tensor (`registry.ollama.ai/library/qwen3.6:35b-a3b-coding-nvfp4`): ``` Ollama tensor: shape=[8192, 256] dtype=U32 (NVFP4 packed, group_size=16) K rows zero: 2048 / 2048 ← every K row collapsed to zero on quantization ``` The values RedHat keeps in BF16 are typical small attention-projection magnitudes (~0.01-0.04). When NVFP4-quantized at group_size=16, those magnitudes fall below the smallest representable codepoint and round to zero across whole rows. **Conclusion**: Ollama's NVFP4 packaging tool is quantizing tensors that should remain BF16 (or at minimum should use a higher-bit override, mirroring how `mlp.gate` already gets an 8-bit override). The fix is to mirror RedHat's mixed-precision recipe — keep `linear_attn.*` projections in BF16, NVFP4 only the MoE expert weights. This affects every Qwen3.6/3.5 35B-A3B `*-nvfp4` tag in the official library; same check should be re-run on `27b-coding-nvfp4`, `35b-a3b-nvfp4` (non-coding), and `27b-nvfp4`. In the meantime, users on Apple Silicon can route around this by running `mlx-community/Qwen3.6-35B-A3B-4bit` via `mlx-lm generate` directly — that quant uses MLX 4-bit affine with group_size=64 and produces coherent output (verified: top next token for "What is 2+2?" is `Here`, leading into a thinking trace).
Author
Owner

@andreinknv commented on GitHub (Apr 28, 2026):

Working around it on Apple Silicon today

For anyone hitting this and wanting a working setup right now, the same model architecture is available pre-packaged on Hugging Face as mlx-community/Qwen3.6-35B-A3B-4bit (4-bit affine, group_size 64). Running it through mlx-lm directly produces coherent output for both reasoning and coding prompts.

Verified on M4 Max / 36 GB / macOS 26.4.1:

  • Peak memory: ~19.7 GB
  • Throughput: ~112 tok/s
  • Coding output: clean, type-hinted Python with proper error handling at temp 0

Setup (one-time, ~5 min)

# 1) Need a recent Python; brew Python 3.14 in a venv works:
/opt/homebrew/bin/python3.14 -m venv ~/mlxlm-venv
source ~/mlxlm-venv/bin/activate

# 2) Install mlx-lm from main (the PyPI 0.29.x release doesn't ship qwen3_5 yet):
pip install --upgrade pip
pip install "git+https://github.com/ml-explore/mlx-lm@main"

# 3) Download the model (~21 GB):
python -c "from huggingface_hub import snapshot_download; \
  snapshot_download(repo_id='mlx-community/Qwen3.6-35B-A3B-4bit', \
                    local_dir='$HOME/qwen3_6_mlx')"

Daily-driver wrapper

#!/bin/bash
# qwen-chat.sh
set -e
source ~/mlxlm-venv/bin/activate
MODEL="${MODEL:-$HOME/qwen3_6_mlx}"
if [ "$1" = "--once" ]; then
  shift
  python -m mlx_lm generate --model "$MODEL" \
    --prompt "$*" --max-tokens "${MAX_TOKENS:-1000}" --temp "${TEMP:-0.3}"
else
  exec python -m mlx_lm chat --model "$MODEL"
fi
# Interactive chat:
./qwen-chat.sh

# One-shot:
MAX_TOKENS=1500 TEMP=0.0 ./qwen-chat.sh --once \
  "Write a Python function that parses ISO-8601 dates."

This leaves your Ollama install untouched — switch back any time with brew services start ollama.

Notes

  • Memory tier: 4-bit (~21 GB) is what fits 36 GB unified memory. The 8-bit MLX variant (unsloth/Qwen3.6-35B-A3B-MLX-8bit) weighs ~37.7 GB on disk and won't fit once OS + KV cache + activations are accounted for.
  • About the coding tag: there's no separate Qwen3.6-Coder weights on Hugging Face — the base Qwen3.6-35B-A3B does coding well on its own.
  • Mixed precision recipe: the working MLX-format models keep linear_attn.* projections in higher precision (BF16) and only quantize the bulk MoE expert weights, which avoids the magnitudes-too-small-for-FP4 case.
  • Code-side companion fix: #15865 is a small kernel-precision patch that variants of this model family on the MLX runner are likely to need regardless of which quant they use; it's mergeable independently of this issue.
<!-- gh-comment-id:4339768732 --> @andreinknv commented on GitHub (Apr 28, 2026): ### Working around it on Apple Silicon today For anyone hitting this and wanting a working setup right now, the same model architecture is available pre-packaged on Hugging Face as `mlx-community/Qwen3.6-35B-A3B-4bit` (4-bit affine, group_size 64). Running it through `mlx-lm` directly produces coherent output for both reasoning and coding prompts. Verified on **M4 Max / 36 GB / macOS 26.4.1**: - Peak memory: **~19.7 GB** - Throughput: **~112 tok/s** - Coding output: clean, type-hinted Python with proper error handling at temp 0 #### Setup (one-time, ~5 min) ```bash # 1) Need a recent Python; brew Python 3.14 in a venv works: /opt/homebrew/bin/python3.14 -m venv ~/mlxlm-venv source ~/mlxlm-venv/bin/activate # 2) Install mlx-lm from main (the PyPI 0.29.x release doesn't ship qwen3_5 yet): pip install --upgrade pip pip install "git+https://github.com/ml-explore/mlx-lm@main" # 3) Download the model (~21 GB): python -c "from huggingface_hub import snapshot_download; \ snapshot_download(repo_id='mlx-community/Qwen3.6-35B-A3B-4bit', \ local_dir='$HOME/qwen3_6_mlx')" ``` #### Daily-driver wrapper ```bash #!/bin/bash # qwen-chat.sh set -e source ~/mlxlm-venv/bin/activate MODEL="${MODEL:-$HOME/qwen3_6_mlx}" if [ "$1" = "--once" ]; then shift python -m mlx_lm generate --model "$MODEL" \ --prompt "$*" --max-tokens "${MAX_TOKENS:-1000}" --temp "${TEMP:-0.3}" else exec python -m mlx_lm chat --model "$MODEL" fi ``` ```bash # Interactive chat: ./qwen-chat.sh # One-shot: MAX_TOKENS=1500 TEMP=0.0 ./qwen-chat.sh --once \ "Write a Python function that parses ISO-8601 dates." ``` This leaves your Ollama install untouched — switch back any time with `brew services start ollama`. #### Notes - **Memory tier**: 4-bit (~21 GB) is what fits 36 GB unified memory. The 8-bit MLX variant (`unsloth/Qwen3.6-35B-A3B-MLX-8bit`) weighs ~37.7 GB on disk and won't fit once OS + KV cache + activations are accounted for. - **About the `coding` tag**: there's no separate `Qwen3.6-Coder` weights on Hugging Face — the base `Qwen3.6-35B-A3B` does coding well on its own. - **Mixed precision recipe**: the working MLX-format models keep `linear_attn.*` projections in higher precision (BF16) and only quantize the bulk MoE expert weights, which avoids the magnitudes-too-small-for-FP4 case. - **Code-side companion fix**: #15865 is a small kernel-precision patch that variants of this model family on the MLX runner are likely to need regardless of which quant they use; it's mergeable independently of this issue.
Author
Owner

@andreinknv commented on GitHub (Apr 29, 2026):

Update: the working mlx-community model can also run inside Ollama (with #15865 + a small cache fix)

Following up on the workaround above — for anyone who'd rather stay inside Ollama, importing mlx-community/Qwen3.6-35B-A3B-4bit via ollama create --experimental works if the local build also includes a small follow-up to #15865 (RecurrentCache ensure split — see my comment on #15865). Without that follow-up, the dtype change in #15865 silently zeros the cache between every step and you get the same gibberish that this issue describes for the NVFP4 build, so I missed the path the first time around.

End-to-end import + run on M4 Max / 36 GB:

# 1) Build patched Ollama with #15865 (kernel + qwen3_5 + cache split).

# 2) Import the working MLX model:
cat > Modelfile <<'EOF'
FROM /path/to/mlx-community/Qwen3.6-35B-A3B-4bit
EOF
./dist/darwin-arm64/ollama create --experimental qwen3.6-mlx-4bit -f Modelfile

# 3) Run it:
./dist/darwin-arm64/ollama run qwen3.6-mlx-4bit

Verified throughput: ~110 tok/s (essentially identical to running the same checkpoint via mlx-lm directly, ~112 tok/s).

This is independent of the NVFP4 packaging fix this issue tracks — once the upstream NVFP4 tag is re-quantized with the mixed-precision recipe (skip linear_attn.*), the patched build should make qwen3.6:35b-a3b-coding-nvfp4 work too without needing the import dance.

<!-- gh-comment-id:4340170806 --> @andreinknv commented on GitHub (Apr 29, 2026): ### Update: the working `mlx-community` model can also run inside Ollama (with #15865 + a small cache fix) Following up on the workaround above — for anyone who'd rather stay inside Ollama, importing `mlx-community/Qwen3.6-35B-A3B-4bit` via `ollama create --experimental` works *if* the local build also includes a small follow-up to #15865 (`RecurrentCache` ensure split — see [my comment on #15865](https://github.com/ollama/ollama/issues/15865)). Without that follow-up, the dtype change in #15865 silently zeros the cache between every step and you get the same gibberish that this issue describes for the NVFP4 build, so I missed the path the first time around. End-to-end import + run on M4 Max / 36 GB: ```bash # 1) Build patched Ollama with #15865 (kernel + qwen3_5 + cache split). # 2) Import the working MLX model: cat > Modelfile <<'EOF' FROM /path/to/mlx-community/Qwen3.6-35B-A3B-4bit EOF ./dist/darwin-arm64/ollama create --experimental qwen3.6-mlx-4bit -f Modelfile # 3) Run it: ./dist/darwin-arm64/ollama run qwen3.6-mlx-4bit ``` Verified throughput: **~110 tok/s** (essentially identical to running the same checkpoint via `mlx-lm` directly, ~112 tok/s). This is independent of the NVFP4 packaging fix this issue tracks — once the upstream NVFP4 tag is re-quantized with the mixed-precision recipe (skip `linear_attn.*`), the patched build should make `qwen3.6:35b-a3b-coding-nvfp4` work too without needing the import dance. </content> </invoke>
Author
Owner

@andreinknv commented on GitHub (Apr 29, 2026):

PR for the related precision/cache fix is up at #15870 — independent of the NVFP4 packaging issue this report tracks, but applying it is a prerequisite for the working mlx-community import path described in my earlier comment to actually run end-to-end.

<!-- gh-comment-id:4340234327 --> @andreinknv commented on GitHub (Apr 29, 2026): PR for the related precision/cache fix is up at #15870 — independent of the NVFP4 packaging issue this report tracks, but applying it is a prerequisite for the working `mlx-community` import path described in my earlier comment to actually run end-to-end.
Author
Owner

@ArkaD171717 commented on GitHub (Apr 30, 2026):

Wrote PR: #15902

<!-- gh-comment-id:4354503507 --> @ArkaD171717 commented on GitHub (Apr 30, 2026): Wrote PR: #15902
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: github-starred/ollama#72171