From a0407d07fa900a129afcec42fc222a19f1102d82 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Tue, 10 Feb 2026 11:29:17 -0800 Subject: [PATCH] safetensors quantization for mlx (#14184) This change includes: - changes to the safetensors metadata format - changes to the create command to properly create the blobs with the new format - changes to load the new format - fixes ollama show to properly show each tensor --- x/create/client/create.go | 105 +++--- x/create/client/quantize.go | 241 ++++++++++---- x/create/client/quantize_stub.go | 11 +- x/create/create.go | 187 +++++++---- x/create/create_test.go | 35 +- x/create/imagegen.go | 10 +- x/imagegen/docs/blob-format.md | 158 +++++++++ x/imagegen/manifest/manifest.go | 97 +++++- x/imagegen/manifest/weights.go | 226 +++++++------ x/imagegen/mlx/mlx.go | 47 +++ x/imagegen/safetensors/extractor.go | 92 +++++- x/imagegen/safetensors/loader.go | 2 +- x/server/show.go | 394 ++++++++++++++-------- x/server/show_test.go | 496 ++++++++++++++++++++++++++-- 14 files changed, 1640 insertions(+), 461 deletions(-) create mode 100644 x/imagegen/docs/blob-format.md diff --git a/x/create/client/create.go b/x/create/client/create.go index 36e7f164b..f89f9fc98 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -19,6 +19,7 @@ import ( "github.com/ollama/ollama/progress" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/x/create" + "github.com/ollama/ollama/x/imagegen/safetensors" ) // MinOllamaVersion is the minimum Ollama version required for safetensors models. @@ -35,7 +36,7 @@ type ModelfileConfig struct { type CreateOptions struct { ModelName string ModelDir string - Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization + Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization Modelfile *ModelfileConfig // template/system/license from Modelfile } @@ -94,6 +95,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { newLayerCreator(), newTensorLayerCreator(), newManifestWriter(opts, capabilities, parserName, rendererName), progressFn, + newPackedTensorLayerCreator(), ) } else { err = create.CreateImageGenModel( @@ -141,60 +143,33 @@ func newTensorLayerCreator() create.QuantizingTensorLayerCreator { } } -// createQuantizedLayers quantizes a tensor and returns the resulting layers. +// createQuantizedLayers quantizes a tensor and returns a single combined layer. +// The combined blob contains data, scale, and optional bias tensors with metadata. func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) { if !QuantizeSupported() { return nil, fmt.Errorf("quantization requires MLX support") } - // Quantize the tensor - qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize) + // Quantize the tensor into a single combined blob + blobData, err := quantizeTensor(r, name, dtype, shape, quantize) if err != nil { return nil, fmt.Errorf("failed to quantize %s: %w", name, err) } - // Create layer for quantized weight - weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor) + // Create single layer for the combined blob + layer, err := manifest.NewLayer(bytes.NewReader(blobData), manifest.MediaTypeImageTensor) if err != nil { return nil, err } - // Create layer for scales - scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor) - if err != nil { - return nil, err - } - - layers := []create.LayerInfo{ + return []create.LayerInfo{ { - Digest: weightLayer.Digest, - Size: weightLayer.Size, - MediaType: weightLayer.MediaType, + Digest: layer.Digest, + Size: layer.Size, + MediaType: layer.MediaType, Name: name, }, - { - Digest: scalesLayer.Digest, - Size: scalesLayer.Size, - MediaType: scalesLayer.MediaType, - Name: name + "_scale", - }, - } - - // Add qbiases layer if present (affine mode) - if qbiasData != nil { - qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor) - if err != nil { - return nil, err - } - layers = append(layers, create.LayerInfo{ - Digest: qbiasLayer.Digest, - Size: qbiasLayer.Size, - MediaType: qbiasLayer.MediaType, - Name: name + "_qbias", - }) - } - - return layers, nil + }, nil } // createUnquantizedLayer creates a single tensor layer without quantization. @@ -214,6 +189,58 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error }, nil } +// newPackedTensorLayerCreator returns a PackedTensorLayerCreator callback for +// creating packed multi-tensor blob layers (used for expert groups). +func newPackedTensorLayerCreator() create.PackedTensorLayerCreator { + return func(groupName string, tensors []create.PackedTensorInput) (create.LayerInfo, error) { + // Check if any tensor in the group needs quantization + hasQuantize := false + for _, t := range tensors { + if t.Quantize != "" { + hasQuantize = true + break + } + } + + var blobReader io.Reader + if hasQuantize { + if !QuantizeSupported() { + return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support") + } + blobData, err := quantizePackedGroup(tensors) + if err != nil { + return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err) + } + blobReader = bytes.NewReader(blobData) + } else { + // Build unquantized packed blob using streaming reader + // Extract raw tensor data from safetensors-wrapped readers + var tds []*safetensors.TensorData + for _, t := range tensors { + rawData, err := safetensors.ExtractRawFromSafetensors(t.Reader) + if err != nil { + return create.LayerInfo{}, fmt.Errorf("failed to extract tensor %s: %w", t.Name, err) + } + td := safetensors.NewTensorDataFromBytes(t.Name, t.Dtype, t.Shape, rawData) + tds = append(tds, td) + } + blobReader = safetensors.BuildPackedSafetensorsReader(tds) + } + + layer, err := manifest.NewLayer(blobReader, manifest.MediaTypeImageTensor) + if err != nil { + return create.LayerInfo{}, err + } + + return create.LayerInfo{ + Digest: layer.Digest, + Size: layer.Size, + MediaType: layer.MediaType, + Name: groupName, + }, nil + } +} + // newManifestWriter returns a ManifestWriter callback for writing the model manifest. func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter { return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error { diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index e69003f73..e47f1664d 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -3,128 +3,195 @@ package client import ( + "encoding/binary" + "encoding/json" "fmt" "io" "os" "path/filepath" + "strconv" + "github.com/ollama/ollama/x/create" "github.com/ollama/ollama/x/imagegen/mlx" ) -// quantizeTensor loads a tensor from safetensors format, quantizes it, -// and returns safetensors data for the quantized weights, scales, and biases. -// Supported quantization types: -// - "q4": affine 4-bit, group_size=32 (with qbiases) -// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales) -// - "q8": affine 8-bit, group_size=64 (with qbiases) -// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales) -// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights). -func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { +// quantizeParams maps quantization type names to MLX quantize parameters. +var quantizeParams = map[string]struct { + groupSize int + bits int + mode string +}{ + "int4": {32, 4, "affine"}, + "nvfp4": {16, 4, "nvfp4"}, + "int8": {64, 8, "affine"}, + "mxfp8": {32, 8, "mxfp8"}, +} + +// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX, +// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias) +// to the provided maps. If quantize is empty, the tensor is kept as-is. +// Returns any temp file paths created (caller must clean up) and arrays needing eval. +func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) { tmpDir := ensureTempDir() - // Read safetensors data to a temp file (LoadSafetensorsNative needs a path) - tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors") + tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors") if err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err) + return "", nil, nil, fmt.Errorf("failed to create temp file: %w", err) } - tmpPath := tmpFile.Name() - defer os.Remove(tmpPath) + tmpPath = tmpFile.Name() if _, err := io.Copy(tmpFile, r); err != nil { tmpFile.Close() - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err) + return tmpPath, nil, nil, fmt.Errorf("failed to write temp file for %s: %w", name, err) } tmpFile.Close() - // Load the tensor using MLX's native loader st, err := mlx.LoadSafetensorsNative(tmpPath) if err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err) + return tmpPath, nil, nil, fmt.Errorf("failed to load safetensors for %s: %w", name, err) } - defer st.Free() - // Get the tensor (it's stored as "data" in our minimal safetensors format) - arr := st.Get("data") + // Find the tensor key (may differ from name for single-tensor blobs) + inputKey, err := findSafetensorsKey(tmpPath) + if err != nil { + st.Free() + return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err) + } + + arr := st.Get(inputKey) if arr == nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors") + st.Free() + return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey) } - // Convert to BFloat16 if needed (quantize expects float type) + if quantize == "" { + arr = mlx.Contiguous(arr) + arrays[name] = arr + return tmpPath, []*mlx.Array{arr}, st, nil + } + + // Convert to float type if needed (quantize expects float) if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 { arr = mlx.AsType(arr, mlx.DtypeBFloat16) mlx.Eval(arr) } - // Quantize based on quantization type - var qweight, scales, qbiases *mlx.Array - switch quantize { - case "q4": - // affine mode: group_size=32, bits=4 (with qbiases for zero-point offset) - qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine") - case "nvfp4": - // NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales) - qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4") - case "q8": - // affine mode: group_size=64, bits=8 (with qbiases for zero-point offset) - qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine") - case "mxfp8": - // Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases) - qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8") - default: - return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize) + params, ok := quantizeParams[quantize] + if !ok { + st.Free() + return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize) } - // Eval and make contiguous for data access + qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode) + qweight = mlx.Contiguous(qweight) scales = mlx.Contiguous(scales) + arrays[name] = qweight + arrays[name+".scale"] = scales + toEval = append(toEval, qweight, scales) + if qbiases != nil { qbiases = mlx.Contiguous(qbiases) - mlx.Eval(qweight, scales, qbiases) - } else { - mlx.Eval(qweight, scales) + arrays[name+".bias"] = qbiases + toEval = append(toEval, qbiases) } - // Get shapes - qweightShape = qweight.Shape() - scalesShape = scales.Shape() + return tmpPath, toEval, st, nil +} - // Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype) - qweightPath := filepath.Join(tmpDir, "qweight.safetensors") - defer os.Remove(qweightPath) - if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err) +// quantizeTensor loads a tensor from safetensors format, quantizes it, +// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias. +// Tensor keys use the original tensor name: name, name.scale, name.bias. +// The blob includes __metadata__ with quant_type and group_size. +// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8". +func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) { + arrays := make(map[string]*mlx.Array) + tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays) + if tmpPath != "" { + defer os.Remove(tmpPath) + } + if st != nil { + defer st.Free() } - qweightData, err = os.ReadFile(qweightPath) if err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err) + return nil, err } - // Save scales using MLX's native safetensors - scalesPath := filepath.Join(tmpDir, "scales.safetensors") - defer os.Remove(scalesPath) - if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err) - } - scalesData, err = os.ReadFile(scalesPath) - if err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err) + mlx.Eval(toEval...) + + // Build metadata for single-tensor blobs + params := quantizeParams[quantize] + metadata := map[string]string{ + "quant_type": quantize, + "group_size": strconv.Itoa(params.groupSize), } - // Affine mode returns qbiases for zero-point offset - if qbiases != nil { - qbiasShape = qbiases.Shape() - qbiasPath := filepath.Join(tmpDir, "qbias.safetensors") - defer os.Remove(qbiasPath) - if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err) + tmpDir := ensureTempDir() + outPath := filepath.Join(tmpDir, "combined.safetensors") + defer os.Remove(outPath) + if err := mlx.SaveSafetensorsWithMetadata(outPath, arrays, metadata); err != nil { + return nil, fmt.Errorf("failed to save combined blob: %w", err) + } + return os.ReadFile(outPath) +} + +// quantizePackedGroup quantizes multiple tensors and saves them all into a single +// combined safetensors blob. Used for packing expert groups. +// Each tensor may have a different quantization type (mixed-precision). +// Returns the blob bytes. No __metadata__ is added because different tensors +// may use different quantization types. +func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) { + allArrays := make(map[string]*mlx.Array) + var allToEval []*mlx.Array + var tmpPaths []string + var handles []*mlx.SafetensorsFile + + for _, input := range inputs { + tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays) + if tmpPath != "" { + tmpPaths = append(tmpPaths, tmpPath) + } + if st != nil { + handles = append(handles, st) } - qbiasData, err = os.ReadFile(qbiasPath) if err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err) + // Cleanup on error + for _, h := range handles { + h.Free() + } + for _, p := range tmpPaths { + os.Remove(p) + } + return nil, err } + allToEval = append(allToEval, toEval...) } - return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil + mlx.Eval(allToEval...) + + // Free native handles after eval + for _, h := range handles { + h.Free() + } + + // Save combined blob (no global metadata for mixed-precision packed blobs) + tmpDir := ensureTempDir() + outPath := filepath.Join(tmpDir, "packed-combined.safetensors") + defer os.Remove(outPath) + if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil { + return nil, fmt.Errorf("failed to save packed blob: %w", err) + } + + blobData, err := os.ReadFile(outPath) + if err != nil { + return nil, fmt.Errorf("failed to read packed blob: %w", err) + } + + for _, p := range tmpPaths { + os.Remove(p) + } + + return blobData, nil } // QuantizeSupported returns true if quantization is supported (MLX build) @@ -138,3 +205,33 @@ func ensureTempDir() string { os.MkdirAll(tmpDir, 0755) return tmpDir } + +// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file. +func findSafetensorsKey(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + var headerSize uint64 + if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { + return "", err + } + headerBytes := make([]byte, headerSize) + if _, err := io.ReadFull(f, headerBytes); err != nil { + return "", err + } + + var header map[string]json.RawMessage + if err := json.Unmarshal(headerBytes, &header); err != nil { + return "", err + } + + for k := range header { + if k != "__metadata__" { + return k, nil + } + } + return "", fmt.Errorf("no tensor found in safetensors header") +} diff --git a/x/create/client/quantize_stub.go b/x/create/client/quantize_stub.go index 3a85afcc7..7a75671a0 100644 --- a/x/create/client/quantize_stub.go +++ b/x/create/client/quantize_stub.go @@ -5,11 +5,18 @@ package client import ( "fmt" "io" + + "github.com/ollama/ollama/x/create" ) // quantizeTensor is not available without MLX -func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") +func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) { + return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") +} + +// quantizePackedGroup is not available without MLX +func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) { + return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") } // QuantizeSupported returns false when MLX is not available diff --git a/x/create/create.go b/x/create/create.go index 2474c8c66..385efadab 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -6,7 +6,9 @@ import ( "io" "os" "path/filepath" + "regexp" "slices" + "sort" "strings" "github.com/ollama/ollama/envconfig" @@ -228,7 +230,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error) type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error) // QuantizingTensorLayerCreator creates tensor layers with optional quantization. -// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases). +// When quantize is non-empty (e.g., "int8"), returns multiple layers (weight + scales + biases). type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) // ManifestWriter writes the manifest file. @@ -264,19 +266,19 @@ func ShouldQuantize(name, component string) bool { // ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type. // This is a more detailed check that also considers tensor dimensions. -// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8"). +// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8"). func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool { return GetTensorQuantization(name, shape, quantize) != "" } // normalizeQuantType converts various quantization type aliases to canonical forms. -// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8 +// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8 func normalizeQuantType(quantize string) string { switch strings.ToUpper(quantize) { case "Q4", "INT4", "FP4": - return "q4" + return "int4" case "Q8", "INT8", "FP8": - return "q8" + return "int8" case "NVFP4": return "nvfp4" case "MXFP8": @@ -286,29 +288,12 @@ func normalizeQuantType(quantize string) string { } } -// getQuantGroupSize returns the group size for a given quantization type. -// These must match the values used in quantize.go when creating quantized models. -func getQuantGroupSize(quantize string) int { - switch normalizeQuantType(quantize) { - case "nvfp4": - return 16 - case "q4": - return 32 - case "mxfp8": - return 32 - case "q8": - return 64 - default: - return 32 - } -} - // GetTensorQuantization returns the appropriate quantization type for a tensor. // Returns "" if the tensor should not be quantized. // This implements mixed-precision quantization: // - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive) -// - Output projection, gate/up weights: q4 (less sensitive) -// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel) +// - Output projection, gate/up weights: int4 (less sensitive) +// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel) // - Norms, embeddings, biases, routing gates: no quantization func GetTensorQuantization(name string, shape []int32, quantize string) string { // Use basic name-based check first @@ -330,12 +315,12 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string { quantNorm := normalizeQuantType(quantize) // MLX quantization requires last dimension to be divisible by group size - // nvfp4: 16, q4/mxfp8: 32, q8: 64 + // nvfp4: 16, int4/mxfp8: 32, int8: 64 groupSize := int32(32) switch quantNorm { case "nvfp4": groupSize = 16 - case "q8": + case "int8": groupSize = 64 } if shape[len(shape)-1]%groupSize != 0 { @@ -363,13 +348,13 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string { return "" // No quantization - keep bf16 } - // Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel) + // Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel) // mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj if strings.Contains(name, "down_proj") { - return "q8" + return "int8" } - // Output projection, gate/up weights - use requested quantization (Q4) + // Output projection, gate/up weights - use requested quantization (INT4) // o_proj, gate_proj, up_proj if strings.Contains(name, "o_proj") || strings.Contains(name, "gate_proj") || @@ -386,14 +371,69 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string { return quantNorm } +// expertGroupRegexp matches expert tensor names and captures the group prefix. +// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes) +// Captures: model.layers.{L}.mlp.experts +var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`) + +// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together. +// For example: +// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts" +// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts" +// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts) +// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert) +func ExpertGroupPrefix(tensorName string) string { + m := expertGroupRegexp.FindStringSubmatch(tensorName) + if m == nil { + return "" + } + return m[1] +} + +// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob. +type PackedTensorInput struct { + Name string + Dtype string + Shape []int32 + Quantize string // per-tensor quantization type (may differ within group) + Reader io.Reader // safetensors-wrapped tensor data +} + +// PackedTensorLayerCreator creates a single blob layer containing multiple packed tensors. +// groupName is the group prefix (e.g., "model.layers.1.mlp.experts"). +type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error) + // CreateSafetensorsModel imports a standard safetensors model from a directory. // This handles Hugging Face style models with config.json and *.safetensors files. // Stores each tensor as a separate blob for fine-grained deduplication. -// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized. -func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { +// Expert tensors are packed into per-layer blobs when createPackedLayer is non-nil. +// If quantize is non-empty (e.g., "int8"), eligible tensors will be quantized. +func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string), createPackedLayer ...PackedTensorLayerCreator) error { var layers []LayerInfo var configLayer LayerInfo + // Resolve the optional packed layer creator + var packedCreator PackedTensorLayerCreator + if len(createPackedLayer) > 0 { + packedCreator = createPackedLayer[0] + } + + // Accumulate expert tensors by group prefix for packing. + // Readers reference file-backed SectionReaders, so we keep extractors + // open until each group is flushed to avoid buffering tensor data in memory. + expertGroups := make(map[string][]PackedTensorInput) + var expertGroupOrder []string + + // Track open extractors so we can close them after flushing groups + var openExtractors []*safetensors.TensorExtractor + + closeExtractors := func() { + for _, ext := range openExtractors { + ext.Close() + } + openExtractors = nil + } + entries, err := os.ReadDir(modelDir) if err != nil { return fmt.Errorf("failed to read directory: %w", err) @@ -410,6 +450,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La // Extract individual tensors from safetensors file extractor, err := safetensors.OpenForExtraction(stPath) if err != nil { + closeExtractors() return fmt.Errorf("failed to open %s: %w", stPath, err) } @@ -420,10 +461,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La } fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg)) + // Track whether this extractor has expert tensors that need to stay open + hasExpertTensors := false + for _, tensorName := range tensorNames { td, err := extractor.GetTensor(tensorName) if err != nil { extractor.Close() + closeExtractors() return fmt.Errorf("failed to get tensor %s: %w", tensorName, err) } @@ -434,20 +479,65 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize) } - // Store as minimal safetensors format (88 bytes header overhead) - // This enables native mmap loading via mlx_load_safetensors - // createTensorLayer returns multiple layers if quantizing (weight + scales) - newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType) - if err != nil { - extractor.Close() - return fmt.Errorf("failed to create layer for %s: %w", tensorName, err) + // Check if this tensor belongs to an expert group for packing + groupPrefix := "" + if packedCreator != nil { + groupPrefix = ExpertGroupPrefix(tensorName) + } + + if groupPrefix != "" { + // Accumulate expert tensor for packed blob. + // The Reader uses a file-backed SectionReader, so we must + // keep the extractor open until this group is flushed. + hasExpertTensors = true + if _, exists := expertGroups[groupPrefix]; !exists { + expertGroupOrder = append(expertGroupOrder, groupPrefix) + } + expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{ + Name: tensorName, + Dtype: td.Dtype, + Shape: td.Shape, + Quantize: quantizeType, + Reader: td.SafetensorsReader(), + }) + } else { + // Store as minimal safetensors format (88 bytes header overhead) + // This enables native mmap loading via mlx_load_safetensors + // createTensorLayer returns multiple layers if quantizing (weight + scales) + newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType) + if err != nil { + extractor.Close() + closeExtractors() + return fmt.Errorf("failed to create layer for %s: %w", tensorName, err) + } + layers = append(layers, newLayers...) } - layers = append(layers, newLayers...) } - extractor.Close() + if hasExpertTensors { + // Keep extractor open - readers still reference its file handle + openExtractors = append(openExtractors, extractor) + } else { + extractor.Close() + } } + // Process accumulated expert groups into packed blobs, then close extractors + if packedCreator != nil { + sort.Strings(expertGroupOrder) + for _, groupName := range expertGroupOrder { + tensors := expertGroups[groupName] + fn(fmt.Sprintf("packing %s (%d tensors)", groupName, len(tensors))) + layer, err := packedCreator(groupName, tensors) + if err != nil { + closeExtractors() + return fmt.Errorf("failed to create packed layer for %s: %w", groupName, err) + } + layers = append(layers, layer) + } + } + closeExtractors() + // Process all JSON config files for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") { @@ -487,23 +577,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La return fmt.Errorf("config.json not found in %s", modelDir) } - // Create model_index.json with quantization info if quantizing - if quantize != "" { - modelIndex := map[string]any{ - "quantization": strings.ToUpper(quantize), - "group_size": getQuantGroupSize(quantize), - } - indexData, err := json.MarshalIndent(modelIndex, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal model_index.json: %w", err) - } - indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json") - if err != nil { - return fmt.Errorf("failed to create model_index.json layer: %w", err) - } - layers = append(layers, indexLayer) - } - fn(fmt.Sprintf("writing manifest for %s", modelName)) if err := writeManifest(modelName, configLayer, layers); err != nil { diff --git a/x/create/create_test.go b/x/create/create_test.go index b5d0a7b34..fb48987d6 100644 --- a/x/create/create_test.go +++ b/x/create/create_test.go @@ -586,6 +586,39 @@ func TestShouldQuantizeTensor(t *testing.T) { } } +func TestExpertGroupPrefix(t *testing.T) { + tests := []struct { + name string + want string + }{ + // Expert tensors should return the group prefix + {"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"}, + {"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"}, + {"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"}, + + // Shared expert tensors should return their own group prefix + {"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"}, + {"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"}, + + // Non-expert tensors should return empty string + {"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts + {"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert + {"model.embed_tokens.weight", ""}, // embedding + {"model.layers.0.self_attn.q_proj.weight", ""}, // attention + {"model.norm.weight", ""}, // norm + {"lm_head.weight", ""}, // output head + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExpertGroupPrefix(tt.name) + if got != tt.want { + t.Errorf("ExpertGroupPrefix(%q) = %q, want %q", tt.name, got, tt.want) + } + }) + } +} + func TestCreateSafetensorsModel_WithQuantize(t *testing.T) { dir := t.TempDir() @@ -751,7 +784,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) { progressFn := func(status string) {} - err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn) + err := CreateImageGenModel("test-imagegen", dir, "int8", createLayer, createTensorLayer, writeManifest, progressFn) if err != nil { t.Fatalf("CreateImageGenModel failed: %v", err) } diff --git a/x/create/imagegen.go b/x/create/imagegen.go index 0da0e764a..6dbbcbfcc 100644 --- a/x/create/imagegen.go +++ b/x/create/imagegen.go @@ -15,15 +15,15 @@ import ( // CreateImageGenModel imports an image generation model from a directory. // Stores each tensor as a separate blob for fine-grained deduplication. // If quantize is specified, linear weights in transformer/text_encoder are quantized. -// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization). +// Supported quantization types: int4, int8, nvfp4, mxfp8 (or empty for no quantization). // Layer creation and manifest writing are done via callbacks to avoid import cycles. func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { // Validate quantization type switch quantize { - case "", "q4", "q8", "nvfp4", "mxfp8": + case "", "int4", "int8", "nvfp4", "mxfp8": // valid default: - return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize) + return fmt.Errorf("unsupported quantization type %q: supported types are int4, int8, nvfp4, mxfp8", quantize) } var layers []LayerInfo @@ -214,7 +214,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer // canQuantizeShape returns true if a tensor shape is compatible with MLX quantization. // MLX requires the last dimension to be divisible by the group size. -// nvfp4: 16, q4/mxfp8: 32, q8: 64 +// nvfp4: 16, int4/mxfp8: 32, int8: 64 func canQuantizeShape(shape []int32, quantize string) bool { if len(shape) < 2 { return false @@ -223,7 +223,7 @@ func canQuantizeShape(shape []int32, quantize string) bool { switch strings.ToUpper(quantize) { case "NVFP4": groupSize = 16 - case "Q8": + case "INT8": groupSize = 64 } return shape[len(shape)-1]%groupSize == 0 diff --git a/x/imagegen/docs/blob-format.md b/x/imagegen/docs/blob-format.md new file mode 100644 index 000000000..768f1c2f9 --- /dev/null +++ b/x/imagegen/docs/blob-format.md @@ -0,0 +1,158 @@ +# Tensor Blob Format + +Ollama stores model tensors as individual blobs in the safetensors format. Each blob contains a logical tensor (or a combined quantized tensor with its scale/bias components), or a group of logical tensors (e.g. shared experts for a given layer along with the scale/bias components for that tensor). + +## Safetensors File Format + +Every blob follows the [safetensors](https://github.com/huggingface/safetensors) layout: + +``` +[8 bytes: header_size (uint64 LE)] [header_size bytes: JSON header] [tensor data region] +``` + +The JSON header maps tensor names to their dtype, shape, and byte offsets within the data region. A special `__metadata__` key holds string-to-string metadata. + +## Unquantized Blobs + +An unquantized blob stores a single tensor keyed by its name: + +```json +{ + "model.layers.0.self_attn.q_proj.weight": { + "dtype": "BF16", + "shape": [2560, 2560], + "data_offsets": [0, 13107200] + } +} +``` + +The tensor key is the full tensor name. Dtype is typically `BF16` or `F32`. + +## Quantized Blobs (Combined Format) + +A quantized blob stores the packed weight, scaling factors, and optional zero-point biases in a single file. Tensor keys use the tensor name, with `.scale` and `.bias` suffixes for the auxiliary tensors: + +```json +{ + "__metadata__": { + "quant_type": "int4", + "group_size": "32" + }, + "model.layers.0.mlp.up_proj.weight": { + "dtype": "U32", + "shape": [2560, 320], + "data_offsets": [0, 3276800] + }, + "model.layers.0.mlp.up_proj.weight.scale": { + "dtype": "BF16", + "shape": [2560, 80], + "data_offsets": [3276800, 3686400] + }, + "model.layers.0.mlp.up_proj.weight.bias": { + "dtype": "BF16", + "shape": [2560, 80], + "data_offsets": [3686400, 4096000] + } +} +``` + +### Metadata Fields + +| Field | Description | +|---|---| +| `quant_type` | Quantization type: `int4`, `int8`, `nvfp4`, or `mxfp8` | +| `group_size` | Number of elements per quantization group (e.g., `32`, `64`) | + +### Tensor Keys + +| Key | Description | +|---|---| +| `{name}` | Packed quantized weights (dtype `U32`) | +| `{name}.scale` | Per-group scaling factors | +| `{name}.bias` | Per-group zero-point offsets (affine modes only) | + +## Quantization Types + +| Type | Bits | Group Size | Mode | Has Bias | +|---|---|---|---|---| +| `int4` | 4 | 32 | affine | yes | +| `int8` | 8 | 64 | affine | yes | +| `nvfp4` | 4 | 16 | nvfp4 | no | +| `mxfp8` | 8 | 32 | mxfp8 | no | + +**Affine modes** (`int4`, `int8`) use `scale + bias` for dequantization. The bias tensor provides the zero-point offset. + +**Non-affine modes** (`nvfp4`, `mxfp8`) use only `scale` with specialized E4M3 scale formats. + +### Packed Weight Shape + +Quantized weights are packed into `uint32` values: +- **4-bit** (int4, nvfp4): 8 values per uint32, so `packed_cols = original_cols / 8` +- **8-bit** (int8, mxfp8): 4 values per uint32, so `packed_cols = original_cols / 4` + +Scale shape: `[rows, original_cols / group_size]` + +## Manifest References + +Blobs are referenced from the model manifest as layers: + +```json +{ + "mediaType": "application/vnd.ollama.image.tensor", + "digest": "sha256:abc123...", + "size": 4096150, + "name": "model.layers.0.mlp.up_proj.weight" +} +``` + +Each tensor (quantized or not) is one layer in the manifest. The layer name matches the tensor key in the blob header. + +## Packed Blobs (Expert Groups) + +For MoE (Mixture of Experts) models, expert tensors from the same layer are packed into a single blob to reduce blob count and improve loading efficiency. A packed blob is a standard safetensors file containing multiple tensor entries: + +```json +{ + "model.layers.1.mlp.experts.0.down_proj.weight": { + "dtype": "U32", + "shape": [2560, 640], + "data_offsets": [0, 6553600] + }, + "model.layers.1.mlp.experts.0.down_proj.weight.scale": { + "dtype": "BF16", + "shape": [2560, 40], + "data_offsets": [6553600, 6963200] + }, + "model.layers.1.mlp.experts.0.gate_proj.weight": { + "dtype": "U32", + "shape": [10240, 320], + "data_offsets": [6963200, 20070400] + }, + "model.layers.1.mlp.experts.0.gate_proj.weight.scale": { "..." : "..." } +} +``` + +### Grouping Rules + +- `model.layers.{L}.mlp.experts.*` tensors are packed into one blob per layer +- `model.layers.{L}.mlp.shared_experts.*` tensors are packed into one blob per layer +- All other tensors remain as individual blobs + +### Manifest Representation + +One manifest layer per packed group, using the group prefix as the layer name: + +```json +{ + "mediaType": "application/vnd.ollama.image.tensor", + "digest": "sha256:...", + "size": 123456789, + "name": "model.layers.1.mlp.experts" +} +``` + +## Loading + +At load time, `mlx_load_safetensors` opens each blob via mmap for zero-copy access. For combined quantized blobs, the loader extracts `{name}`, `{name}.scale`, and `{name}.bias` tensors and caches them as `name`, `name + "_scale"`, and `name + "_qbias"` respectively, maintaining compatibility with the weight loading interface. + +For packed blobs, if the manifest layer name (group prefix) is not found as a tensor key, the loader parses the blob header to discover all tensor names and loads each individually. diff --git a/x/imagegen/manifest/manifest.go b/x/imagegen/manifest/manifest.go index 8af0c4630..4de66644c 100644 --- a/x/imagegen/manifest/manifest.go +++ b/x/imagegen/manifest/manifest.go @@ -1,11 +1,13 @@ package manifest import ( + "encoding/binary" "encoding/json" "fmt" "io" "os" "path/filepath" + "sort" "strings" "github.com/ollama/ollama/envconfig" @@ -205,17 +207,12 @@ func GetModelInfo(modelName string) (*ModelInfo, error) { } } - // Fallback: detect quantization from tensor names if not in config + // Fallback: detect quantization from first tensor blob's __metadata__ if info.Quantization == "" { - for _, layer := range manifest.Manifest.Layers { - if strings.HasSuffix(layer.Name, ".weight_scale") { - info.Quantization = "Q8" - break - } - } - if info.Quantization == "" { - info.Quantization = "BF16" - } + info.Quantization = detectQuantizationFromBlobs(manifest) + } + if info.Quantization == "" { + info.Quantization = "BF16" } // Fallback: estimate parameter count if not in config @@ -223,9 +220,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) { var totalSize int64 for _, layer := range manifest.Manifest.Layers { if layer.MediaType == "application/vnd.ollama.image.tensor" { - if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") { - totalSize += layer.Size - } + totalSize += layer.Size } } // Assume BF16 (2 bytes/param) as rough estimate @@ -234,3 +229,79 @@ func GetModelInfo(modelName string) (*ModelInfo, error) { return info, nil } + +// detectQuantizationFromBlobs reads __metadata__ from the first tensor blob +// to detect quantization type. +func detectQuantizationFromBlobs(manifest *ModelManifest) string { + for _, layer := range manifest.Manifest.Layers { + if layer.MediaType != "application/vnd.ollama.image.tensor" { + continue + } + data, err := readBlobHeader(manifest.BlobPath(layer.Digest)) + if err != nil { + continue + } + var header map[string]json.RawMessage + if json.Unmarshal(data, &header) != nil { + continue + } + if metaRaw, ok := header["__metadata__"]; ok { + var meta map[string]string + if json.Unmarshal(metaRaw, &meta) == nil { + if qt, ok := meta["quant_type"]; ok && qt != "" { + return strings.ToUpper(qt) + } + } + } + // Only check the first tensor blob + break + } + return "" +} + +// ParseBlobTensorNames reads a safetensors blob and returns all "main" tensor names. +// Filters out __metadata__, .scale, and .bias entries to return only primary weight tensors. +func ParseBlobTensorNames(path string) ([]string, error) { + data, err := readBlobHeader(path) + if err != nil { + return nil, err + } + + var header map[string]json.RawMessage + if err := json.Unmarshal(data, &header); err != nil { + return nil, err + } + + var names []string + for k := range header { + if k == "__metadata__" || strings.HasSuffix(k, ".scale") || strings.HasSuffix(k, ".bias") { + continue + } + names = append(names, k) + } + + sort.Strings(names) + return names, nil +} + +// readBlobHeader reads the JSON header bytes from a safetensors blob file. +func readBlobHeader(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var headerSize uint64 + if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { + return nil, err + } + if headerSize > 1024*1024 { + return nil, fmt.Errorf("header too large: %d", headerSize) + } + data := make([]byte, headerSize) + if _, err := io.ReadFull(f, data); err != nil { + return nil, err + } + return data, nil +} diff --git a/x/imagegen/manifest/weights.go b/x/imagegen/manifest/weights.go index 63f7d5604..19a6ede07 100644 --- a/x/imagegen/manifest/weights.go +++ b/x/imagegen/manifest/weights.go @@ -5,6 +5,7 @@ package manifest import ( "fmt" "sort" + "strconv" "strings" "github.com/ollama/ollama/x/imagegen/mlx" @@ -18,6 +19,8 @@ type ManifestWeights struct { tensors map[string]ManifestLayer // name -> layer cache map[string]*mlx.Array // name -> loaded array nativeCache []*mlx.SafetensorsFile // keep native handles alive + quantType string // quantization type from blob metadata (e.g., "int4", "int8") + groupSize int // quantization group size from blob metadata } // LoadWeightsFromManifest creates a weight loader from manifest storage. @@ -54,43 +57,115 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife // Load loads all tensor blobs using native mmap (zero-copy). // Blobs are stored in safetensors format for native mlx_load_safetensors mmap. -// If dtype is non-zero, tensors are converted to the specified dtype. +// Combined quantized blobs contain tensors keyed by name, name+".scale", and optional name+".bias" +// with quantization metadata. Scale and bias are stored in cache as name+"_scale" +// and name+"_qbias" for compatibility with downstream loading code. +// Packed blobs (e.g., for expert groups) contain multiple tensors; the manifest name +// is a group prefix and individual tensors are loaded by their actual names from the blob. +// If dtype is non-zero, non-quantized tensors are converted to the specified dtype. func (mw *ManifestWeights) Load(dtype mlx.Dtype) error { // Track native handles to free after batch eval nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors)) arrays := make([]*mlx.Array, 0, len(mw.tensors)) + // Group tensors by digest to avoid loading the same blob multiple times + type blobEntry struct { + name string + layer ManifestLayer + } + blobGroups := make(map[string][]blobEntry) for name, layer := range mw.tensors { - path := mw.manifest.BlobPath(layer.Digest) + blobGroups[layer.Digest] = append(blobGroups[layer.Digest], blobEntry{name, layer}) + } + + for digest, entries := range blobGroups { + path := mw.manifest.BlobPath(digest) // Load blob as safetensors (native mmap, zero-copy) sf, err := mlx.LoadSafetensorsNative(path) if err != nil { - // Free any handles we've accumulated for _, h := range nativeHandles { h.Free() } - return fmt.Errorf("load %s: %w", name, err) + return fmt.Errorf("load %s: %w", entries[0].name, err) } nativeHandles = append(nativeHandles, sf) - // Blob contains single tensor named "data" - arr := sf.Get("data") - if arr == nil { - for _, h := range nativeHandles { - h.Free() + // Read quantization metadata from blob + if qt := sf.GetMetadata("quant_type"); qt != "" && mw.quantType == "" { + mw.quantType = qt + if gs := sf.GetMetadata("group_size"); gs != "" { + mw.groupSize, _ = strconv.Atoi(gs) } - return fmt.Errorf("tensor 'data' not found in blob for %s", name) } - // Convert dtype if needed - if dtype != 0 && arr.Dtype() != dtype { - arr = mlx.AsType(arr, dtype) + for _, entry := range entries { + name := entry.name + + // Try to get tensor by manifest name + arr := sf.Get(name) + if arr != nil { + // Single-tensor blob or tensor found by name + if dtype != 0 && arr.Dtype() != dtype { + arr = mlx.AsType(arr, dtype) + } + arr = mlx.Contiguous(arr) + mw.cache[name] = arr + arrays = append(arrays, arr) + + // Check for scale tensor + if scale := sf.Get(name + ".scale"); scale != nil { + scale = mlx.Contiguous(scale) + mw.cache[name+"_scale"] = scale + arrays = append(arrays, scale) + } + + // Check for bias tensor + if bias := sf.Get(name + ".bias"); bias != nil { + bias = mlx.Contiguous(bias) + mw.cache[name+"_qbias"] = bias + arrays = append(arrays, bias) + } + } else { + // Packed blob: manifest name is a group prefix, not a tensor name. + // Load all individual tensors from the blob. + tensorNames, err := ParseBlobTensorNames(path) + if err != nil { + for _, h := range nativeHandles { + h.Free() + } + return fmt.Errorf("parse packed blob for %s: %w", name, err) + } + + for _, tensorName := range tensorNames { + tArr := sf.Get(tensorName) + if tArr == nil { + continue + } + + if dtype != 0 && tArr.Dtype() != dtype { + tArr = mlx.AsType(tArr, dtype) + } + tArr = mlx.Contiguous(tArr) + mw.cache[tensorName] = tArr + arrays = append(arrays, tArr) + + // Check for scale tensor + if scale := sf.Get(tensorName + ".scale"); scale != nil { + scale = mlx.Contiguous(scale) + mw.cache[tensorName+"_scale"] = scale + arrays = append(arrays, scale) + } + + // Check for bias tensor + if bias := sf.Get(tensorName + ".bias"); bias != nil { + bias = mlx.Contiguous(bias) + mw.cache[tensorName+"_qbias"] = bias + arrays = append(arrays, bias) + } + } + } } - // Make contiguous copy to ensure independence from mmap - arr = mlx.Contiguous(arr) - mw.cache[name] = arr - arrays = append(arrays, arr) } // Batch evaluate all tensors at once (much faster than one at a time) @@ -117,30 +192,50 @@ func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) { } // ListTensors returns all tensor names in sorted order. +// Includes both manifest tensor names and scale/bias entries from combined blobs. func (mw *ManifestWeights) ListTensors() []string { - names := make([]string, 0, len(mw.tensors)) + seen := make(map[string]bool, len(mw.tensors)+len(mw.cache)) for name := range mw.tensors { + seen[name] = true + } + // Also include cache entries (scale/bias from combined blobs) + for name := range mw.cache { + seen[name] = true + } + names := make([]string, 0, len(seen)) + for name := range seen { names = append(names, name) } sort.Strings(names) return names } -// HasTensor checks if a tensor exists. +// HasTensor checks if a tensor exists in the manifest or cache. func (mw *ManifestWeights) HasTensor(name string) bool { - _, ok := mw.tensors[name] - return ok + if _, ok := mw.tensors[name]; ok { + return true + } + // Also check cache for scale/bias entries from combined blobs + if _, ok := mw.cache[name]; ok { + return true + } + return false } -// Quantization returns the model's quantization type from model_index.json. +// Quantization returns the model's quantization type. +// Returns the quant_type from blob metadata (e.g., "int4", "int8", "nvfp4", "mxfp8"). // Returns empty string if not quantized. -// Falls back to detecting from tensor names and shapes if not in config. +// Falls back to model_index.json for image gen models. func (mw *ManifestWeights) Quantization() string { + if mw.quantType != "" { + return strings.ToUpper(mw.quantType) + } + if mw.manifest == nil { return "" } - // Try to read from model_index.json first + // Fallback: read from model_index.json (for image gen models) var index struct { Quantization string `json:"quantization"` } @@ -148,89 +243,22 @@ func (mw *ManifestWeights) Quantization() string { return index.Quantization } - // Fallback: detect from tensor names - // Check if any tensors have _scale suffix (indicates quantization) - hasScales := false - hasQBias := false - for name := range mw.tensors { - if strings.HasSuffix(name, ".weight_scale") { - hasScales = true - } - if strings.HasSuffix(name, ".weight_qbias") { - hasQBias = true - } - } - - if !hasScales { - // No scales = not quantized - return "" - } - - // Has scales but no qbias = NVFP4 (or other non-affine mode) - if !hasQBias { - return "NVFP4" - } - - // Has both scales and qbias = affine mode - // Need to determine FP4 vs FP8 from tensor shapes - // FP4: weight last dim is 1/8 of scales last dim * group_size - // FP8: weight last dim is 1/4 of scales last dim * group_size - // - // For affine mode with group_size=32: - // - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8 - // - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4 - // scales_dim = orig_dim / group_size - // So: weight_dim / scales_dim = group_size / pack_factor - // FP4: ratio = 32/8 = 4 - // FP8: ratio = 32/4 = 8 - - // Find a weight/scale pair to check the ratio - for name := range mw.tensors { - if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") { - continue - } - scaleName := name + "_scale" - if _, ok := mw.tensors[scaleName]; !ok { - continue - } - - // Load both tensors to check shapes - weightLayer := mw.tensors[name] - scaleLayer := mw.tensors[scaleName] - - // Get shapes from manifest layer metadata if available - // For now, default to FP4 since it's more common - // The actual shape check would require loading the tensor - - // Simple heuristic: check if scale tensor is ~4x smaller than weight - // FP4: weight is packed 8 per uint32, scales are 1 per group (32) - // So scale size should be ~weight_size * 8 / 32 = weight_size / 4 - // FP8: weight is packed 4 per uint32, scales are 1 per group (32) - // So scale size should be ~weight_size * 4 / 32 = weight_size / 8 - - // Rough size heuristic (assuming float16 scales) - // Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8 - // Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16 - ratio := float64(weightLayer.Size) / float64(scaleLayer.Size) - if ratio < 12 { - // Closer to 8 = Q4 - return "Q4" - } - // Closer to 16 = Q8 - return "Q8" - } - - // Default to Q4 for affine mode (most common) - return "Q4" + return "" } -// GroupSize returns the quantization group size from model_index.json. +// GroupSize returns the quantization group size. +// Returns the group_size from blob metadata. // Returns 0 if not specified (caller should use default based on quantization type). func (mw *ManifestWeights) GroupSize() int { + if mw.groupSize > 0 { + return mw.groupSize + } + if mw.manifest == nil { return 0 } + // Fallback: read from model_index.json (for image gen models) var index struct { GroupSize int `json:"group_size"` } diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go index 2232e482b..cf3e51572 100644 --- a/x/imagegen/mlx/mlx.go +++ b/x/imagegen/mlx/mlx.go @@ -1544,6 +1544,18 @@ func (s *SafetensorsFile) Count() int { return 0 } +// GetMetadata retrieves a metadata value by key from the safetensors file +func (s *SafetensorsFile) GetMetadata(key string) string { + cKey := C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + + var cValue *C.char + if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 { + return "" + } + return C.GoString(cValue) +} + // Free releases the safetensors file func (s *SafetensorsFile) Free() { C.mlx_map_string_to_array_free(s.arrays) @@ -1578,6 +1590,41 @@ func SaveSafetensors(path string, arrays map[string]*Array) error { return nil } +// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata key/value pairs. +// This is like SaveSafetensors but inserts metadata into the __metadata__ section. +func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + // Create the array map + cArrays := C.mlx_map_string_to_array_new() + defer C.mlx_map_string_to_array_free(cArrays) + + for name, arr := range arrays { + cName := C.CString(name) + C.mlx_map_string_to_array_insert(cArrays, cName, arr.c) + C.free(unsafe.Pointer(cName)) + } + + // Create metadata map + cMeta := C.mlx_map_string_to_string_new() + defer C.mlx_map_string_to_string_free(cMeta) + + for key, value := range metadata { + cKey := C.CString(key) + cValue := C.CString(value) + C.mlx_map_string_to_string_insert(cMeta, cKey, cValue) + C.free(unsafe.Pointer(cKey)) + C.free(unsafe.Pointer(cValue)) + } + + // Save + if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 { + return fmt.Errorf("failed to save safetensors: %s", path) + } + return nil +} + // ============ NPY Loading ============ // LoadNpy loads a numpy array from an npy file diff --git a/x/imagegen/safetensors/extractor.go b/x/imagegen/safetensors/extractor.go index b14dfe965..65a4f6da0 100644 --- a/x/imagegen/safetensors/extractor.go +++ b/x/imagegen/safetensors/extractor.go @@ -41,13 +41,11 @@ func (td *TensorData) Reader() io.Reader { return td.reader } -// SafetensorsReader returns a reader that outputs the tensor wrapped in -// minimal safetensors format. This allows using mlx_load_safetensors on -// individual tensor blobs for native zero-copy loading. -func (td *TensorData) SafetensorsReader() io.Reader { - // Build minimal safetensors header with tensor named "data" - header := map[string]tensorInfo{ - "data": { +// safetensorsHeader builds the JSON header for a minimal safetensors blob +// containing a single tensor keyed by its name. +func (td *TensorData) safetensorsHeader() []byte { + header := map[string]any{ + td.Name: tensorInfo{ Dtype: td.Dtype, Shape: td.Shape, DataOffsets: [2]int{0, int(td.Size)}, @@ -58,6 +56,15 @@ func (td *TensorData) SafetensorsReader() io.Reader { // Pad header to 8-byte alignment padding := (8 - len(headerJSON)%8) % 8 headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...) + return headerJSON +} + +// SafetensorsReader returns a reader that outputs the tensor wrapped in +// minimal safetensors format. This allows using mlx_load_safetensors on +// individual tensor blobs for native zero-copy loading. +// The tensor is keyed by its name in the safetensors header. +func (td *TensorData) SafetensorsReader() io.Reader { + headerJSON := td.safetensorsHeader() // Build header with size prefix headerBuf := new(bytes.Buffer) @@ -71,16 +78,77 @@ func (td *TensorData) SafetensorsReader() io.Reader { // SafetensorsSize returns the total size of the safetensors-wrapped tensor. func (td *TensorData) SafetensorsSize() int64 { - header := map[string]tensorInfo{ - "data": { + headerJSON := td.safetensorsHeader() + return 8 + int64(len(headerJSON)) + td.Size +} + +// NewTensorDataFromBytes creates a TensorData from raw tensor bytes. +// This is useful for constructing packed blobs from already-extracted data. +func NewTensorDataFromBytes(name, dtype string, shape []int32, rawData []byte) *TensorData { + return &TensorData{ + Name: name, + Dtype: dtype, + Shape: shape, + Size: int64(len(rawData)), + reader: io.NewSectionReader(bytes.NewReader(rawData), 0, int64(len(rawData))), + } +} + +// ExtractRawFromSafetensors reads a safetensors-wrapped reader and extracts +// the raw tensor data bytes (stripping the header). +func ExtractRawFromSafetensors(r io.Reader) ([]byte, error) { + // Read header size (8 bytes, little endian) + var headerSize uint64 + if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil { + return nil, fmt.Errorf("failed to read header size: %w", err) + } + + // Skip header + if _, err := io.CopyN(io.Discard, r, int64(headerSize)); err != nil { + return nil, fmt.Errorf("failed to skip header: %w", err) + } + + // Read remaining bytes (the raw tensor data) + return io.ReadAll(r) +} + +// BuildPackedSafetensorsReader builds a streaming io.Reader that outputs a valid +// safetensors file containing multiple tensors. Used for packing expert tensors +// into a single blob without loading all data into memory. +// Each TensorData must have been obtained from GetTensor. +func BuildPackedSafetensorsReader(tensors []*TensorData) io.Reader { + // Build the header with sequential data offsets + header := make(map[string]tensorInfo, len(tensors)) + var offset int + for _, td := range tensors { + header[td.Name] = tensorInfo{ Dtype: td.Dtype, Shape: td.Shape, - DataOffsets: [2]int{0, int(td.Size)}, - }, + DataOffsets: [2]int{offset, offset + int(td.Size)}, + } + offset += int(td.Size) } + headerJSON, _ := json.Marshal(header) + + // Pad header to 8-byte alignment padding := (8 - len(headerJSON)%8) % 8 - return 8 + int64(len(headerJSON)) + int64(padding) + td.Size + headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...) + + // Build header with size prefix + headerBuf := new(bytes.Buffer) + binary.Write(headerBuf, binary.LittleEndian, uint64(len(headerJSON))) + headerBuf.Write(headerJSON) + + // Build multi-reader: header + all tensor data readers + readers := make([]io.Reader, 0, 1+len(tensors)) + readers = append(readers, headerBuf) + for _, td := range tensors { + td.reader.Seek(0, io.SeekStart) + readers = append(readers, td.reader) + } + + return io.MultiReader(readers...) } // OpenForExtraction opens a safetensors file for tensor extraction. diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go index 4c1d0a9af..d0426a2ef 100644 --- a/x/imagegen/safetensors/loader.go +++ b/x/imagegen/safetensors/loader.go @@ -17,7 +17,7 @@ type WeightSource interface { GetTensor(name string) (*mlx.Array, error) ListTensors() []string HasTensor(name string) bool - Quantization() string // Returns "NVFP4", "Q4", "Q8", or "" + Quantization() string // Returns "NVFP4", "INT4", "INT8", or "" GroupSize() int // Returns quantization group size, or 0 if not specified } diff --git a/x/server/show.go b/x/server/show.go index 652293e77..ec6df2d3d 100644 --- a/x/server/show.go +++ b/x/server/show.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "sort" "strings" "github.com/ollama/ollama/api" @@ -105,9 +106,9 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map bytesPerParam = 1 } - // Subtract safetensors header overhead (88 bytes per tensor file) - // Each tensor is stored as a minimal safetensors file - totalBytes := totalTensorBytes - tensorCount*88 + // Subtract safetensors header overhead per tensor blob. + // Headers include __metadata__ with the tensor name, so overhead is ~150 bytes on average. + totalBytes := totalTensorBytes - tensorCount*150 paramCount := totalBytes / bytesPerParam @@ -163,24 +164,103 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) { // getTensorInfoFromManifest extracts tensor info from a manifest. // This is separated for testability. -// For quantized models, groups weight/scale/qbias into single entries with detected quantization type. +// For quantized tensors, reads quant_type from blob __metadata__. +// For packed blobs (multiple tensors per blob), enumerates all tensors in the blob. func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) { var tensors []api.Tensor - // First pass: collect all tensor info and identify scale tensors - type tensorData struct { - info *safetensorsTensorInfo - digest string - } - tensorMap := make(map[string]*tensorData) - scaleMap := make(map[string]*tensorData) // base name -> scale tensor info - for _, layer := range mf.Layers { if layer.MediaType != manifest.MediaTypeImageTensor { continue } - // Read the safetensors header from the blob + // Read all tensor entries from the safetensors header + blobPath, err := manifest.BlobsPath(layer.Digest) + if err != nil { + continue + } + + f, err := os.Open(blobPath) + if err != nil { + continue + } + + allInfos, err := parseSafetensorsAllHeaders(f) + f.Close() + if err != nil { + continue + } + + // Determine if this is a packed blob (multiple main tensors) + isPacked := len(allInfos) > 1 + + for _, info := range allInfos { + tensorName := layer.Name + if isPacked { + // For packed blobs, use the tensor name from the header + tensorName = info.Name + } + + if info.QuantType != "" { + quantType := strings.ToUpper(info.QuantType) + + shape := make([]uint64, len(info.Shape)) + for i, s := range info.Shape { + shape[i] = uint64(s) + } + + var packFactor int64 + switch strings.ToLower(info.QuantType) { + case "int4", "nvfp4": + packFactor = 8 + case "int8", "mxfp8": + packFactor = 4 + } + if packFactor > 0 && len(shape) >= 2 { + shape[len(shape)-1] = uint64(info.Shape[len(info.Shape)-1] * packFactor) + } + + tensors = append(tensors, api.Tensor{ + Name: tensorName, + Type: quantType, + Shape: shape, + }) + } else { + shape := make([]uint64, len(info.Shape)) + for i, s := range info.Shape { + shape[i] = uint64(s) + } + + tensors = append(tensors, api.Tensor{ + Name: tensorName, + Type: info.Dtype, + Shape: shape, + }) + } + } + } + + sort.Slice(tensors, func(i, j int) bool { + return tensors[i].Name < tensors[j].Name + }) + + return tensors, nil +} + +// GetSafetensorsDtype returns the quantization type for a safetensors model. +// Reads quant_type from the first tensor blob's __metadata__. +// Falls back to torch_dtype from config.json if no quant metadata. +func GetSafetensorsDtype(name model.Name) (string, error) { + mf, err := manifest.ParseNamedManifest(name) + if err != nil { + return "", fmt.Errorf("failed to load manifest: %w", err) + } + + // Check first tensor blob for quant_type metadata + for _, layer := range mf.Layers { + if layer.MediaType != manifest.MediaTypeImageTensor { + continue + } blobPath, err := manifest.BlobsPath(layer.Digest) if err != nil { continue @@ -189,131 +269,11 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) { if err != nil { continue } - - td := &tensorData{info: info, digest: layer.Digest} - - if strings.HasSuffix(layer.Name, "_scale") { - baseName := strings.TrimSuffix(layer.Name, "_scale") - scaleMap[baseName] = td - } else if strings.HasSuffix(layer.Name, "_qbias") { - // Skip qbias tensors - they're included with the quantized weight - continue - } else { - tensorMap[layer.Name] = td + if info.QuantType != "" { + return strings.ToUpper(info.QuantType), nil } - } - - // Second pass: build tensor list with quantization info - for _, layer := range mf.Layers { - if layer.MediaType != manifest.MediaTypeImageTensor { - continue - } - - // Skip scale and qbias tensors - if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") { - continue - } - - td := tensorMap[layer.Name] - if td == nil { - continue - } - - // Check if this tensor has a corresponding scale tensor (quantized) - scaleTd := scaleMap[layer.Name] - if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 { - // Quantized tensor - detect bits from shapes - weightCols := td.info.Shape[len(td.info.Shape)-1] - scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1] - - // Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4 - // Q4 uses group_size=32: weightCols * 8 / scaleCols = 32 - // Q8 uses group_size=64: weightCols * 4 / scaleCols = 64 - var bits int - var quantType string - if weightCols*8/scaleCols == 32 { - bits = 4 - quantType = "Q4" - } else if weightCols*4/scaleCols == 64 { - bits = 8 - quantType = "Q8" - } else { - // Unknown quantization, show raw - quantType = td.info.Dtype - } - - // Calculate unpacked shape - shape := make([]uint64, len(td.info.Shape)) - for i, s := range td.info.Shape { - shape[i] = uint64(s) - } - if bits > 0 { - packFactor := int64(32 / bits) - shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor) - } - - tensors = append(tensors, api.Tensor{ - Name: layer.Name, - Type: quantType, - Shape: shape, - }) - } else { - // Non-quantized tensor - shape := make([]uint64, len(td.info.Shape)) - for i, s := range td.info.Shape { - shape[i] = uint64(s) - } - - tensors = append(tensors, api.Tensor{ - Name: layer.Name, - Type: td.info.Dtype, - Shape: shape, - }) - } - } - - return tensors, nil -} - -// GetSafetensorsDtype returns the quantization type for a safetensors model. -// Reads from model_index.json first, falls back to detection from tensor names. -// Otherwise returns the torch_dtype from config.json. -func GetSafetensorsDtype(name model.Name) (string, error) { - mf, err := manifest.ParseNamedManifest(name) - if err != nil { - return "", fmt.Errorf("failed to load manifest: %w", err) - } - - // First try to read quantization from model_index.json - var modelIndex struct { - Quantization string `json:"quantization"` - } - if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" { - return modelIndex.Quantization, nil - } - - // Fallback: detect from tensor names - hasScales := false - hasQBias := false - for _, layer := range mf.Layers { - if layer.MediaType == manifest.MediaTypeImageTensor { - if strings.HasSuffix(layer.Name, "_scale") { - hasScales = true - } - if strings.HasSuffix(layer.Name, "_qbias") { - hasQBias = true - } - } - } - - if hasScales { - if hasQBias { - // Affine mode (has scale + qbias) - could be Q4 or Q8 - // Default to Q4 as it's more common - return "Q4", nil - } - // No qbias = NVFP4 - return "NVFP4", nil + // Only check the first tensor blob + break } // Not quantized - return torch_dtype from config.json @@ -329,8 +289,11 @@ func GetSafetensorsDtype(name model.Name) (string, error) { // safetensorsTensorInfo holds metadata about a tensor from a safetensors header type safetensorsTensorInfo struct { - Dtype string `json:"dtype"` - Shape []int64 `json:"shape"` + Name string // tensor name from the header key + Dtype string `json:"dtype"` + Shape []int64 `json:"shape"` + QuantType string // from __metadata__.quant_type (e.g., "int4", "int8", "nvfp4", "mxfp8") + GroupSize string // from __metadata__.group_size (e.g., "32", "64") } // readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata. @@ -347,6 +310,7 @@ func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) { // parseSafetensorsHeader parses a safetensors header from a reader. // This is separated for testability. +// Parses __metadata__ for quant_type and group_size if present. func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) { // Read header size (8 bytes, little endian) var headerSize uint64 @@ -371,7 +335,31 @@ func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) { return nil, fmt.Errorf("failed to parse header: %w", err) } - // Find the first (and should be only) tensor entry + // Parse metadata if present + var quantType, groupSize string + if metaRaw, ok := header["__metadata__"]; ok { + var meta map[string]string + if json.Unmarshal(metaRaw, &meta) == nil { + quantType = meta["quant_type"] + groupSize = meta["group_size"] + } + } + + // Find the main tensor entry (not __metadata__, .scale, or .bias) + for name, raw := range header { + if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") { + continue + } + var info safetensorsTensorInfo + if err := json.Unmarshal(raw, &info); err != nil { + return nil, fmt.Errorf("failed to parse tensor info: %w", err) + } + info.QuantType = quantType + info.GroupSize = groupSize + return &info, nil + } + + // Fall back to first non-metadata tensor entry for name, raw := range header { if name == "__metadata__" { continue @@ -380,8 +368,134 @@ func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) { if err := json.Unmarshal(raw, &info); err != nil { return nil, fmt.Errorf("failed to parse tensor info: %w", err) } + info.QuantType = quantType + info.GroupSize = groupSize return &info, nil } return nil, fmt.Errorf("no tensor found in header") } + +// parseSafetensorsAllHeaders parses all tensor entries from a safetensors header. +// Returns one safetensorsTensorInfo per main tensor (skipping __metadata__, .scale, .bias). +// For packed blobs this returns multiple entries; for single-tensor blobs, one entry. +// Each tensor's quant type is inferred from its shape and the presence of .scale/.bias entries +// when no global __metadata__ quant_type is present. +func parseSafetensorsAllHeaders(r io.Reader) ([]safetensorsTensorInfo, error) { + var headerSize uint64 + if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil { + return nil, fmt.Errorf("failed to read header size: %w", err) + } + + if headerSize > 100*1024*1024 { // 100MB limit for packed blob headers + return nil, fmt.Errorf("header size too large: %d", headerSize) + } + + headerBytes := make([]byte, headerSize) + if _, err := io.ReadFull(r, headerBytes); err != nil { + return nil, fmt.Errorf("failed to read header: %w", err) + } + + var header map[string]json.RawMessage + if err := json.Unmarshal(headerBytes, &header); err != nil { + return nil, fmt.Errorf("failed to parse header: %w", err) + } + + // Parse global metadata if present + var globalQuantType, globalGroupSize string + if metaRaw, ok := header["__metadata__"]; ok { + var meta map[string]string + if json.Unmarshal(metaRaw, &meta) == nil { + globalQuantType = meta["quant_type"] + globalGroupSize = meta["group_size"] + } + } + + // Build a set of all keys for checking .scale/.bias presence + headerKeys := make(map[string]bool, len(header)) + for k := range header { + headerKeys[k] = true + } + + // Collect all main tensor entries (sorted for deterministic output) + var mainNames []string + for name := range header { + if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") { + continue + } + mainNames = append(mainNames, name) + } + sort.Strings(mainNames) + + var results []safetensorsTensorInfo + for _, name := range mainNames { + var info safetensorsTensorInfo + if err := json.Unmarshal(header[name], &info); err != nil { + return nil, fmt.Errorf("failed to parse tensor info for %s: %w", name, err) + } + info.Name = name + + if globalQuantType != "" { + // Use global metadata + info.QuantType = globalQuantType + info.GroupSize = globalGroupSize + } else if headerKeys[name+".scale"] { + // No global metadata, but has .scale - infer quant type from shape + info.QuantType = inferQuantType(header, name) + } + + results = append(results, info) + } + + if len(results) == 0 { + return nil, fmt.Errorf("no tensor found in header") + } + + return results, nil +} + +// inferQuantType infers the quantization type for a tensor from its shape and scale shape. +// Returns "int4", "int8", etc. or "" if not quantized. +func inferQuantType(header map[string]json.RawMessage, name string) string { + // Parse the main tensor shape + var mainInfo struct { + Shape []int64 `json:"shape"` + } + if json.Unmarshal(header[name], &mainInfo) != nil || len(mainInfo.Shape) < 2 { + return "" + } + + // Parse scale shape to determine group size + scaleRaw, ok := header[name+".scale"] + if !ok { + return "" + } + var scaleInfo struct { + Shape []int64 `json:"shape"` + } + if json.Unmarshal(scaleRaw, &scaleInfo) != nil || len(scaleInfo.Shape) < 2 { + return "" + } + + // Calculate group size: main_cols * pack_factor / scale_cols + // Main dtype is U32, so we need to figure out the pack factor + // For int4: pack=8, group=32. scale_cols = original_cols / 32 = main_cols * 8 / 32 = main_cols / 4 + // For int8: pack=4, group=64. scale_cols = original_cols / 64 = main_cols * 4 / 64 = main_cols / 16 + mainCols := mainInfo.Shape[len(mainInfo.Shape)-1] + scaleCols := scaleInfo.Shape[len(scaleInfo.Shape)-1] + if scaleCols == 0 { + return "" + } + + ratio := mainCols / scaleCols // main_packed_cols / scale_cols + // int4: ratio = (orig/8) / (orig/32) = 32/8 = 4 + // int8: ratio = (orig/4) / (orig/64) = 64/4 = 16 + switch ratio { + case 4: + return "int4" + case 16: + return "int8" + default: + return "" + } +} diff --git a/x/server/show_test.go b/x/server/show_test.go index be57758f8..5e8ba62fa 100644 --- a/x/server/show_test.go +++ b/x/server/show_test.go @@ -36,7 +36,7 @@ func TestBuildModelInfo(t *testing.T) { VocabSize: 262144, TorchDtype: "bfloat16", }, - totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header + totalTensorBytes: 8_600_000_150, // ~4.3B params * 2 bytes + 150 bytes header tensorCount: 1, wantArch: "gemma3", wantContextLen: 131072, @@ -57,7 +57,7 @@ func TestBuildModelInfo(t *testing.T) { VocabSize: 32000, TorchDtype: "float16", }, - totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header + totalTensorBytes: 14_000_000_150, // ~7B params * 2 bytes + 150 bytes header tensorCount: 1, wantArch: "llama", wantContextLen: 4096, @@ -84,7 +84,7 @@ func TestBuildModelInfo(t *testing.T) { VocabSize: 262144, TorchDtype: "bfloat16", }, - totalTensorBytes: 8_600_000_088, + totalTensorBytes: 8_600_000_150, tensorCount: 1, wantArch: "gemma3", wantContextLen: 131072, @@ -101,7 +101,7 @@ func TestBuildModelInfo(t *testing.T) { MaxPositionEmbeddings: 2048, TorchDtype: "float32", }, - totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header + totalTensorBytes: 400_000_150, // 100M params * 4 bytes + 150 bytes header tensorCount: 1, wantArch: "test", wantContextLen: 2048, @@ -118,7 +118,7 @@ func TestBuildModelInfo(t *testing.T) { MaxPositionEmbeddings: 1024, TorchDtype: "bfloat16", }, - totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes + totalTensorBytes: 2_001_500, // 1M params * 2 bytes + 10 tensors * 150 bytes tensorCount: 10, wantArch: "test", wantContextLen: 1024, @@ -230,42 +230,42 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) { { name: "bfloat16", dtype: "bfloat16", - totalBytes: 2_000_088, // 1M * 2 + 88 + totalBytes: 2_000_150, // 1M * 2 + 150 tensorCount: 1, wantParamCount: 1_000_000, }, { name: "float16", dtype: "float16", - totalBytes: 2_000_088, + totalBytes: 2_000_150, tensorCount: 1, wantParamCount: 1_000_000, }, { name: "float32", dtype: "float32", - totalBytes: 4_000_088, // 1M * 4 + 88 + totalBytes: 4_000_150, // 1M * 4 + 150 tensorCount: 1, wantParamCount: 1_000_000, }, { name: "int8", dtype: "int8", - totalBytes: 1_000_088, // 1M * 1 + 88 + totalBytes: 1_000_150, // 1M * 1 + 150 tensorCount: 1, wantParamCount: 1_000_000, }, { name: "unknown dtype defaults to 2 bytes", dtype: "unknown", - totalBytes: 2_000_088, + totalBytes: 2_000_150, tensorCount: 1, wantParamCount: 1_000_000, }, { name: "empty dtype defaults to 2 bytes", dtype: "", - totalBytes: 2_000_088, + totalBytes: 2_000_150, tensorCount: 1, wantParamCount: 1_000_000, }, @@ -288,11 +288,13 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) { func TestParseSafetensorsHeader(t *testing.T) { tests := []struct { - name string - header map[string]any - wantDtype string - wantShape []int64 - wantErr bool + name string + header map[string]any + wantDtype string + wantShape []int64 + wantQuantType string + wantGroupSize string + wantErr bool }{ { name: "simple tensor", @@ -307,7 +309,70 @@ func TestParseSafetensorsHeader(t *testing.T) { wantShape: []int64{2560, 262144}, }, { - name: "with metadata", + name: "tensor keyed by name", + header: map[string]any{ + "model.layers.0.weight": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 2560}, + "data_offsets": []int64{0, 13107200}, + }, + }, + wantDtype: "BF16", + wantShape: []int64{2560, 2560}, + }, + { + name: "with int4 quant metadata", + header: map[string]any{ + "__metadata__": map[string]any{ + "quant_type": "int4", + "group_size": "32", + }, + "model.layers.0.mlp.up_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{2560, 320}, + "data_offsets": []int64{0, 3276800}, + }, + "model.layers.0.mlp.up_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 80}, + "data_offsets": []int64{3276800, 3686400}, + }, + "model.layers.0.mlp.up_proj.weight.bias": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 80}, + "data_offsets": []int64{3686400, 4096000}, + }, + }, + wantDtype: "U32", + wantShape: []int64{2560, 320}, + wantQuantType: "int4", + wantGroupSize: "32", + }, + { + name: "int8 quant metadata", + header: map[string]any{ + "__metadata__": map[string]any{ + "quant_type": "int8", + "group_size": "64", + }, + "model.layers.0.mlp.down_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{2560, 640}, + "data_offsets": []int64{0, 6553600}, + }, + "model.layers.0.mlp.down_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 40}, + "data_offsets": []int64{6553600, 6963200}, + }, + }, + wantDtype: "U32", + wantShape: []int64{2560, 640}, + wantQuantType: "int8", + wantGroupSize: "64", + }, + { + name: "with old-style format metadata", header: map[string]any{ "__metadata__": map[string]any{ "format": "pt", @@ -371,6 +436,13 @@ func TestParseSafetensorsHeader(t *testing.T) { } } } + + if info.QuantType != tt.wantQuantType { + t.Errorf("QuantType = %v, want %v", info.QuantType, tt.wantQuantType) + } + if info.GroupSize != tt.wantGroupSize { + t.Errorf("GroupSize = %v, want %v", info.GroupSize, tt.wantGroupSize) + } }) } } @@ -460,7 +532,7 @@ func TestGetTensorInfoFromManifest(t *testing.T) { t.Fatalf("failed to create blobs dir: %v", err) } - // Create test tensor blobs + // Create test tensor blobs with __metadata__ tensors := []struct { name string digest string @@ -487,10 +559,9 @@ func TestGetTensorInfoFromManifest(t *testing.T) { }, } - // Create blob files + // Create blob files with tensor keyed by name var layers []manifest.Layer for _, tensor := range tensors { - // Create safetensors blob header := map[string]any{ tensor.name: map[string]any{ "dtype": tensor.dtype, @@ -561,6 +632,391 @@ func TestGetTensorInfoFromManifest(t *testing.T) { } } +func TestGetTensorInfoFromManifest_Quantized(t *testing.T) { + // Create a temp directory for blobs and set OLLAMA_MODELS + tempDir := t.TempDir() + t.Setenv("OLLAMA_MODELS", tempDir) + + blobDir := filepath.Join(tempDir, "blobs") + if err := os.MkdirAll(blobDir, 0o755); err != nil { + t.Fatalf("failed to create blobs dir: %v", err) + } + + // Create a combined quantized blob with __metadata__ + header := map[string]any{ + "__metadata__": map[string]string{ + "quant_type": "int4", + "group_size": "32", + }, + "model.layers.0.mlp.up_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{2560, 320}, // packed: 2560 / 8 = 320 + "data_offsets": []int64{0, 3276800}, + }, + "model.layers.0.mlp.up_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 80}, // 2560 / 32 = 80 + "data_offsets": []int64{3276800, 3686400}, + }, + "model.layers.0.mlp.up_proj.weight.bias": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 80}, + "data_offsets": []int64{3686400, 4096000}, + }, + } + headerJSON, _ := json.Marshal(header) + + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))) + buf.Write(headerJSON) + + digest := "sha256:aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb" + blobPath, err := manifest.BlobsPath(digest) + if err != nil { + t.Fatalf("failed to get blob path: %v", err) + } + if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("failed to write blob: %v", err) + } + + mf := &manifest.Manifest{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Layers: []manifest.Layer{ + { + MediaType: manifest.MediaTypeImageTensor, + Digest: digest, + Size: int64(buf.Len() + 4096000), + Name: "model.layers.0.mlp.up_proj.weight", + }, + }, + } + + result, err := getTensorInfoFromManifest(mf) + if err != nil { + t.Fatalf("getTensorInfoFromManifest() error = %v", err) + } + + if len(result) != 1 { + t.Fatalf("got %d tensors, want 1", len(result)) + } + + tensor := result[0] + if tensor.Name != "model.layers.0.mlp.up_proj.weight" { + t.Errorf("Name = %v, want model.layers.0.mlp.up_proj.weight", tensor.Name) + } + if tensor.Type != "INT4" { + t.Errorf("Type = %v, want INT4", tensor.Type) + } + // Shape should be unpacked: 320 * 8 = 2560 + if len(tensor.Shape) != 2 || tensor.Shape[0] != 2560 || tensor.Shape[1] != 2560 { + t.Errorf("Shape = %v, want [2560, 2560]", tensor.Shape) + } +} + +func TestParseSafetensorsAllHeaders(t *testing.T) { + tests := []struct { + name string + header map[string]any + wantCount int + wantNames []string + wantDtypes []string + wantQuants []string + wantErr bool + }{ + { + name: "single tensor blob", + header: map[string]any{ + "model.layers.0.weight": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 2560}, + "data_offsets": []int64{0, 13107200}, + }, + }, + wantCount: 1, + wantNames: []string{"model.layers.0.weight"}, + wantDtypes: []string{"BF16"}, + wantQuants: []string{""}, + }, + { + name: "packed unquantized blob", + header: map[string]any{ + "model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 10240}, + "data_offsets": []int64{0, 52428800}, + }, + "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 2560}, + "data_offsets": []int64{52428800, 104857600}, + }, + "model.layers.0.mlp.experts.0.up_proj.weight": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 2560}, + "data_offsets": []int64{104857600, 157286400}, + }, + }, + wantCount: 3, + wantNames: []string{ + "model.layers.0.mlp.experts.0.down_proj.weight", + "model.layers.0.mlp.experts.0.gate_proj.weight", + "model.layers.0.mlp.experts.0.up_proj.weight", + }, + wantDtypes: []string{"BF16", "BF16", "BF16"}, + wantQuants: []string{"", "", ""}, + }, + { + name: "packed quantized blob with global metadata", + header: map[string]any{ + "__metadata__": map[string]any{ + "quant_type": "int4", + "group_size": "32", + }, + "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{10240, 320}, + "data_offsets": []int64{0, 13107200}, + }, + "model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 80}, + "data_offsets": []int64{13107200, 14745600}, + }, + "model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 80}, + "data_offsets": []int64{14745600, 16384000}, + }, + "model.layers.0.mlp.experts.0.up_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{10240, 320}, + "data_offsets": []int64{16384000, 29491200}, + }, + "model.layers.0.mlp.experts.0.up_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 80}, + "data_offsets": []int64{29491200, 31129600}, + }, + "model.layers.0.mlp.experts.0.up_proj.weight.bias": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 80}, + "data_offsets": []int64{31129600, 32768000}, + }, + }, + wantCount: 2, + wantNames: []string{ + "model.layers.0.mlp.experts.0.gate_proj.weight", + "model.layers.0.mlp.experts.0.up_proj.weight", + }, + wantDtypes: []string{"U32", "U32"}, + wantQuants: []string{"int4", "int4"}, + }, + { + name: "packed mixed-precision blob (no global metadata)", + header: map[string]any{ + "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{10240, 320}, + "data_offsets": []int64{0, 13107200}, + }, + "model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 80}, + "data_offsets": []int64{13107200, 14745600}, + }, + "model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 80}, + "data_offsets": []int64{14745600, 16384000}, + }, + "model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{2560, 2560}, + "data_offsets": []int64{16384000, 42598400}, + }, + "model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 160}, + "data_offsets": []int64{42598400, 43417600}, + }, + }, + wantCount: 2, + wantNames: []string{ + "model.layers.0.mlp.experts.0.down_proj.weight", + "model.layers.0.mlp.experts.0.gate_proj.weight", + }, + wantDtypes: []string{"U32", "U32"}, + wantQuants: []string{"int8", "int4"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headerJSON, err := json.Marshal(tt.header) + if err != nil { + t.Fatalf("failed to marshal header: %v", err) + } + + var buf bytes.Buffer + if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil { + t.Fatalf("failed to write header size: %v", err) + } + buf.Write(headerJSON) + + results, err := parseSafetensorsAllHeaders(&buf) + if (err != nil) != tt.wantErr { + t.Errorf("parseSafetensorsAllHeaders() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + + if len(results) != tt.wantCount { + t.Fatalf("got %d tensors, want %d", len(results), tt.wantCount) + } + + for i, info := range results { + if info.Name != tt.wantNames[i] { + t.Errorf("tensor[%d].Name = %v, want %v", i, info.Name, tt.wantNames[i]) + } + if info.Dtype != tt.wantDtypes[i] { + t.Errorf("tensor[%d].Dtype = %v, want %v", i, info.Dtype, tt.wantDtypes[i]) + } + if info.QuantType != tt.wantQuants[i] { + t.Errorf("tensor[%d].QuantType = %v, want %v", i, info.QuantType, tt.wantQuants[i]) + } + } + }) + } +} + +func TestGetTensorInfoFromManifest_Packed(t *testing.T) { + // Create a temp directory for blobs and set OLLAMA_MODELS + tempDir := t.TempDir() + t.Setenv("OLLAMA_MODELS", tempDir) + + blobDir := filepath.Join(tempDir, "blobs") + if err := os.MkdirAll(blobDir, 0o755); err != nil { + t.Fatalf("failed to create blobs dir: %v", err) + } + + // Create a packed blob with multiple expert tensors (mixed quantization) + header := map[string]any{ + "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{10240, 320}, + "data_offsets": []int64{0, 13107200}, + }, + "model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 80}, + "data_offsets": []int64{13107200, 14745600}, + }, + "model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{ + "dtype": "BF16", + "shape": []int64{10240, 80}, + "data_offsets": []int64{14745600, 16384000}, + }, + "model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{ + "dtype": "U32", + "shape": []int64{2560, 2560}, + "data_offsets": []int64{16384000, 42598400}, + }, + "model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{ + "dtype": "BF16", + "shape": []int64{2560, 160}, + "data_offsets": []int64{42598400, 43417600}, + }, + } + headerJSON, _ := json.Marshal(header) + + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))) + buf.Write(headerJSON) + + packedDigest := "sha256:aaaa000000000000000000000000000000000000000000000000000000000001" + blobPath, err := manifest.BlobsPath(packedDigest) + if err != nil { + t.Fatalf("failed to get blob path: %v", err) + } + if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("failed to write packed blob: %v", err) + } + + // Also create a regular (single-tensor) blob + singleHeader := map[string]any{ + "model.embed_tokens.weight": map[string]any{ + "dtype": "BF16", + "shape": []int64{262144, 2560}, + "data_offsets": []int64{0, 1342177280}, + }, + } + singleHeaderJSON, _ := json.Marshal(singleHeader) + var singleBuf bytes.Buffer + binary.Write(&singleBuf, binary.LittleEndian, uint64(len(singleHeaderJSON))) + singleBuf.Write(singleHeaderJSON) + + singleDigest := "sha256:bbbb000000000000000000000000000000000000000000000000000000000002" + singleBlobPath, err := manifest.BlobsPath(singleDigest) + if err != nil { + t.Fatalf("failed to get blob path: %v", err) + } + if err := os.WriteFile(singleBlobPath, singleBuf.Bytes(), 0o644); err != nil { + t.Fatalf("failed to write single blob: %v", err) + } + + mf := &manifest.Manifest{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Layers: []manifest.Layer{ + { + MediaType: manifest.MediaTypeImageTensor, + Digest: singleDigest, + Size: int64(singleBuf.Len()), + Name: "model.embed_tokens.weight", + }, + { + MediaType: manifest.MediaTypeImageTensor, + Digest: packedDigest, + Size: int64(buf.Len()), + Name: "model.layers.0.mlp.experts", // group prefix + }, + }, + } + + result, err := getTensorInfoFromManifest(mf) + if err != nil { + t.Fatalf("getTensorInfoFromManifest() error = %v", err) + } + + // Should have 3 tensors: 1 single + 2 packed main tensors + if len(result) != 3 { + t.Fatalf("got %d tensors, want 3. Tensors: %v", len(result), result) + } + + // First tensor should be the single blob + if result[0].Name != "model.embed_tokens.weight" { + t.Errorf("tensor[0].Name = %v, want model.embed_tokens.weight", result[0].Name) + } + if result[0].Type != "BF16" { + t.Errorf("tensor[0].Type = %v, want BF16", result[0].Type) + } + + // Packed tensors should have their actual names (sorted) + packedNames := make(map[string]bool) + for _, r := range result[1:] { + packedNames[r.Name] = true + } + if !packedNames["model.layers.0.mlp.experts.0.down_proj.weight"] { + t.Error("missing packed tensor: model.layers.0.mlp.experts.0.down_proj.weight") + } + if !packedNames["model.layers.0.mlp.experts.0.gate_proj.weight"] { + t.Error("missing packed tensor: model.layers.0.mlp.experts.0.gate_proj.weight") + } +} + func TestReadSafetensorsHeader(t *testing.T) { // Create a temp file with a valid safetensors header tempDir := t.TempDir()