diff --git a/server/quantization.go b/server/quantization.go index 56d882e84..5e7f45cee 100644 --- a/server/quantization.go +++ b/server/quantization.go @@ -21,33 +21,76 @@ type quantizer struct { progressFn func(n uint64) } +const quantizationChunkElements uint64 = 4 * 1024 * 1024 + func (q quantizer) WriteTo(w io.Writer) (int64, error) { quantize := q.from.Kind != q.to.Kind sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size())) if !quantize { n, err := io.Copy(w, sr) - q.progressFn(q.from.Size()) + if q.progressFn != nil { + q.progressFn(q.from.Size()) + } return n, err } - data, err := io.ReadAll(sr) - if err != nil { - slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err) - return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err) + + if len(q.from.Shape) == 0 || q.from.Shape[0] == 0 { + return 0, fmt.Errorf("tensor %s has invalid shape %v", q.from.Name, q.from.Shape) } - if uint64(len(data)) < q.from.Size() { - return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape) + + fromType := fsggml.TensorType(q.from.Kind) + toType := fsggml.TensorType(q.to.Kind) + nPerRow := q.from.Shape[0] + totalElements := q.from.Elements() + if totalElements%nPerRow != 0 { + return 0, fmt.Errorf("tensor %s has non-row-aligned shape %v", q.from.Name, q.from.Shape) } - var f32s []float32 - newType := fsggml.TensorType(q.to.Kind) - if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 { - f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements()) - } else { - f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements()) + + inRowSize := fromType.RowSize(nPerRow) + if inRowSize == 0 { + return 0, fmt.Errorf("tensor %s has unsupported source type %v", q.from.Name, fromType) } - data = ggml.Quantize(newType, f32s, q.from.Shape) - n, err := w.Write(data) - q.progressFn(q.from.Size()) - return int64(n), err + + totalRows := totalElements / nPerRow + rowsPerChunk := max(quantizationChunkElements/nPerRow, uint64(1)) + chunkBuf := make([]byte, inRowSize*rowsPerChunk) + var written int64 + + for row := uint64(0); row < totalRows; { + chunkRows := min(rowsPerChunk, totalRows-row) + chunkBytes := inRowSize * chunkRows + data := chunkBuf[:chunkBytes] + + if _, err := io.ReadFull(sr, data); err != nil { + slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err) + return written, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err) + } + + var f32s []float32 + chunkElements := chunkRows * nPerRow + if fromType == fsggml.TensorTypeF32 { + f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), chunkElements) + } else { + f32s = ggml.ConvertToF32(data, q.from.Kind, chunkElements) + } + + quantized := ggml.Quantize(toType, f32s, []uint64{nPerRow, chunkRows}) + n, err := w.Write(quantized) + written += int64(n) + if err != nil { + return written, err + } + if n != len(quantized) { + return written, io.ErrShortWrite + } + + if q.progressFn != nil { + q.progressFn(chunkBytes) + } + row += chunkRows + } + + return written, nil } type quantizeState struct {