models: add nemotronh architecture support (#14356)

This commit is contained in:
Jeffrey Morgan
2026-02-22 15:09:14 -08:00
committed by GitHub
parent 06edabdde1
commit 0ade9205cc
22 changed files with 3196 additions and 4 deletions

View File

@@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {

View File

@@ -15,6 +15,7 @@ import (
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nemotronh"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/olmo3"
_ "github.com/ollama/ollama/model/models/qwen2"

View File

@@ -0,0 +1,88 @@
package nemotronh
import (
"fmt"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// Attention implements simple attention without RoPE for Nemotron-H.
// Unlike Qwen3Next, Nemotron-H attention has:
// - No RoPE (position info comes from Mamba2 layers)
// - Standard scaled dot-product attention
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
hiddenDim := hiddenStates.Dim(0)
nSeqTokens := hiddenStates.Dim(1)
switch hiddenStates.Dim(2) {
case 0:
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
case 1:
default:
return nil, ErrUnsupportedBatchLayout
}
// Nemotron-H is currently clamped to num_parallel=1.
if cache != nil && cache.IsSupportedForBatch() {
if cache.numSeqs() != 1 {
return nil, ErrUnsupportedBatchLayout
}
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
return nil, ErrUnsupportedBatchLayout
}
}
batchSize := nSeqTokens
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize)
headDim := opts.getHeadDim()
if headDim <= 0 {
return nil, fmt.Errorf("nemotronh: invalid attention head dimension %d", headDim)
}
// Q projection
query := a.Query.Forward(ctx, hiddenStates)
if query.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: query dim %d not divisible by head dim %d", query.Dim(0), headDim)
}
numHeads := query.Dim(0) / headDim
query = query.Reshape(ctx, headDim, numHeads, batchSize)
// K projection
key := a.Key.Forward(ctx, hiddenStates)
if key.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: key dim %d not divisible by head dim %d", key.Dim(0), headDim)
}
numKVHeads := key.Dim(0) / headDim
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
// V projection
value := a.Value.Forward(ctx, hiddenStates)
if value.Dim(0)%headDim != 0 {
return nil, fmt.Errorf("nemotronh: value dim %d not divisible by head dim %d", value.Dim(0), headDim)
}
if value.Dim(0)/headDim != numKVHeads {
return nil, fmt.Errorf("nemotronh: key heads %d and value heads %d do not match", numKVHeads, value.Dim(0)/headDim)
}
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
// Standard attention computation (no RoPE)
scale := opts.attentionScale
if scale == 0 {
scale = 1.0 / math.Sqrt(float64(headDim))
}
attention := nn.Attention(ctx, query, key, value, scale, cache)
// Flatten heads
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
// Output projection
return a.Output.Forward(ctx, attention), nil
}

View File

@@ -0,0 +1,55 @@
package nemotronh
import (
"errors"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
// with the layer requirements.
var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout")
var (
_ kvcache.Cache = (*HybridCache)(nil)
_ kvcache.CheckpointCache = (*HybridCache)(nil)
)
// HybridCache adapts the shared recurrent cache base for Nemotron-H naming.
type HybridCache struct {
*kvcache.Recurrent
}
func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache {
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
Shift: Shift,
ConvDim: convDim,
ConvChannels: convChannels,
RecurrentStateSize: ssmStateSize,
CheckpointLogPrefix: "nemotronh",
})
return &HybridCache{Recurrent: base}
}
// SSMState returns the SSM state for current batch sequences.
func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) {
return c.RecurrentState4D(ctx, layer, dState, headDim, nHead)
}
// UpdateSSMState writes a new SSM state for current batch sequences.
func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) {
c.UpdateRecurrentState(ctx, layer, newState)
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.SlotsTensor()
}
func (c *HybridCache) seqTokens() int {
return c.SeqTokens()
}
func (c *HybridCache) numSeqs() int {
return c.NumSeqs()
}

View File

