chore: update models to use slice/chunk/chunksections (#12934)

* use slice/chunks

* bert

* llama4

* gemma3n

* gptoss

* mistral3

* qwen3vl

* qwen25vl

* deepseek2

* remove unused ops
This commit is contained in:
Michael Yang
2025-11-13 15:20:12 -08:00
committed by GitHub
parent c114987523
commit 333203d871
13 changed files with 59 additions and 135 deletions

View File

@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name = name + ".weight"
}
if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{
Name: name + ".weight",
Name: name,
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
Name: strings.Replace(name, "gate_up", "gate", 1),
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
Name: strings.Replace(name, "gate_up", "up", 1),
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),

View File

@@ -146,7 +146,6 @@ type Tensor interface {
FromFloats([]float32)
FromInts([]int32)
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
@@ -185,7 +184,6 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, shape ...int) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor
@@ -209,7 +207,6 @@ type Tensor interface {
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Clamp(ctx Context, min, max float32) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
}
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
}
}
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {
@@ -1732,13 +1711,6 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.

View File

@@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
return hiddenStates.Slice(ctx, 1, 0, 1, 1)
case TypeLast:
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
return hiddenStates
return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
default:
panic("unknown pooling type")
}

View File

@@ -29,7 +29,7 @@ type Model struct {
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1))
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions))))
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)

View File

@@ -78,44 +78,31 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
qPass := query.View(ctx, 0,
opts.qkNopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0),
opts.qkRopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1))
kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0),
opts.qkRopeHeadDim, compressedKV.Stride(1),
1, compressedKV.Stride(1),
compressedKV.Dim(1))
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kRot := compressedKV.View(ctx,
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
compressedKV.Stride(1), 1,
compressedKV.Stride(1), compressedKV.Dim(1),
)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2))
value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0),
opts.vHeadDim, kv.Stride(1),
kv.Dim(1), kv.Stride(2),
kv.Dim(2)).Contiguous(ctx)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = kRot.Repeat(ctx, 1, qPass.Dim(1))
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, qPass, 0)
key := kRot.Concat(ctx, kPass, 0)
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
@@ -142,6 +129,7 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))

View File

@@ -64,18 +64,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
// inputPerLayer = inputsPerLayer[:, i, :]
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx)
// inputPerLayer = inputsPerLayer[:, i, :].squeeze(1)
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
}
// hiddenStates = hiddenStates[:, :, 0]
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
hiddenStates0 := hiddenStates.Slice(ctx, 2, 0, 1, 1)
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
// hiddenState = hiddenStates[:, :, 1:]
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
hiddenState = hiddenStates.Slice(ctx, 2, 1, hiddenStates.Dim(2), 1)
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
@@ -176,10 +176,10 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
// inactive := predictions[:, :, 1:]
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
inactive := predictions.Slice(ctx, 2, 1, predictions.Dim(2), 1)
active = inactive.Add(ctx, active)
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
predictions0 := predictions.Slice(ctx, 2, 0, 1, 1)
return predictions0.Concat(ctx, active, 2)
}
@@ -319,7 +319,7 @@ type TextOptions struct {
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
// t[:, :, o.altupActiveIndex]
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
return t.Slice(ctx, 2, o.altupActiveIndex, o.altupActiveIndex+1, 1)
}
func (o *TextOptions) headDim() int {

View File

@@ -121,30 +121,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
var query, key, value ml.Tensor
if attn.QKV != nil {
qkv := attn.QKV.Forward(ctx, hiddenStates)
// query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim)
query = qkv.View(ctx,
0,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numHeads, qkv.Stride(1),
batchSize,
)
// key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim)
key = qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*opts.numHeads,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
// value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim)
value = qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads),
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, batchSize)
chunks := qkv.ChunkSections(ctx, 1, opts.numHeads, opts.numKVHeads, opts.numKVHeads)
query, key, value = chunks[0], chunks[1], chunks[2]
} else {
query = attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
@@ -195,15 +174,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
var gate, up ml.Tensor
if mlp.GateUp != nil {
hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2))
dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)}
gate = hiddenStates.View(ctx, 0, dimStride...)
gate = gate.Contiguous(ctx, gate.Dim(0)*gate.Dim(1), gate.Dim(2), gate.Dim(3))
up = hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...)
up = up.Contiguous(ctx, up.Dim(0)*up.Dim(1), up.Dim(2), up.Dim(3))
gate = hiddenStates.Slice(ctx, 0, 0, hiddenStates.Dim(0), 2)
up = hiddenStates.Slice(ctx, 0, 1, hiddenStates.Dim(0), 2)
} else {
gate = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts)
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)

View File

@@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
for range aspectRatio.Y {
for x := range aspectRatio.X {
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
var separator separator
if x < aspectRatio.X-1 {
separator.x = true // <|tile_x_separator|>
@@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
}
}
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
return multimodal, nil

View File

@@ -37,27 +37,23 @@ type VisionAttention struct {
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
// t1 = t[..., 0::2]
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
t1 := t.Slice(ctx, 0, 0, t.Dim(0), 2)
// t2 = t[..., 1::2]
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
t2 := t.Slice(ctx, 0, 1, t.Dim(0), 2)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, -1)
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3)
cosOut = cosOut.Contiguous(ctx, width, height, channels, tiles)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
sinOut := t2.Scale(ctx, -1).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, -1)
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3)
sinOut = sinOut.Contiguous(ctx, width, height, channels, tiles)
return cosOut.Add(ctx, sinOut)
}

View File

@@ -11,9 +11,9 @@ import (
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {

View File

@@ -13,9 +13,9 @@ import (
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {

View File

@@ -18,8 +18,8 @@ type VisionAttention struct {
}
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1)
x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx)
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
@@ -160,10 +160,11 @@ func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor
positionEmbeds = positionEmbeds.Mul(ctx, weights)
positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4)
positionEmbeds = positionEmbeds.View(ctx, 0, n, positionEmbeds.Stride(1), grid.Height*grid.Width).
Add(ctx, positionEmbeds.View(ctx, 1*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)).
Add(ctx, positionEmbeds.View(ctx, 2*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)).
Add(ctx, positionEmbeds.View(ctx, 3*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width))
positionEmbedsChunks := positionEmbeds.Chunk(ctx, 2, 1)
positionEmbeds = positionEmbedsChunks[0].
Add(ctx, positionEmbedsChunks[1]).
Add(ctx, positionEmbedsChunks[2]).
Add(ctx, positionEmbedsChunks[3])
positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize)
positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1)