mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
models: add nemotronh architecture support (#14356)
This commit is contained in:
@@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor
|
||||
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
_ "github.com/ollama/ollama/model/models/nemotronh"
|
||||
_ "github.com/ollama/ollama/model/models/nomicbert"
|
||||
_ "github.com/ollama/ollama/model/models/olmo3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
|
||||
88
model/models/nemotronh/attention.go
Normal file
88
model/models/nemotronh/attention.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// Attention implements simple attention without RoPE for Nemotron-H.
|
||||
// Unlike Qwen3Next, Nemotron-H attention has:
|
||||
// - No RoPE (position info comes from Mamba2 layers)
|
||||
// - Standard scaled dot-product attention
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
switch hiddenStates.Dim(2) {
|
||||
case 0:
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
|
||||
case 1:
|
||||
default:
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
|
||||
// Nemotron-H is currently clamped to num_parallel=1.
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
if cache.numSeqs() != 1 {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
batchSize := nSeqTokens
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize)
|
||||
|
||||
headDim := opts.getHeadDim()
|
||||
if headDim <= 0 {
|
||||
return nil, fmt.Errorf("nemotronh: invalid attention head dimension %d", headDim)
|
||||
}
|
||||
|
||||
// Q projection
|
||||
query := a.Query.Forward(ctx, hiddenStates)
|
||||
if query.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: query dim %d not divisible by head dim %d", query.Dim(0), headDim)
|
||||
}
|
||||
numHeads := query.Dim(0) / headDim
|
||||
query = query.Reshape(ctx, headDim, numHeads, batchSize)
|
||||
|
||||
// K projection
|
||||
key := a.Key.Forward(ctx, hiddenStates)
|
||||
if key.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: key dim %d not divisible by head dim %d", key.Dim(0), headDim)
|
||||
}
|
||||
numKVHeads := key.Dim(0) / headDim
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// V projection
|
||||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
if value.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: value dim %d not divisible by head dim %d", value.Dim(0), headDim)
|
||||
}
|
||||
if value.Dim(0)/headDim != numKVHeads {
|
||||
return nil, fmt.Errorf("nemotronh: key heads %d and value heads %d do not match", numKVHeads, value.Dim(0)/headDim)
|
||||
}
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// Standard attention computation (no RoPE)
|
||||
scale := opts.attentionScale
|
||||
if scale == 0 {
|
||||
scale = 1.0 / math.Sqrt(float64(headDim))
|
||||
}
|
||||
attention := nn.Attention(ctx, query, key, value, scale, cache)
|
||||
|
||||
// Flatten heads
|
||||
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
|
||||
|
||||
// Output projection
|
||||
return a.Output.Forward(ctx, attention), nil
|
||||
}
|
||||
55
model/models/nemotronh/cache.go
Normal file
55
model/models/nemotronh/cache.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
|
||||
// with the layer requirements.
|
||||
var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout")
|
||||
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache adapts the shared recurrent cache base for Nemotron-H naming.
|
||||
type HybridCache struct {
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache {
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: Shift,
|
||||
ConvDim: convDim,
|
||||
ConvChannels: convChannels,
|
||||
RecurrentStateSize: ssmStateSize,
|
||||
CheckpointLogPrefix: "nemotronh",
|
||||
})
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
// SSMState returns the SSM state for current batch sequences.
|
||||
func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) {
|
||||
return c.RecurrentState4D(ctx, layer, dState, headDim, nHead)
|
||||
}
|
||||
|
||||
// UpdateSSMState writes a new SSM state for current batch sequences.
|
||||
func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
c.UpdateRecurrentState(ctx, layer, newState)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.SlotsTensor()
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return c.NumSeqs()
|
||||
}
|
||||
197
model/models/nemotronh/mamba2.go
Normal file
197
model/models/nemotronh/mamba2.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// convKernel wraps the 1D convolution kernel tensor
|
||||
type convKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// Mamba2 implements the Mamba2 SSM layer for Nemotron-H.
|
||||
// The forward pass follows llama.cpp's build_mamba2_layer:
|
||||
// 1. Input projection: zxBCdt = SSMIn @ hidden
|
||||
// 2. Split: z, xBC, dt
|
||||
// 3. Concat with conv state, apply SSMConv, save new conv state
|
||||
// 4. Apply SiLU to convolved xBC
|
||||
// 5. Split: x, B, C
|
||||
// 6. Add dt bias
|
||||
// 7. SSMScan: y = SSMScan(state, x, dt, A, B, C, ids)
|
||||
// 8. D skip: y = y + x * D
|
||||
// 9. Swiglu with z: y = z * silu(y)
|
||||
// 10. Group RMSNorm
|
||||
// 11. Output projection
|
||||
type Mamba2 struct {
|
||||
SSMIn *nn.Linear `gguf:"ssm_in"` // n_embd → d_in_proj (2*d_inner + 2*n_group*d_state + n_head)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"` // conv kernel
|
||||
SSMConv1DB ml.Tensor `gguf:"ssm_conv1d.bias"`
|
||||
SSMDtB ml.Tensor `gguf:"ssm_dt.bias"` // dt bias [n_head]
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // A parameter [1, n_head]
|
||||
SSMD ml.Tensor `gguf:"ssm_d"` // D skip connection [1, n_head]
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` // group norm
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"` // output projection
|
||||
Layer int
|
||||
}
|
||||
|
||||
func (m *Mamba2) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
layer := m.Layer
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
switch hiddenStates.Dim(2) {
|
||||
case 0:
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
|
||||
case 1:
|
||||
default:
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
|
||||
// Nemotron-H is currently clamped to num_parallel=1.
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
if cache.numSeqs() != 1 {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
nSeqs := 1
|
||||
|
||||
dConv := opts.ssmDConv
|
||||
dInner := opts.ssmDInner
|
||||
dState := opts.ssmDState
|
||||
nHead := opts.ssmNHead
|
||||
headDim := dInner / nHead
|
||||
nGroup := opts.ssmNGroup
|
||||
|
||||
// {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
|
||||
// d_in_proj = 2*d_inner + 2*n_group*d_state + n_head
|
||||
zxBCdt := m.SSMIn.Forward(ctx, hiddenStates)
|
||||
|
||||
// Split into z, xBC, dt
|
||||
// z: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
z := zxBCdt.Slice(ctx, 0, 0, dInner, 1)
|
||||
z = z.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
// xBC: [d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs]
|
||||
xBCSize := dInner + 2*nGroup*dState
|
||||
xBC := zxBCdt.Slice(ctx, 0, dInner, dInner+xBCSize, 1)
|
||||
if nSeqTokens == 1 {
|
||||
xBC = xBC.Reshape(ctx, xBCSize, 1, nSeqs)
|
||||
}
|
||||
|
||||
// dt: [n_head, n_seq_tokens, n_seqs]
|
||||
dt := zxBCdt.Slice(ctx, 0, 2*dInner+2*nGroup*dState, 2*dInner+2*nGroup*dState+nHead, 1)
|
||||
if nSeqTokens == 1 {
|
||||
dt = dt.Reshape(ctx, nHead, 1, nSeqs)
|
||||
} else {
|
||||
dt = dt.Contiguous(ctx, nHead, nSeqTokens, nSeqs)
|
||||
}
|
||||
|
||||
// Get conv state from cache
|
||||
convStates, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
slog.Warn("nemotronh: failed to get conv state, using zeros", "layer", layer, "error", err)
|
||||
convStates = ctx.Input().Zeros(ml.DTypeF32, dConv-1, xBCSize, nSeqs)
|
||||
}
|
||||
|
||||
// Reshape conv states: [d_conv-1, xBCSize, n_seqs]
|
||||
convStates = convStates.Reshape(ctx, dConv-1, xBCSize, nSeqs)
|
||||
|
||||
// For decode (n_seq_tokens == 1), reshape avoids a transpose/contiguous pair.
|
||||
var xBCT ml.Tensor
|
||||
if nSeqTokens == 1 {
|
||||
xBCT = xBC.Reshape(ctx, 1, xBCSize, nSeqs)
|
||||
} else {
|
||||
// Prefill path: [xBCSize, n_seq_tokens, n_seqs] -> [n_seq_tokens, xBCSize, n_seqs]
|
||||
xBCT = xBC.Permute(ctx, 1, 0, 2, 3)
|
||||
}
|
||||
|
||||
// Concatenate with conv state: [d_conv-1 + n_seq_tokens, xBCSize, n_seqs]
|
||||
convInput := convStates.Concat(ctx, xBCT, 0)
|
||||
|
||||
// Save new conv state (last d_conv-1 columns)
|
||||
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+dConv-1, 1)
|
||||
cache.UpdateConvState(ctx, layer, lastConvStates)
|
||||
|
||||
// Apply SSM convolution
|
||||
xBC = convInput.SSMConv(ctx, m.SSMConv1D.Weight)
|
||||
|
||||
// Add conv bias
|
||||
if m.SSMConv1DB != nil {
|
||||
xBC = xBC.Add(ctx, m.SSMConv1DB)
|
||||
}
|
||||
|
||||
// Apply SiLU
|
||||
xBC = xBC.SILU(ctx)
|
||||
|
||||
// Split xBC into x, B, C
|
||||
// x: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
x := xBC.Slice(ctx, 0, 0, dInner, 1)
|
||||
x = x.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
// B: [d_state, n_group, n_seq_tokens, n_seqs]
|
||||
B := xBC.Slice(ctx, 0, dInner, dInner+nGroup*dState, 1)
|
||||
B = B.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
|
||||
|
||||
// C: [d_state, n_group, n_seq_tokens, n_seqs]
|
||||
C := xBC.Slice(ctx, 0, dInner+nGroup*dState, dInner+2*nGroup*dState, 1)
|
||||
C = C.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
|
||||
|
||||
// Add dt bias
|
||||
dt = dt.Add(ctx, m.SSMDtB)
|
||||
|
||||
// Get SSM state from cache
|
||||
state, err := cache.SSMState(ctx, layer, dState, headDim, nHead)
|
||||
if err != nil {
|
||||
slog.Warn("nemotronh: failed to get SSM state, using zeros", "layer", layer, "error", err)
|
||||
state = ctx.Input().Zeros(ml.DTypeF32, dState, headDim, nHead, nSeqs)
|
||||
}
|
||||
|
||||
// SSMScan
|
||||
// state: [d_state, head_dim, n_head, n_seqs]
|
||||
// returns: [head_dim, n_head, n_seq_tokens, n_seqs] concatenated with new state
|
||||
ySsm := state.SSMScan(ctx, x, dt, m.SSMA, B, C, cache.slotsTensor())
|
||||
|
||||
// ySsm is a packed 1D buffer: [y (nSeqTokens*headDim*nHead*nSeqs), newState]
|
||||
yElems := headDim * nHead * nSeqTokens * nSeqs
|
||||
y := ySsm.View(ctx, 0, yElems).Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
stateOffsetBytes := yElems * x.Stride(0)
|
||||
stateElems := dState * headDim * nHead * nSeqs
|
||||
newState := ySsm.View(ctx, stateOffsetBytes, stateElems)
|
||||
newState = newState.Reshape(ctx, dState, headDim, nHead, nSeqs)
|
||||
|
||||
// Update SSM state in cache
|
||||
cache.UpdateSSMState(ctx, layer, newState)
|
||||
|
||||
// D skip connection: y = y + x * D
|
||||
if m.SSMD != nil {
|
||||
// SSMD shape: [1, n_head] -> broadcast to [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
xD := x.Mul(ctx, m.SSMD)
|
||||
y = y.Add(ctx, xD)
|
||||
}
|
||||
|
||||
// Swiglu with z: y = z * silu(y)
|
||||
y = z.SILU(ctx, y)
|
||||
|
||||
// Group RMSNorm
|
||||
if m.SSMNorm != nil {
|
||||
// Reshape for group norm: [d_inner/n_group, n_group, n_seq_tokens, n_seqs]
|
||||
innerPerGroup := dInner / nGroup
|
||||
y = y.Reshape(ctx, innerPerGroup, nGroup, nSeqTokens, nSeqs)
|
||||
y = m.SSMNorm.Forward(ctx, y, opts.eps)
|
||||
}
|
||||
|
||||
// Reshape back to [d_inner, n_seq_tokens, n_seqs]
|
||||
y = y.Reshape(ctx, dInner, nSeqTokens, nSeqs)
|
||||
|
||||
// Output projection
|
||||
out := m.SSMOut.Forward(ctx, y)
|
||||
|
||||
// Reshape to 2D for consistency with attention output
|
||||
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
|
||||
}
|
||||
417
model/models/nemotronh/model.go
Normal file
417
model/models/nemotronh/model.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
numHeads int // attention heads
|
||||
numKVHeads int // KV heads for attention layers
|
||||
headDim int
|
||||
eps float32
|
||||
|
||||
// Mamba2 SSM config
|
||||
ssmDConv int // conv kernel size
|
||||
ssmDInner int // inner dimension (d_inner)
|
||||
ssmDState int // state dimension
|
||||
ssmNHead int // number of SSM heads (dt_rank)
|
||||
ssmNGroup int // number of groups for B, C
|
||||
|
||||
// Per-layer configuration
|
||||
isRecurrent []bool // true = Mamba2, false = attention or FFN
|
||||
nFF []int // n_ff per layer (0 = attention-only)
|
||||
|
||||
// Attention scale
|
||||
attentionScale float64
|
||||
|
||||
// MoE config
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
expertWeightsNorm bool
|
||||
expertWeightsScale float32
|
||||
expertWeightsNormClip float32
|
||||
}
|
||||
|
||||
func (o Options) getHeadDim() int {
|
||||
if o.headDim > 0 {
|
||||
return o.headDim
|
||||
}
|
||||
if o.numHeads <= 0 {
|
||||
return 0
|
||||
}
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
// Operator is the interface for layer operators (Mamba2 or Attention)
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
|
||||
}
|
||||
|
||||
// MLP is the interface for feedforward networks
|
||||
type MLP interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Operator Operator // Mamba2, Attention, or nil (for FFN-only layers)
|
||||
MLP MLP // Dense or MoE FFN, or nil
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
residual := hiddenStates
|
||||
|
||||
// Pre-layer norm
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Layer operator (Mamba2, Attention, or FFN)
|
||||
if l.Operator != nil {
|
||||
var err error
|
||||
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, cache, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if l.MLP != nil {
|
||||
// FFN-only layer
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
}
|
||||
|
||||
// Output projection for last layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
return hiddenStates.Add(ctx, residual), nil
|
||||
}
|
||||
|
||||
// Model is the main Nemotron-H model
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
// Shift is used for KV cache position shifting.
|
||||
// Nemotron-H attention does not apply RoPE, so keys do not need to be transformed.
|
||||
func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
var err error
|
||||
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, outputs, cache, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
|
||||
// Get per-layer configuration from GGUF metadata
|
||||
// Use the same interface pattern as qwen3next
|
||||
type perLayerConfig interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
FFNLength() []uint64
|
||||
}
|
||||
|
||||
var headCount []uint64
|
||||
var headCountKV []uint64
|
||||
var ffnLength []uint64
|
||||
|
||||
if plc, ok := c.(perLayerConfig); ok {
|
||||
headCount = plc.HeadCount()
|
||||
headCountKV = plc.HeadCountKV()
|
||||
ffnLength = plc.FFNLength()
|
||||
}
|
||||
|
||||
// Build per-layer arrays with defaults
|
||||
isRecurrent := make([]bool, numLayers)
|
||||
nFF := make([]int, numLayers)
|
||||
|
||||
for i := range numLayers {
|
||||
// Get per-layer values
|
||||
kvHeads := uint64(1) // Default non-zero
|
||||
if i < len(headCountKV) {
|
||||
kvHeads = headCountKV[i]
|
||||
}
|
||||
ff := uint64(0)
|
||||
if i < len(ffnLength) {
|
||||
ff = ffnLength[i]
|
||||
}
|
||||
nFF[i] = int(ff)
|
||||
|
||||
// A layer is recurrent IFF n_head_kv == 0 AND n_ff == 0
|
||||
// This matches llama.cpp behavior for Nemotron-H
|
||||
isRecurrent[i] = kvHeads == 0 && ff == 0
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
isMoE := c.Uint("expert_count") > 0
|
||||
|
||||
for i := range layers {
|
||||
if isRecurrent[i] {
|
||||
// Mamba2 layer
|
||||
layers[i].Operator = &Mamba2{Layer: i}
|
||||
} else if nFF[i] == 0 {
|
||||
// Attention-only layer (n_head_kv > 0, n_ff == 0)
|
||||
layers[i].Operator = &Attention{}
|
||||
} else {
|
||||
// FFN layer (n_ff > 0)
|
||||
if isMoE {
|
||||
layers[i].MLP = &MoESparse{}
|
||||
} else {
|
||||
layers[i].MLP = &Dense{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get attention head configuration
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
if numHeads == 0 {
|
||||
for i := range numLayers {
|
||||
if i < len(headCount) && i < len(headCountKV) && headCount[i] > 0 && headCountKV[i] > 0 {
|
||||
numHeads = int(headCount[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
numKVHeads := int(c.Uint("attention.head_count_kv"))
|
||||
if numKVHeads == 0 {
|
||||
for i := range numLayers {
|
||||
if i < len(headCountKV) && i < len(ffnLength) && headCountKV[i] > 0 && ffnLength[i] == 0 {
|
||||
numKVHeads = int(headCountKV[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
if numKVHeads == 0 {
|
||||
numKVHeads = numHeads
|
||||
}
|
||||
}
|
||||
|
||||
headDim := int(c.Uint("attention.head_dim"))
|
||||
if headDim == 0 {
|
||||
if keyLength := int(c.Uint("attention.key_length")); keyLength > 0 {
|
||||
headDim = keyLength
|
||||
} else if numHeads > 0 {
|
||||
headDim = int(c.Uint("embedding_length")) / numHeads
|
||||
}
|
||||
}
|
||||
if headDim <= 0 {
|
||||
return nil, fmt.Errorf("nemotronh: invalid attention head dimension")
|
||||
}
|
||||
if numHeads <= 0 {
|
||||
// Attention layers derive per-layer head counts from projection weights.
|
||||
// Keep a non-zero default to avoid invalid option math.
|
||||
numHeads = 1
|
||||
}
|
||||
|
||||
numExperts := int(c.Uint("expert_count"))
|
||||
numExpertsUsed := int(c.Uint("expert_used_count"))
|
||||
if numExperts > 0 {
|
||||
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
|
||||
return nil, fmt.Errorf("nemotronh: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
|
||||
}
|
||||
}
|
||||
|
||||
opts := &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: numHeads,
|
||||
numKVHeads: numKVHeads,
|
||||
headDim: headDim,
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ssmDConv: int(c.Uint("ssm.conv_kernel")),
|
||||
ssmDInner: int(c.Uint("ssm.inner_size")),
|
||||
ssmDState: int(c.Uint("ssm.state_size")),
|
||||
ssmNHead: int(c.Uint("ssm.time_step_rank")),
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
isRecurrent: isRecurrent,
|
||||
nFF: nFF,
|
||||
attentionScale: float64(c.Float("attention.scale")),
|
||||
numExperts: numExperts,
|
||||
numExpertsUsed: numExpertsUsed,
|
||||
expertWeightsNorm: c.Bool("expert_weights_norm", false),
|
||||
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
|
||||
expertWeightsNormClip: c.Float("expert_weights_norm_clip", 0),
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
convDim := max(0, opts.ssmDConv-1)
|
||||
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
|
||||
ssmHeadDim := 0
|
||||
if opts.ssmNHead > 0 {
|
||||
ssmHeadDim = opts.ssmDInner / opts.ssmNHead
|
||||
}
|
||||
ssmStateSize := opts.ssmDState * ssmHeadDim * opts.ssmNHead
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}
|
||||
|
||||
m.Cache = NewHybridCache(convDim, convChannels, ssmStateSize)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("nemotron_h", New)
|
||||
model.Register("nemotron_h_moe", New)
|
||||
}
|
||||
|
||||
// Ensure Model implements model.Model
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
// Dense implements standard feedforward with ReLU-squared activation
|
||||
type Dense struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (d *Dense) Forward(ctx ml.Context, x ml.Tensor, opts *Options) ml.Tensor {
|
||||
// up -> ReLU-squared -> down
|
||||
up := d.Up.Forward(ctx, x)
|
||||
up = up.RELU(ctx)
|
||||
up = up.Mul(ctx, up) // Square
|
||||
return d.Down.Forward(ctx, up)
|
||||
}
|
||||
|
||||
// MoESparse implements MoE with shared experts for Nemotron-H-MoE
|
||||
type MoESparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
|
||||
LatentIn *nn.Linear `gguf:"ffn_latent_in"`
|
||||
LatentOut *nn.Linear `gguf:"ffn_latent_out"`
|
||||
|
||||
// Shared experts
|
||||
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
|
||||
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
|
||||
}
|
||||
|
||||
func (m *MoESparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
seqLen := hiddenStates.Dim(1)
|
||||
batchSize := hiddenStates.Dim(2)
|
||||
if batchSize == 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, seqLen*batchSize)
|
||||
|
||||
// Router logits with sigmoid gating
|
||||
routerLogits := m.Router.Forward(ctx, hiddenStates2D)
|
||||
|
||||
// Weights come from unbiased sigmoid probabilities.
|
||||
probs := routerLogits.Sigmoid(ctx)
|
||||
|
||||
// Selection uses optional bias.
|
||||
selectionProbs := probs
|
||||
if m.Bias != nil {
|
||||
selectionProbs = selectionProbs.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
// Select top-k experts
|
||||
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
|
||||
|
||||
// Normalize routing weights
|
||||
if opts.expertWeightsNorm {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
weightsSum := routingWeights.SumRows(ctx)
|
||||
weightsSum = weightsSum.Clamp(ctx, 6.103515625e-5, float32(math.MaxFloat32))
|
||||
routingWeights = routingWeights.Div(ctx, weightsSum)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
}
|
||||
|
||||
// Scale routing weights
|
||||
if opts.expertWeightsScale != 1.0 {
|
||||
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
|
||||
}
|
||||
|
||||
routedInput := hiddenStates2D
|
||||
if m.LatentIn != nil {
|
||||
routedInput = m.LatentIn.Forward(ctx, routedInput)
|
||||
}
|
||||
hiddenStates3D := routedInput.Reshape(ctx, routedInput.Dim(0), 1, routedInput.Dim(1))
|
||||
|
||||
// Expert computation with ReLU-squared activation
|
||||
upOut := m.Up.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
upOut = upOut.RELU(ctx)
|
||||
upOut = upOut.Mul(ctx, upOut) // Square
|
||||
experts := m.Down.Forward(ctx, upOut, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum over experts
|
||||
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
if m.LatentOut != nil {
|
||||
moeOut = m.LatentOut.Forward(ctx, moeOut)
|
||||
}
|
||||
|
||||
// Add shared experts if present
|
||||
if m.SharedUp != nil {
|
||||
sharedUp := m.SharedUp.Forward(ctx, hiddenStates2D)
|
||||
sharedUp = sharedUp.RELU(ctx)
|
||||
sharedUp = sharedUp.Mul(ctx, sharedUp) // Square
|
||||
sharedOut := m.SharedDown.Forward(ctx, sharedUp)
|
||||
moeOut = moeOut.Add(ctx, sharedOut)
|
||||
}
|
||||
|
||||
return moeOut
|
||||
}
|
||||
Reference in New Issue
Block a user