mirror of
https://github.com/ollama/ollama.git
synced 2026-04-30 16:08:07 -05:00
ggml: ensure tensor size is valid (#14406)
When quantizing tensors during model creation validate that the resulting sizes match what is expected based on the shape.
This commit is contained in:
@@ -245,7 +245,22 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|||||||
padding := ggufPadding(offset, int64(alignment))
|
padding := ggufPadding(offset, int64(alignment))
|
||||||
llm.tensorOffset = uint64(offset + padding)
|
llm.tensorOffset = uint64(offset + padding)
|
||||||
|
|
||||||
|
// get file size to validate tensor bounds
|
||||||
|
fileSize, err := rs.Seek(0, io.SeekEnd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to determine file size: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := rs.Seek(offset, io.SeekStart); err != nil {
|
||||||
|
return fmt.Errorf("failed to seek back after size check: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, tensor := range llm.tensors {
|
for _, tensor := range llm.tensors {
|
||||||
|
tensorEnd := llm.tensorOffset + tensor.Offset + tensor.Size()
|
||||||
|
if tensorEnd > uint64(fileSize) {
|
||||||
|
return fmt.Errorf("tensor %q offset+size (%d) exceeds file size (%d)", tensor.Name, tensorEnd, fileSize)
|
||||||
|
}
|
||||||
|
|
||||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get current offset: %w", err)
|
return fmt.Errorf("failed to get current offset: %w", err)
|
||||||
|
|||||||
@@ -11,21 +11,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestWriteGGUF(t *testing.T) {
|
func TestWriteGGUF(t *testing.T) {
|
||||||
b := bytes.NewBuffer(make([]byte, 2*3))
|
tensorData := make([]byte, 2*3*4) // 6 F32 elements = 24 bytes
|
||||||
for range 8 {
|
for range 8 {
|
||||||
t.Run("shuffle", func(t *testing.T) {
|
t.Run("shuffle", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ts := []*Tensor{
|
ts := []*Tensor{
|
||||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewReader(tensorData)},
|
||||||
}
|
}
|
||||||
|
|
||||||
rand.Shuffle(len(ts), func(i, j int) {
|
rand.Shuffle(len(ts), func(i, j int) {
|
||||||
@@ -98,4 +98,32 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("truncated_tensor_data", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ts := []*Tensor{
|
||||||
|
{Name: "blk.0.attn.weight", Kind: 0, Shape: []uint64{512, 2}, WriterTo: bytes.NewBuffer(make([]byte, 32))},
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := os.CreateTemp(t.TempDir(), "truncated_*.bin")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
if err := WriteGGUF(w, KV{"general.architecture": "test"}, ts); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := os.Open(w.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
if _, err := Decode(r, -1); err == nil {
|
||||||
|
t.Error("Decode should reject GGUF files where tensor data extends beyond file size")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
|||||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
||||||
}
|
}
|
||||||
|
if uint64(len(data)) < q.from.Size() {
|
||||||
|
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
|
||||||
|
}
|
||||||
var f32s []float32
|
var f32s []float32
|
||||||
newType := fsggml.TensorType(q.to.Kind)
|
newType := fsggml.TensorType(q.to.Kind)
|
||||||
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ func TestQuantizeModel(t *testing.T) {
|
|||||||
tensors []*fsggml.Tensor
|
tensors []*fsggml.Tensor
|
||||||
newType string
|
newType string
|
||||||
expectedTensorTypes map[string]fsggml.TensorType
|
expectedTensorTypes map[string]fsggml.TensorType
|
||||||
|
expectErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "f16_q4_k",
|
name: "f16_q4_k",
|
||||||
@@ -253,6 +254,36 @@ func TestQuantizeModel(t *testing.T) {
|
|||||||
"output.weight": fsggml.TensorTypeQ8_0,
|
"output.weight": fsggml.TensorTypeQ8_0,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "f32_short_data",
|
||||||
|
kv: map[string]any{
|
||||||
|
"general.architecture": "foo",
|
||||||
|
},
|
||||||
|
tensors: []*fsggml.Tensor{
|
||||||
|
{
|
||||||
|
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||||
|
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||||
|
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
newType: "Q4_K",
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "f16_short_data",
|
||||||
|
kv: map[string]any{
|
||||||
|
"general.architecture": "foo",
|
||||||
|
},
|
||||||
|
tensors: []*fsggml.Tensor{
|
||||||
|
{
|
||||||
|
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||||
|
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||||
|
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
newType: "Q4_K",
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
@@ -264,6 +295,9 @@ func TestQuantizeModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer fp.Close()
|
defer fp.Close()
|
||||||
meta, err := fsggml.Decode(fp, -1)
|
meta, err := fsggml.Decode(fp, -1)
|
||||||
|
if tt.expectErr && err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
@@ -283,6 +317,12 @@ func TestQuantizeModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = quantize(fp, tmp, meta, ftype, progress)
|
err = quantize(fp, tmp, meta, ftype, progress)
|
||||||
|
if tt.expectErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected quantize to return an error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error during quantize: %s", err)
|
t.Fatalf("error during quantize: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user