mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 03:12:11 -05:00
server: chunk quantization writes to reduce create memory usage
This commit is contained in:
@@ -21,33 +21,76 @@ type quantizer struct {
|
||||
progressFn func(n uint64)
|
||||
}
|
||||
|
||||
const quantizationChunkElements uint64 = 4 * 1024 * 1024
|
||||
|
||||
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||
quantize := q.from.Kind != q.to.Kind
|
||||
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
|
||||
if !quantize {
|
||||
n, err := io.Copy(w, sr)
|
||||
q.progressFn(q.from.Size())
|
||||
if q.progressFn != nil {
|
||||
q.progressFn(q.from.Size())
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
data, err := io.ReadAll(sr)
|
||||
if err != nil {
|
||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
||||
|
||||
if len(q.from.Shape) == 0 || q.from.Shape[0] == 0 {
|
||||
return 0, fmt.Errorf("tensor %s has invalid shape %v", q.from.Name, q.from.Shape)
|
||||
}
|
||||
if uint64(len(data)) < q.from.Size() {
|
||||
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
|
||||
|
||||
fromType := fsggml.TensorType(q.from.Kind)
|
||||
toType := fsggml.TensorType(q.to.Kind)
|
||||
nPerRow := q.from.Shape[0]
|
||||
totalElements := q.from.Elements()
|
||||
if totalElements%nPerRow != 0 {
|
||||
return 0, fmt.Errorf("tensor %s has non-row-aligned shape %v", q.from.Name, q.from.Shape)
|
||||
}
|
||||
var f32s []float32
|
||||
newType := fsggml.TensorType(q.to.Kind)
|
||||
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
||||
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
|
||||
} else {
|
||||
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
|
||||
|
||||
inRowSize := fromType.RowSize(nPerRow)
|
||||
if inRowSize == 0 {
|
||||
return 0, fmt.Errorf("tensor %s has unsupported source type %v", q.from.Name, fromType)
|
||||
}
|
||||
data = ggml.Quantize(newType, f32s, q.from.Shape)
|
||||
n, err := w.Write(data)
|
||||
q.progressFn(q.from.Size())
|
||||
return int64(n), err
|
||||
|
||||
totalRows := totalElements / nPerRow
|
||||
rowsPerChunk := max(quantizationChunkElements/nPerRow, uint64(1))
|
||||
chunkBuf := make([]byte, inRowSize*rowsPerChunk)
|
||||
var written int64
|
||||
|
||||
for row := uint64(0); row < totalRows; {
|
||||
chunkRows := min(rowsPerChunk, totalRows-row)
|
||||
chunkBytes := inRowSize * chunkRows
|
||||
data := chunkBuf[:chunkBytes]
|
||||
|
||||
if _, err := io.ReadFull(sr, data); err != nil {
|
||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||
return written, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err)
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
chunkElements := chunkRows * nPerRow
|
||||
if fromType == fsggml.TensorTypeF32 {
|
||||
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), chunkElements)
|
||||
} else {
|
||||
f32s = ggml.ConvertToF32(data, q.from.Kind, chunkElements)
|
||||
}
|
||||
|
||||
quantized := ggml.Quantize(toType, f32s, []uint64{nPerRow, chunkRows})
|
||||
n, err := w.Write(quantized)
|
||||
written += int64(n)
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
if n != len(quantized) {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
|
||||
if q.progressFn != nil {
|
||||
q.progressFn(chunkBytes)
|
||||
}
|
||||
row += chunkRows
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
type quantizeState struct {
|
||||
|
||||
Reference in New Issue
Block a user