mirror of
https://github.com/ollama/ollama.git
synced 2026-04-30 16:08:07 -05:00
qwen3next: avoid inplace sigmoid for shared gate (#14077)
This commit is contained in:
@@ -175,6 +175,7 @@ type Tensor interface {
|
|||||||
SILU(ctx Context, up ...Tensor) Tensor
|
SILU(ctx Context, up ...Tensor) Tensor
|
||||||
RELU(ctx Context, up ...Tensor) Tensor
|
RELU(ctx Context, up ...Tensor) Tensor
|
||||||
Sigmoid(ctx Context) Tensor
|
Sigmoid(ctx Context) Tensor
|
||||||
|
SigmoidOut(ctx Context) Tensor
|
||||||
|
|
||||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||||
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||||
|
|||||||
@@ -1468,6 +1468,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||||
switch len(shape) {
|
switch len(shape) {
|
||||||
case 1:
|
case 1:
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
|||||||
// Apply shared expert gating
|
// Apply shared expert gating
|
||||||
if mlp.SharedGateInp != nil {
|
if mlp.SharedGateInp != nil {
|
||||||
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
|
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
|
||||||
sharedGateVal = sharedGateVal.Sigmoid(ctx)
|
sharedGateVal = sharedGateVal.SigmoidOut(ctx)
|
||||||
// Broadcast gate to match dimensions
|
// Broadcast gate to match dimensions
|
||||||
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
|
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
|
||||||
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
|
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
|
||||||
|
|||||||
Reference in New Issue
Block a user