fix tensor merge (#13053)

This commit is contained in:
Michael Yang
2025-11-13 15:32:34 -08:00
committed by GitHub
parent 333203d871
commit 12b174b10e
2 changed files with 66 additions and 0 deletions

View File

@@ -2,10 +2,12 @@ package convert
import (
"cmp"
"errors"
"io"
"iter"
"path"
"slices"
"strconv"
"strings"
"github.com/pdevine/tensor"
@@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
return matched
})
slices.SortStableFunc(matched, func(a, b Tensor) int {
x := strings.Split(a.Name(), ".")
y := strings.Split(b.Name(), ".")
if len(x) != len(y) {
return cmp.Compare(len(x), len(y))
}
vals := make([]int, len(x))
for i := range x {
vals[i] = strings.Compare(x[i], y[i])
m, err := strconv.ParseInt(x[i], 0, 0)
n, err2 := strconv.ParseInt(y[i], 0, 0)
if errors.Join(err, err2) == nil {
vals[i] = cmp.Compare(m, n)
}
}
return cmp.Or(vals...)
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,

View File

@@ -3,8 +3,10 @@ package convert
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"iter"
"math/rand/v2"
"slices"
"strings"
"testing"
@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
}
})
}
func TestMergeOrder(t *testing.T) {
for range 8 {
t.Run("", func(t *testing.T) {
tensors := make([]Tensor, 16)
for i := range tensors {
tensors[i] = &fakeTensor{
name: fmt.Sprintf("layer.%d.weight", i),
shape: []uint64{1},
data: []float32{float32(i)},
}
}
rand.Shuffle(len(tensors), func(i, j int) {
tensors[i], tensors[j] = tensors[j], tensors[i]
})
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
if len(unmatched) != 0 {
t.Error("expected no remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
var b bytes.Buffer
if _, err := matched[0].WriteTo(&b); err != nil {
t.Fatal(err)
}
var f32s [16]float32
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.IsSorted(f32s[:]) {
t.Errorf("merged tensor data is not in order: %+v", f32s)
}
})
}
}