@@ -0,0 +1,197 @@
package nemotronh
import (
"log/slog"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// convKernel wraps the 1D convolution kernel tensor
type convKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// Mamba2 implements the Mamba2 SSM layer for Nemotron-H.
// The forward pass follows llama.cpp's build_mamba2_layer:
// 1. Input projection: zxBCdt = SSMIn @ hidden
// 2. Split: z, xBC, dt
// 3. Concat with conv state, apply SSMConv, save new conv state
// 4. Apply SiLU to convolved xBC
// 5. Split: x, B, C
// 6. Add dt bias
// 7. SSMScan: y = SSMScan(state, x, dt, A, B, C, ids)
// 8. D skip: y = y + x * D
// 9. Swiglu with z: y = z * silu(y)
// 10. Group RMSNorm
// 11. Output projection
type Mamba2 struct {
SSMIn *nn.Linear `gguf:"ssm_in"` // n_embd → d_in_proj (2*d_inner + 2*n_group*d_state + n_head)
SSMConv1D *convKernel `gguf:"ssm_conv1d"` // conv kernel
SSMConv1DB ml.Tensor `gguf:"ssm_conv1d.bias"`
SSMDtB ml.Tensor `gguf:"ssm_dt.bias"` // dt bias [n_head]
SSMA ml.Tensor `gguf:"ssm_a"` // A parameter [1, n_head]
SSMD ml.Tensor `gguf:"ssm_d"` // D skip connection [1, n_head]
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` // group norm
SSMOut *nn.Linear `gguf:"ssm_out"` // output projection
Layer int
}
func (m *Mamba2) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
layer := m.Layer
hiddenDim := hiddenStates.Dim(0)
nSeqTokens := hiddenStates.Dim(1)
switch hiddenStates.Dim(2) {
case 0:
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
case 1:
default:
return nil, ErrUnsupportedBatchLayout
}
// Nemotron-H is currently clamped to num_parallel=1.
if cache != nil && cache.IsSupportedForBatch() {
if cache.numSeqs() != 1 {
return nil, ErrUnsupportedBatchLayout
}
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
return nil, ErrUnsupportedBatchLayout
}
}
nSeqs := 1
dConv := opts.ssmDConv
dInner := opts.ssmDInner
dState := opts.ssmDState
nHead := opts.ssmNHead
headDim := dInner / nHead
nGroup := opts.ssmNGroup
// {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
// d_in_proj = 2*d_inner + 2*n_group*d_state + n_head
zxBCdt := m.SSMIn.Forward(ctx, hiddenStates)
// Split into z, xBC, dt
// z: [head_dim, n_head, n_seq_tokens, n_seqs]
z := zxBCdt.Slice(ctx, 0, 0, dInner, 1)
z = z.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
// xBC: [d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs]
xBCSize := dInner + 2*nGroup*dState
xBC := zxBCdt.Slice(ctx, 0, dInner, dInner+xBCSize, 1)
if nSeqTokens == 1 {
xBC = xBC.Reshape(ctx, xBCSize, 1, nSeqs)
}
// dt: [n_head, n_seq_tokens, n_seqs]
dt := zxBCdt.Slice(ctx, 0, 2*dInner+2*nGroup*dState, 2*dInner+2*nGroup*dState+nHead, 1)
if nSeqTokens == 1 {
dt = dt.Reshape(ctx, nHead, 1, nSeqs)
} else {
dt = dt.Contiguous(ctx, nHead, nSeqTokens, nSeqs)
}
// Get conv state from cache
convStates, err := cache.ConvState(ctx, layer)
if err != nil {
slog.Warn("nemotronh: failed to get conv state, using zeros", "layer", layer, "error", err)
convStates = ctx.Input().Zeros(ml.DTypeF32, dConv-1, xBCSize, nSeqs)
}
// Reshape conv states: [d_conv-1, xBCSize, n_seqs]
convStates = convStates.Reshape(ctx, dConv-1, xBCSize, nSeqs)
// For decode (n_seq_tokens == 1), reshape avoids a transpose/contiguous pair.
var xBCT ml.Tensor
if nSeqTokens == 1 {
xBCT = xBC.Reshape(ctx, 1, xBCSize, nSeqs)
} else {
// Prefill path: [xBCSize, n_seq_tokens, n_seqs] -> [n_seq_tokens, xBCSize, n_seqs]
xBCT = xBC.Permute(ctx, 1, 0, 2, 3)
}
// Concatenate with conv state: [d_conv-1 + n_seq_tokens, xBCSize, n_seqs]
convInput := convStates.Concat(ctx, xBCT, 0)
// Save new conv state (last d_conv-1 columns)
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+dConv-1, 1)
cache.UpdateConvState(ctx, layer, lastConvStates)
// Apply SSM convolution
xBC = convInput.SSMConv(ctx, m.SSMConv1D.Weight)
// Add conv bias
if m.SSMConv1DB != nil {
xBC = xBC.Add(ctx, m.SSMConv1DB)
}
// Apply SiLU
xBC = xBC.SILU(ctx)
// Split xBC into x, B, C
// x: [head_dim, n_head, n_seq_tokens, n_seqs]
x := xBC.Slice(ctx, 0, 0, dInner, 1)
x = x.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
// B: [d_state, n_group, n_seq_tokens, n_seqs]
B := xBC.Slice(ctx, 0, dInner, dInner+nGroup*dState, 1)
B = B.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
// C: [d_state, n_group, n_seq_tokens, n_seqs]
C := xBC.Slice(ctx, 0, dInner+nGroup*dState, dInner+2*nGroup*dState, 1)
C = C.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
// Add dt bias
dt = dt.Add(ctx, m.SSMDtB)
// Get SSM state from cache
state, err := cache.SSMState(ctx, layer, dState, headDim, nHead)
if err != nil {
slog.Warn("nemotronh: failed to get SSM state, using zeros", "layer", layer, "error", err)
state = ctx.Input().Zeros(ml.DTypeF32, dState, headDim, nHead, nSeqs)
}
// SSMScan
// state: [d_state, head_dim, n_head, n_seqs]
// returns: [head_dim, n_head, n_seq_tokens, n_seqs] concatenated with new state
ySsm := state.SSMScan(ctx, x, dt, m.SSMA, B, C, cache.slotsTensor())
// ySsm is a packed 1D buffer: [y (nSeqTokens*headDim*nHead*nSeqs), newState]
yElems := headDim * nHead * nSeqTokens * nSeqs
y := ySsm.View(ctx, 0, yElems).Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
stateOffsetBytes := yElems * x.Stride(0)
stateElems := dState * headDim * nHead * nSeqs
newState := ySsm.View(ctx, stateOffsetBytes, stateElems)
newState = newState.Reshape(ctx, dState, headDim, nHead, nSeqs)
// Update SSM state in cache
cache.UpdateSSMState(ctx, layer, newState)
// D skip connection: y = y + x * D
if m.SSMD != nil {
// SSMD shape: [1, n_head] -> broadcast to [head_dim, n_head, n_seq_tokens, n_seqs]
xD := x.Mul(ctx, m.SSMD)
y = y.Add(ctx, xD)
}
// Swiglu with z: y = z * silu(y)
y = z.SILU(ctx, y)
// Group RMSNorm
if m.SSMNorm != nil {
// Reshape for group norm: [d_inner/n_group, n_group, n_seq_tokens, n_seqs]
innerPerGroup := dInner / nGroup
y = y.Reshape(ctx, innerPerGroup, nGroup, nSeqTokens, nSeqs)
y = m.SSMNorm.Forward(ctx, y, opts.eps)
}
// Reshape back to [d_inner, n_seq_tokens, n_seqs]
y = y.Reshape(ctx, dInner, nSeqTokens, nSeqs)
// Output projection
out := m.SSMOut.Forward(ctx, y)
// Reshape to 2D for consistency with attention output
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
}

View File

@@ -0,0 +1,417 @@
package nemotronh
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
// Options contains model configuration
type Options struct {
hiddenSize int
numHeads int // attention heads
numKVHeads int // KV heads for attention layers
headDim int
eps float32
// Mamba2 SSM config
ssmDConv int // conv kernel size
ssmDInner int // inner dimension (d_inner)
ssmDState int // state dimension
ssmNHead int // number of SSM heads (dt_rank)
ssmNGroup int // number of groups for B, C
// Per-layer configuration
isRecurrent []bool // true = Mamba2, false = attention or FFN
nFF []int // n_ff per layer (0 = attention-only)
// Attention scale
attentionScale float64
// MoE config
numExperts int
numExpertsUsed int
expertWeightsNorm bool
expertWeightsScale float32
expertWeightsNormClip float32
}
func (o Options) getHeadDim() int {
if o.headDim > 0 {
return o.headDim
}
if o.numHeads <= 0 {
return 0
}
return o.hiddenSize / o.numHeads
}
// Operator is the interface for layer operators (Mamba2 or Attention)
type Operator interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
}
// MLP is the interface for feedforward networks
type MLP interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
}
// Layer represents a single transformer layer
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator // Mamba2, Attention, or nil (for FFN-only layers)
MLP MLP // Dense or MoE FFN, or nil
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
residual := hiddenStates
// Pre-layer norm
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// Layer operator (Mamba2, Attention, or FFN)
if l.Operator != nil {
var err error
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, cache, opts)
if err != nil {
return nil, err
}
} else if l.MLP != nil {
// FFN-only layer
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
}
// Output projection for last layer
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// Residual connection
return hiddenStates.Add(ctx, residual), nil
}
// Model is the main Nemotron-H model
type Model struct {
model.Base
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
*Options
}
// Shift is used for KV cache position shifting.
// Nemotron-H attention does not apply RoPE, so keys do not need to be transformed.
func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
cache := m.Cache.(*HybridCache)
for i, layer := range m.Layers {
cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
var err error
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, outputs, cache, m.Options)
if err != nil {
return nil, err
}
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
// Get per-layer configuration from GGUF metadata
// Use the same interface pattern as qwen3next
type perLayerConfig interface {
HeadCount() []uint64
HeadCountKV() []uint64
FFNLength() []uint64
}
var headCount []uint64
var headCountKV []uint64
var ffnLength []uint64
if plc, ok := c.(perLayerConfig); ok {
headCount = plc.HeadCount()
headCountKV = plc.HeadCountKV()
ffnLength = plc.FFNLength()
}
// Build per-layer arrays with defaults
isRecurrent := make([]bool, numLayers)
nFF := make([]int, numLayers)
for i := range numLayers {
// Get per-layer values
kvHeads := uint64(1) // Default non-zero
if i < len(headCountKV) {
kvHeads = headCountKV[i]
}
ff := uint64(0)
if i < len(ffnLength) {
ff = ffnLength[i]
}
nFF[i] = int(ff)
// A layer is recurrent IFF n_head_kv == 0 AND n_ff == 0
// This matches llama.cpp behavior for Nemotron-H
isRecurrent[i] = kvHeads == 0 && ff == 0
}
// Determine if MoE
isMoE := c.Uint("expert_count") > 0
for i := range layers {
if isRecurrent[i] {
// Mamba2 layer
layers[i].Operator = &Mamba2{Layer: i}
} else if nFF[i] == 0 {
// Attention-only layer (n_head_kv > 0, n_ff == 0)
layers[i].Operator = &Attention{}
} else {
// FFN layer (n_ff > 0)
if isMoE {
layers[i].MLP = &MoESparse{}
} else {
layers[i].MLP = &Dense{}
}
}
}
// Get attention head configuration
numHeads := int(c.Uint("attention.head_count"))
if numHeads == 0 {
for i := range numLayers {
if i < len(headCount) && i < len(headCountKV) && headCount[i] > 0 && headCountKV[i] > 0 {
numHeads = int(headCount[i])
break
}
}
}
numKVHeads := int(c.Uint("attention.head_count_kv"))
if numKVHeads == 0 {
for i := range numLayers {
if i < len(headCountKV) && i < len(ffnLength) && headCountKV[i] > 0 && ffnLength[i] == 0 {
numKVHeads = int(headCountKV[i])
break
}
}
if numKVHeads == 0 {
numKVHeads = numHeads
}
}
headDim := int(c.Uint("attention.head_dim"))
if headDim == 0 {
if keyLength := int(c.Uint("attention.key_length")); keyLength > 0 {
headDim = keyLength
} else if numHeads > 0 {
headDim = int(c.Uint("embedding_length")) / numHeads
}
}
if headDim <= 0 {
return nil, fmt.Errorf("nemotronh: invalid attention head dimension")
}
if numHeads <= 0 {
// Attention layers derive per-layer head counts from projection weights.
// Keep a non-zero default to avoid invalid option math.
numHeads = 1
}
numExperts := int(c.Uint("expert_count"))
numExpertsUsed := int(c.Uint("expert_used_count"))
if numExperts > 0 {
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
return nil, fmt.Errorf("nemotronh: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
}
}
opts := &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: numHeads,
numKVHeads: numKVHeads,
headDim: headDim,
eps: c.Float("attention.layer_norm_rms_epsilon"),
ssmDConv: int(c.Uint("ssm.conv_kernel")),
ssmDInner: int(c.Uint("ssm.inner_size")),
ssmDState: int(c.Uint("ssm.state_size")),
ssmNHead: int(c.Uint("ssm.time_step_rank")),
ssmNGroup: int(c.Uint("ssm.group_count")),
isRecurrent: isRecurrent,
nFF: nFF,
attentionScale: float64(c.Float("attention.scale")),
numExperts: numExperts,
numExpertsUsed: numExpertsUsed,
expertWeightsNorm: c.Bool("expert_weights_norm", false),
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
expertWeightsNormClip: c.Float("expert_weights_norm_clip", 0),
}
// Calculate cache dimensions
convDim := max(0, opts.ssmDConv-1)
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
ssmHeadDim := 0
if opts.ssmNHead > 0 {
ssmHeadDim = opts.ssmDInner / opts.ssmNHead
}
ssmStateSize := opts.ssmDState * ssmHeadDim * opts.ssmNHead
m := Model{
Tokenizer: tokenizer.NewBytePairEncoding(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: opts,
}
m.Cache = NewHybridCache(convDim, convChannels, ssmStateSize)
return &m, nil
}
func init() {
model.Register("nemotron_h", New)
model.Register("nemotron_h_moe", New)
}
// Ensure Model implements model.Model
var _ model.Model = (*Model)(nil)
// Dense implements standard feedforward with ReLU-squared activation
type Dense struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (d *Dense) Forward(ctx ml.Context, x ml.Tensor, opts *Options) ml.Tensor {
// up -> ReLU-squared -> down
up := d.Up.Forward(ctx, x)
up = up.RELU(ctx)
up = up.Mul(ctx, up) // Square
return d.Down.Forward(ctx, up)
}
// MoESparse implements MoE with shared experts for Nemotron-H-MoE
type MoESparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
LatentIn *nn.Linear `gguf:"ffn_latent_in"`
LatentOut *nn.Linear `gguf:"ffn_latent_out"`
// Shared experts
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
}
func (m *MoESparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenDim := hiddenStates.Dim(0)
seqLen := hiddenStates.Dim(1)
batchSize := hiddenStates.Dim(2)
if batchSize == 0 {
batchSize = 1
}
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, seqLen*batchSize)
// Router logits with sigmoid gating
routerLogits := m.Router.Forward(ctx, hiddenStates2D)
// Weights come from unbiased sigmoid probabilities.
probs := routerLogits.Sigmoid(ctx)
// Selection uses optional bias.
selectionProbs := probs
if m.Bias != nil {
selectionProbs = selectionProbs.Add(ctx, m.Bias)
}
// Select top-k experts
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
// Normalize routing weights
if opts.expertWeightsNorm {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
weightsSum := routingWeights.SumRows(ctx)
weightsSum = weightsSum.Clamp(ctx, 6.103515625e-5, float32(math.MaxFloat32))
routingWeights = routingWeights.Div(ctx, weightsSum)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
}
// Scale routing weights
if opts.expertWeightsScale != 1.0 {
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
}
routedInput := hiddenStates2D
if m.LatentIn != nil {
routedInput = m.LatentIn.Forward(ctx, routedInput)
}
hiddenStates3D := routedInput.Reshape(ctx, routedInput.Dim(0), 1, routedInput.Dim(1))
// Expert computation with ReLU-squared activation
upOut := m.Up.Forward(ctx, hiddenStates3D, selectedExperts)
upOut = upOut.RELU(ctx)
upOut = upOut.Mul(ctx, upOut) // Square
experts := m.Down.Forward(ctx, upOut, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
// Sum over experts
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
if m.LatentOut != nil {
moeOut = m.LatentOut.Forward(ctx, moeOut)
}
// Add shared experts if present
if m.SharedUp != nil {
sharedUp := m.SharedUp.Forward(ctx, hiddenStates2D)
sharedUp = sharedUp.RELU(ctx)
sharedUp = sharedUp.Mul(ctx, sharedUp) // Square
sharedOut := m.SharedDown.Forward(ctx, sharedUp)
moeOut = moeOut.Add(ctx, sharedOut)
}
return moeOut
}