diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index eedd0c61d..433b2ab1b 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) { Details: api.ModelDetails{ Family: "ZImagePipeline", ParameterSize: "10.3B", - QuantizationLevel: "FP8", + QuantizationLevel: "Q8", }, Capabilities: []model.Capability{model.CapabilityImage}, Requires: "0.14.0", @@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) { expect := " Model\n" + " architecture ZImagePipeline \n" + " parameters 10.3B \n" + - " quantization FP8 \n" + + " quantization Q8 \n" + " requires 0.14.0 \n" + "\n" + " Capabilities\n" + diff --git a/runner/runner.go b/runner/runner.go index 543410798..db50758bb 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -3,7 +3,7 @@ package runner import ( "github.com/ollama/ollama/runner/llamarunner" "github.com/ollama/ollama/runner/ollamarunner" - imagerunner "github.com/ollama/ollama/x/imagegen/runner" + "github.com/ollama/ollama/x/mlxrunner" ) func Execute(args []string) error { @@ -12,18 +12,18 @@ func Execute(args []string) error { } var newRunner bool - var imageRunner bool + var mlxRunner bool if len(args) > 0 && args[0] == "--ollama-engine" { args = args[1:] newRunner = true } - if len(args) > 0 && args[0] == "--image-engine" { + if len(args) > 0 && args[0] == "--mlx-engine" { args = args[1:] - imageRunner = true + mlxRunner = true } - if imageRunner { - return imagerunner.Execute(args) + if mlxRunner { + return mlxrunner.Execute(args) } else if newRunner { return ollamarunner.Execute(args) } else { diff --git a/server/sched.go b/server/sched.go index 5f22e0b87..3aa9969a0 100644 --- a/server/sched.go +++ b/server/sched.go @@ -21,7 +21,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/types/model" - "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/mlxrunner" ) type LlmRequest struct { @@ -195,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) { slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus)) } - // Check for image generation model before attempting GGML load + // Check for image generation models - all use MLX runner if slices.Contains(pending.model.Config.Capabilities, "image") { - if s.loadImageGen(pending) { + if s.loadMLX(pending) { break } continue } + // Check for experimental safetensors LLM models + if pending.model.Config.ModelFormat == "safetensors" { + if slices.Contains(pending.model.Config.Capabilities, "completion") { + // LLM model with safetensors format - use MLX runner + if s.loadMLX(pending) { + break + } + continue + } + } + // Load model for fitting logutil.Trace("loading model metadata", "model", pending.model.ModelPath) ggml, err := llm.LoadModel(pending.model.ModelPath, 1024) @@ -552,11 +563,20 @@ iGPUScan: return false } -// loadImageGen loads an image generation model. -func (s *Scheduler) loadImageGen(req *LlmRequest) bool { - // Use model name for imagegen (it resolves manifests by name, not file path) +// loadMLX loads an experimental safetensors model using the unified MLX runner. +// This supports both LLM (completion) and image generation models. +func (s *Scheduler) loadMLX(req *LlmRequest) bool { + // Determine mode based on capabilities + var mode mlxrunner.ModelMode + if slices.Contains(req.model.Config.Capabilities, "image") { + mode = mlxrunner.ModeImageGen + } else { + mode = mlxrunner.ModeLLM + } + + // Use model name for MLX (it resolves manifests by name, not file path) modelName := req.model.ShortName - server, err := imagegen.NewServer(modelName) + server, err := mlxrunner.NewServer(modelName, mode) if err != nil { req.errCh <- err return true diff --git a/x/create/client/create.go b/x/create/client/create.go index c7e51a525..36e7f164b 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -13,6 +13,7 @@ import ( "io" "os" "path/filepath" + "strings" "github.com/ollama/ollama/manifest" "github.com/ollama/ollama/progress" @@ -34,7 +35,7 @@ type ModelfileConfig struct { type CreateOptions struct { ModelName string ModelDir string - Quantize string // "fp8" for quantization + Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization Modelfile *ModelfileConfig // template/system/license from Modelfile } @@ -53,10 +54,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { // Determine model type settings var modelType, spinnerKey string var capabilities []string + var parserName, rendererName string if isSafetensors { modelType = "safetensors model" spinnerKey = "create" capabilities = []string{"completion"} + + // Check if model supports thinking based on architecture + if supportsThinking(opts.ModelDir) { + capabilities = append(capabilities, "thinking") + } + + // Set parser and renderer name based on architecture + parserName = getParserName(opts.ModelDir) + rendererName = getRendererName(opts.ModelDir) } else { modelType = "image generation model" spinnerKey = "imagegen" @@ -81,14 +92,14 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { err = create.CreateSafetensorsModel( opts.ModelName, opts.ModelDir, opts.Quantize, newLayerCreator(), newTensorLayerCreator(), - newManifestWriter(opts, capabilities), + newManifestWriter(opts, capabilities, parserName, rendererName), progressFn, ) } else { err = create.CreateImageGenModel( opts.ModelName, opts.ModelDir, opts.Quantize, newLayerCreator(), newTensorLayerCreator(), - newManifestWriter(opts, capabilities), + newManifestWriter(opts, capabilities, "", ""), progressFn, ) } @@ -204,7 +215,7 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error } // newManifestWriter returns a ManifestWriter callback for writing the model manifest. -func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter { +func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter { return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error { name := model.ParseName(modelName) if !name.IsValid() { @@ -229,6 +240,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes ModelFormat: "safetensors", Capabilities: caps, Requires: MinOllamaVersion, + Parser: parserName, + Renderer: rendererName, } configJSON, err := json.Marshal(configData) if err != nil { @@ -295,3 +308,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) { return layers, nil } + +// supportsThinking checks if the model supports thinking mode based on its architecture. +// This reads the config.json from the model directory and checks the architectures field. +func supportsThinking(modelDir string) bool { + configPath := filepath.Join(modelDir, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return false + } + + var cfg struct { + Architectures []string `json:"architectures"` + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return false + } + + // Check architectures that support thinking + thinkingArchitectures := []string{ + "glm4moe", // GLM-4 MoE models + "deepseek", // DeepSeek models + "qwen3", // Qwen3 models + } + + // Check the architecture list + for _, arch := range cfg.Architectures { + archLower := strings.ToLower(arch) + for _, thinkArch := range thinkingArchitectures { + if strings.Contains(archLower, thinkArch) { + return true + } + } + } + + // Also check model_type + if cfg.ModelType != "" { + typeLower := strings.ToLower(cfg.ModelType) + for _, thinkArch := range thinkingArchitectures { + if strings.Contains(typeLower, thinkArch) { + return true + } + } + } + + return false +} + +// getParserName returns the parser name for a model based on its architecture. +// This reads the config.json from the model directory and determines the appropriate parser. +func getParserName(modelDir string) string { + configPath := filepath.Join(modelDir, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return "" + } + + var cfg struct { + Architectures []string `json:"architectures"` + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return "" + } + + // Check architectures for known parsers + for _, arch := range cfg.Architectures { + archLower := strings.ToLower(arch) + if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") { + return "glm-4.7" + } + if strings.Contains(archLower, "deepseek") { + return "deepseek3" + } + if strings.Contains(archLower, "qwen3") { + return "qwen3-coder" + } + } + + // Also check model_type + if cfg.ModelType != "" { + typeLower := strings.ToLower(cfg.ModelType) + if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") { + return "glm-4.7" + } + if strings.Contains(typeLower, "deepseek") { + return "deepseek3" + } + if strings.Contains(typeLower, "qwen3") { + return "qwen3-coder" + } + } + + return "" +} + +// getRendererName returns the renderer name for a model based on its architecture. +// This reads the config.json from the model directory and determines the appropriate renderer. +func getRendererName(modelDir string) string { + configPath := filepath.Join(modelDir, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return "" + } + + var cfg struct { + Architectures []string `json:"architectures"` + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return "" + } + + // Check architectures for known renderers + for _, arch := range cfg.Architectures { + archLower := strings.ToLower(arch) + if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") { + return "glm-4.7" + } + if strings.Contains(archLower, "deepseek") { + return "deepseek3" + } + if strings.Contains(archLower, "qwen3") { + return "qwen3-coder" + } + } + + // Also check model_type + if cfg.ModelType != "" { + typeLower := strings.ToLower(cfg.ModelType) + if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") { + return "glm-4.7" + } + if strings.Contains(typeLower, "deepseek") { + return "deepseek3" + } + if strings.Contains(typeLower, "qwen3") { + return "qwen3-coder" + } + } + + return "" +} diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index 3a9f37cfc..e69003f73 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -13,7 +13,11 @@ import ( // quantizeTensor loads a tensor from safetensors format, quantizes it, // and returns safetensors data for the quantized weights, scales, and biases. -// Supported quantization types: "fp8" (affine 8-bit) +// Supported quantization types: +// - "q4": affine 4-bit, group_size=32 (with qbiases) +// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales) +// - "q8": affine 8-bit, group_size=64 (with qbiases) +// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales) // Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights). func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { tmpDir := ensureTempDir() @@ -54,12 +58,18 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str // Quantize based on quantization type var qweight, scales, qbiases *mlx.Array switch quantize { - case "fp4": - // affine mode: group_size=32, bits=4 + case "q4": + // affine mode: group_size=32, bits=4 (with qbiases for zero-point offset) qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine") - case "fp8": - // affine mode: group_size=32, bits=8 - qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine") + case "nvfp4": + // NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales) + qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4") + case "q8": + // affine mode: group_size=64, bits=8 (with qbiases for zero-point offset) + qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine") + case "mxfp8": + // Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases) + qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8") default: return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize) } diff --git a/x/create/create.go b/x/create/create.go index 823d0f842..2474c8c66 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -228,7 +228,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error) type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error) // QuantizingTensorLayerCreator creates tensor layers with optional quantization. -// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases). +// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases). type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) // ManifestWriter writes the manifest file. @@ -262,36 +262,134 @@ func ShouldQuantize(name, component string) bool { return strings.HasSuffix(name, ".weight") } -// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape. +// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type. // This is a more detailed check that also considers tensor dimensions. -func ShouldQuantizeTensor(name string, shape []int32) bool { +// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8"). +func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool { + return GetTensorQuantization(name, shape, quantize) != "" +} + +// normalizeQuantType converts various quantization type aliases to canonical forms. +// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8 +func normalizeQuantType(quantize string) string { + switch strings.ToUpper(quantize) { + case "Q4", "INT4", "FP4": + return "q4" + case "Q8", "INT8", "FP8": + return "q8" + case "NVFP4": + return "nvfp4" + case "MXFP8": + return "mxfp8" + default: + return quantize + } +} + +// getQuantGroupSize returns the group size for a given quantization type. +// These must match the values used in quantize.go when creating quantized models. +func getQuantGroupSize(quantize string) int { + switch normalizeQuantType(quantize) { + case "nvfp4": + return 16 + case "q4": + return 32 + case "mxfp8": + return 32 + case "q8": + return 64 + default: + return 32 + } +} + +// GetTensorQuantization returns the appropriate quantization type for a tensor. +// Returns "" if the tensor should not be quantized. +// This implements mixed-precision quantization: +// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive) +// - Output projection, gate/up weights: q4 (less sensitive) +// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel) +// - Norms, embeddings, biases, routing gates: no quantization +func GetTensorQuantization(name string, shape []int32, quantize string) string { // Use basic name-based check first if !ShouldQuantize(name, "") { - return false + return "" } // Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any) if len(shape) != 2 { - return false + return "" } // Skip small tensors (less than 1024 elements) - not worth quantizing if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 { - return false + return "" } - // MLX quantization requires last dimension to be divisible by group size (32) - if shape[len(shape)-1]%32 != 0 { - return false + // Normalize quantization type to canonical form + quantNorm := normalizeQuantType(quantize) + + // MLX quantization requires last dimension to be divisible by group size + // nvfp4: 16, q4/mxfp8: 32, q8: 64 + groupSize := int32(32) + switch quantNorm { + case "nvfp4": + groupSize = 16 + case "q8": + groupSize = 64 + } + if shape[len(shape)-1]%groupSize != 0 { + return "" } - return true + // Skip routing gate weights (should stay high precision) + // In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight) + if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") { + return "" + } + + // For NVFP4 or MXFP8, use the same quantization for all (no mixed precision) + if quantNorm == "nvfp4" || quantNorm == "mxfp8" { + return quantNorm + } + + // Attention MLA weights - keep unquantized (bf16) + // These are highly sensitive: errors accumulate in the KV cache over time + // q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj + if strings.Contains(name, "q_a_proj") || + strings.Contains(name, "q_b_proj") || + strings.Contains(name, "kv_a_proj") || + strings.Contains(name, "kv_b_proj") { + return "" // No quantization - keep bf16 + } + + // Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel) + // mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj + if strings.Contains(name, "down_proj") { + return "q8" + } + + // Output projection, gate/up weights - use requested quantization (Q4) + // o_proj, gate_proj, up_proj + if strings.Contains(name, "o_proj") || + strings.Contains(name, "gate_proj") || + strings.Contains(name, "up_proj") { + return quantNorm + } + + // LM head - use requested quantization + if strings.Contains(name, "lm_head") { + return quantNorm + } + + // Default to requested quantization for other weights + return quantNorm } // CreateSafetensorsModel imports a standard safetensors model from a directory. // This handles Hugging Face style models with config.json and *.safetensors files. // Stores each tensor as a separate blob for fine-grained deduplication. -// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized. +// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized. func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { var layers []LayerInfo var configLayer LayerInfo @@ -330,9 +428,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La } // Determine quantization type for this tensor (empty string if not quantizing) + // GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN) quantizeType := "" - if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) { - quantizeType = quantize + if quantize != "" { + quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize) } // Store as minimal safetensors format (88 bytes header overhead) @@ -388,6 +487,23 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La return fmt.Errorf("config.json not found in %s", modelDir) } + // Create model_index.json with quantization info if quantizing + if quantize != "" { + modelIndex := map[string]any{ + "quantization": strings.ToUpper(quantize), + "group_size": getQuantGroupSize(quantize), + } + indexData, err := json.MarshalIndent(modelIndex, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal model_index.json: %w", err) + } + indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json") + if err != nil { + return fmt.Errorf("failed to create model_index.json layer: %w", err) + } + layers = append(layers, indexLayer) + } + fn(fmt.Sprintf("writing manifest for %s", modelName)) if err := writeManifest(modelName, configLayer, layers); err != nil { diff --git a/x/create/create_test.go b/x/create/create_test.go index c69bb10a8..b5d0a7b34 100644 --- a/x/create/create_test.go +++ b/x/create/create_test.go @@ -536,41 +536,51 @@ func TestShouldQuantize(t *testing.T) { func TestShouldQuantizeTensor(t *testing.T) { tests := []struct { - name string - tensor string - shape []int32 - want bool + name string + tensor string + shape []int32 + quantize string + want bool }{ // 2D tensors with sufficient size should be quantized - {"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true}, - {"medium 2D weight", "small_proj.weight", []int32{128, 128}, true}, + {"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true}, + {"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true}, + {"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true}, // Small tensors should not be quantized (< 1024 elements) - {"tiny 2D weight", "tiny.weight", []int32{16, 16}, false}, - {"small 2D weight", "small.weight", []int32{31, 31}, false}, + {"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false}, + {"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false}, // 1D tensors should not be quantized - {"1D tensor", "layer_norm.weight", []int32{4096}, false}, + {"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false}, // 3D+ tensors should not be quantized - {"3D tensor", "conv.weight", []int32{64, 64, 3}, false}, - {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false}, + {"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false}, + {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false}, // Embeddings should not be quantized regardless of shape - {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false}, + {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false}, // Norms should not be quantized regardless of shape - {"norm 2D", "layer_norm.weight", []int32{4096, 1}, false}, + {"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false}, // Biases should not be quantized - {"bias 2D", "proj.bias", []int32{4096, 1}, false}, + {"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false}, + + // Group size divisibility tests + // FP8/FP4 require divisible by 32 + {"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false}, + {"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true}, + // NVFP4 requires divisible by 16 + {"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false}, + {"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := ShouldQuantizeTensor(tt.tensor, tt.shape) + got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize) if got != tt.want { - t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want) + t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want) } }) } @@ -741,7 +751,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) { progressFn := func(status string) {} - err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn) + err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn) if err != nil { t.Fatalf("CreateImageGenModel failed: %v", err) } diff --git a/x/create/imagegen.go b/x/create/imagegen.go index 595a40417..0da0e764a 100644 --- a/x/create/imagegen.go +++ b/x/create/imagegen.go @@ -15,15 +15,15 @@ import ( // CreateImageGenModel imports an image generation model from a directory. // Stores each tensor as a separate blob for fine-grained deduplication. // If quantize is specified, linear weights in transformer/text_encoder are quantized. -// Supported quantization types: fp8 (or empty for no quantization). +// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization). // Layer creation and manifest writing are done via callbacks to avoid import cycles. func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { // Validate quantization type switch quantize { - case "", "fp4", "fp8": + case "", "q4", "q8", "nvfp4", "mxfp8": // valid default: - return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize) + return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize) } var layers []LayerInfo @@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer // Determine quantization type for this tensor (empty string if not quantizing) quantizeType := "" - if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) { + if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) { quantizeType = quantize } @@ -213,10 +213,18 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer } // canQuantizeShape returns true if a tensor shape is compatible with MLX quantization. -// MLX requires the last dimension to be divisible by the group size (32). -func canQuantizeShape(shape []int32) bool { +// MLX requires the last dimension to be divisible by the group size. +// nvfp4: 16, q4/mxfp8: 32, q8: 64 +func canQuantizeShape(shape []int32, quantize string) bool { if len(shape) < 2 { return false } - return shape[len(shape)-1]%32 == 0 + groupSize := int32(32) + switch strings.ToUpper(quantize) { + case "NVFP4": + groupSize = 16 + case "Q8": + groupSize = 64 + } + return shape[len(shape)-1]%groupSize == 0 } diff --git a/x/imagegen/cache/cache.go b/x/imagegen/cache/cache.go index 4faa2412e..8a25193cd 100644 --- a/x/imagegen/cache/cache.go +++ b/x/imagegen/cache/cache.go @@ -9,6 +9,7 @@ type Cache interface { Offset() int Len() int State() []*mlx.Array + Reset() } type KVCache struct { @@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array { func (c *KVCache) Offset() int { return c.offset } func (c *KVCache) Len() int { return c.offset } +// Reset clears the cache state for a new generation session +func (c *KVCache) Reset() { + c.keys = nil + c.values = nil + c.offset = 0 +} + // RotatingKVCache implements sliding window attention with bounded memory type RotatingKVCache struct { keys, values *mlx.Array @@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array { func (c *RotatingKVCache) Offset() int { return c.offset } func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) } + +// Reset clears the cache state for a new generation session +func (c *RotatingKVCache) Reset() { + c.keys = nil + c.values = nil + c.offset = 0 + c.idx = 0 +} diff --git a/x/imagegen/manifest.go b/x/imagegen/manifest.go index 7692d2b09..da55cbe81 100644 --- a/x/imagegen/manifest.go +++ b/x/imagegen/manifest.go @@ -102,14 +102,17 @@ func (m *ModelManifest) BlobPath(digest string) string { return filepath.Join(m.BlobDir, blobName) } -// GetTensorLayers returns all tensor layers for a given component. -// Component should be "text_encoder", "transformer", or "vae". -// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight"). +// GetTensorLayers returns tensor layers, optionally filtered by component. +// If component is empty, returns all tensor layers (for LLM models). +// If component is specified (e.g., "text_encoder", "transformer", "vae"), +// returns only layers with that prefix. func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer { - prefix := component + "/" var layers []ManifestLayer for _, layer := range m.Manifest.Layers { - if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) { + if layer.MediaType != "application/vnd.ollama.image.tensor" { + continue + } + if component == "" || strings.HasPrefix(layer.Name, component+"/") { layers = append(layers, layer) } } @@ -206,7 +209,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) { if info.Quantization == "" { for _, layer := range manifest.Manifest.Layers { if strings.HasSuffix(layer.Name, ".weight_scale") { - info.Quantization = "FP8" + info.Quantization = "Q8" break } } diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go index 2b31aadfb..2232e482b 100644 --- a/x/imagegen/mlx/mlx.go +++ b/x/imagegen/mlx/mlx.go @@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array { return Concatenate([]*Array{a, b}, axis) } +// Stack stacks arrays along a new axis (axis 0 by default) +func Stack(arrays []*Array, axis int) *Array { + handles := make([]C.mlx_array, len(arrays)) + for i, arr := range arrays { + handles[i] = arr.c + } + vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles))) + res := C.mlx_array_new() + C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream()) + C.mlx_vector_array_free(vec) + return newArray(res) +} + // Slice slices the array func Slice(a *Array, start, stop []int32) *Array { n := len(start) diff --git a/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go new file mode 100644 index 000000000..caebbe361 --- /dev/null +++ b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go @@ -0,0 +1,840 @@ +//go:build mlx + +// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX. +// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE). +package glm4_moe_lite + +import ( + "encoding/json" + "fmt" + "math" + + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// RopeScaling holds RoPE scaling configuration +type RopeScaling struct { + Factor float32 `json:"factor"` + MscaleAllDim float32 `json:"mscale_all_dim"` +} + +// Config holds GLM4-MoE-Lite model configuration +type Config struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + MoEIntermediateSize int32 `json:"moe_intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + AttentionBias bool `json:"attention_bias"` + + // MLA (Multi-head Latent Attention) parameters + QLoraRank int32 `json:"q_lora_rank"` + KVLoraRank int32 `json:"kv_lora_rank"` + QKRopeHeadDim int32 `json:"qk_rope_head_dim"` + QKNopeHeadDim int32 `json:"qk_nope_head_dim"` + VHeadDim int32 `json:"v_head_dim"` + + // MoE parameters + NRoutedExperts int32 `json:"n_routed_experts"` + NSharedExperts int32 `json:"n_shared_experts"` + NumExpertsPerTok int32 `json:"num_experts_per_tok"` + RoutedScalingFactor float32 `json:"routed_scaling_factor"` + NormTopKProb bool `json:"norm_topk_prob"` + FirstKDenseReplace int32 `json:"first_k_dense_replace"` + NGroup int32 `json:"n_group"` + TopKGroup int32 `json:"topk_group"` + + // RoPE scaling + RopeScaling *RopeScaling `json:"rope_scaling"` + + // Quantization parameters (set during load based on model quantization) + QuantGroupSize int `json:"-"` // Group size for quantization (default 64) + QuantBits int `json:"-"` // Bits per weight (4 or 8) + QuantMode string `json:"-"` // Quantization mode ("affine", etc.) + + // Computed fields + QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim + Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment +} + +// MLAAttention implements Multi-head Latent Attention with absorption. +// This uses absorbed MLA which operates in latent space for reduced KV cache. +type MLAAttention struct { + // Low-rank query projections + QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"` + QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"` + QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"` + + // Low-rank KV projections (with shared rope component) + KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"` + KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"` + + // Absorbed MLA projections (derived from kv_b_proj) + // EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim] + // UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank] + EmbedQ *nn.MultiLinear `weight:"-"` + UnembedOut *nn.MultiLinear `weight:"-"` + + // Output projection + OProj nn.LinearLayer `weight:"self_attn.o_proj"` +} + +// Forward computes absorbed MLA attention output. +// This operates in latent space for reduced KV cache memory. +func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + // Query path: q_a_proj -> layernorm -> q_b_proj + q := a.QAProj.Forward(x) + q = a.QALayerNorm.Forward(q, cfg.RMSNormEps) + q = a.QBProj.Forward(q) + + // Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim] + q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim) + q = mlx.Transpose(q, 0, 2, 1, 3) + + // Split Q into nope and rope parts + qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim}) + qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim}) + + // KV path: get compressed KV and k_pe + compressedKV := a.KVAProjWithMQA.Forward(x) + + // Split into compressed_kv and k_pe (shared rope component) + kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank}) + kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim}) + + // k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim] + kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim) + kPE = mlx.Transpose(kPE, 0, 2, 1, 3) + + // Apply layernorm to get kv latent representation + kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps) + // kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting + kvLatent = mlx.ExpandDims(kvLatent, 1) + + // Apply RoPE to the rope parts + offset := 0 + if c != nil { + offset = c.Offset() + } + qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) + kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) + + // ABSORBED MLA: project q_nope to latent space + // qNope: [B, num_heads, L, qk_nope_head_dim] + // EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim] + // Result: [B, num_heads, L, kv_lora_rank] + qLatent := a.EmbedQ.Forward(qNope) + + // Keys = concat(kvLatent, kPE) + // kvLatent: [B, 1, L, kv_lora_rank] + // kPE: [B, 1, L, qk_rope_head_dim] + // keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim] + keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3) + + // Cache the smaller latent representation + // We cache keys (latent + rope) and use empty values since values are derived from keys + cachedL := L + if c != nil { + // Create placeholder values with 0 dims for cache (we don't actually use cached values) + placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32) + keys, _ = c.Update(keys, placeholderValues, int(L)) + cachedL = int32(keys.Shape()[2]) + } + + // Values are the first kv_lora_rank dims of keys (slice off rope part) + values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank}) + + // Queries = concat(qLatent, qPE) + // qLatent: [B, num_heads, L, kv_lora_rank] + // qPE: [B, num_heads, L, qk_rope_head_dim] + // queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim] + queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3) + + // Attention in latent space + // queries: [B, num_heads, L, kv_lora_rank + rope_dim] + // keys: [B, 1, cachedL, kv_lora_rank + rope_dim] + // values: [B, 1, cachedL, kv_lora_rank] + out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1) + + // ABSORBED MLA: unembed from latent space + // out: [B, num_heads, L, kv_lora_rank] + // UnembedOut: [num_heads, v_head_dim, kv_lora_rank] + // Result: [B, num_heads, L, v_head_dim] + out = a.UnembedOut.Forward(out) + + // Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim] + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim) + + return a.OProj.Forward(out) +} + +// DenseMLP implements the standard SwiGLU MLP for dense layers +type DenseMLP struct { + GateProj nn.LinearLayer `weight:"mlp.gate_proj"` + UpProj nn.LinearLayer `weight:"mlp.up_proj"` + DownProj nn.LinearLayer `weight:"mlp.down_proj"` +} + +// Forward applies the SwiGLU MLP +func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array { + gate := mlx.SiLU(m.GateProj.Forward(x)) + up := m.UpProj.Forward(x) + return m.DownProj.Forward(mlx.Mul(gate, up)) +} + +// MoEGate implements the expert gating mechanism +type MoEGate struct { + Gate nn.LinearLayer `weight:"mlp.gate"` + EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"` +} + +// Forward computes expert selection indices and scores +func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) { + // Compute gate logits through linear layer (handles both quantized and non-quantized) + gates := g.Gate.Forward(x) + + // Sigmoid scoring + scores := mlx.Sigmoid(gates) + origScores := scores + + // Add correction bias if present + if g.EScoreCorrectionBias != nil { + scores = mlx.Add(scores, g.EScoreCorrectionBias) + } + + // Group-wise expert selection (simplified for n_group=1) + // Select top-k experts + topK := cfg.NumExpertsPerTok + negScores := mlx.Neg(scores) + inds := mlx.Argpartition(negScores, int(topK)-1, -1) + + shape := inds.Shape() + inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK}) + + // Get scores for selected experts + scores = mlx.TakeAlongAxis(origScores, inds, -1) + + // Normalize if configured + if topK > 1 && cfg.NormTopKProb { + sumScores := mlx.Sum(scores, -1, true) + scores = mlx.Div(scores, sumScores) + } + + // Apply routing scaling factor + scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor) + + return inds, scores +} + +// SwitchMLP implements the MoE expert computation using stacked weights +// Note: No weight tags - these are populated manually by stacking expert weights +type SwitchMLP struct { + // Dequantized weights (used when GatherQMM not available) + GateWeight *mlx.Array + UpWeight *mlx.Array + DownWeight *mlx.Array + + // Quantized weights (used with GatherQMM for 4/8-bit affine) + GateWeightQ, GateScales, GateBiases *mlx.Array + UpWeightQ, UpScales, UpBiases *mlx.Array + DownWeightQ, DownScales, DownBiases *mlx.Array + + // Quantization bits per projection (supports mixed precision Q4/Q8) + GateBits int + UpBits int + DownBits int + + // Quantization group size per projection (detected from tensor shapes) + GateGroupSize int + UpGroupSize int + DownGroupSize int + + // If true, use GatherQMM with quantized weights + UseQuantized bool +} + +// Forward applies the switched expert MLP +func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array { + shape := x.Shape() + B, L := shape[0], shape[1] + topK := cfg.NumExpertsPerTok + + // Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D] + xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2) + + // Flatten for gather_mm: [B*L, 1, 1, D] + xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize) + + // Flatten indices: [B, L, topK] -> [B*L, topK] + idxFlat := mlx.Reshape(indices, B*L, topK) + + // Sort for efficient gather (when we have many tokens) + doSort := B*L >= 64 + var invOrder *mlx.Array + n := B * L * topK + + if doSort { + idxAll := mlx.Flatten(idxFlat) + order := mlx.Argsort(idxAll, 0) + invOrder = mlx.Argsort(order, 0) + // Reorder x based on sorted indices + xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1) + idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) + } + + var gate, up, hidden, down *mlx.Array + + if s.UseQuantized { + // Use GatherQMM for quantized weights (faster, keeps weights quantized) + // Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down) + gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases, + nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort) + up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, + nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) + + hidden = mlx.Mul(mlx.SiLU(gate), up) + + down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, + nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) + } else { + // Use GatherMM for dequantized/non-quantized weights + gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort) + up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort) + + hidden = mlx.Mul(mlx.SiLU(gate), up) + + down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort) + } + + // Unsort if we sorted + if doSort { + down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize) + } else { + down = mlx.Squeeze(down, 2) + } + + return mlx.Reshape(down, B, L, topK, cfg.HiddenSize) +} + +// SharedExperts implements the shared expert MLP +type SharedExperts struct { + GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"` + UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"` + DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"` +} + +// Forward applies the shared expert MLP +func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array { + gate := mlx.SiLU(s.GateProj.Forward(x)) + up := s.UpProj.Forward(x) + return s.DownProj.Forward(mlx.Mul(gate, up)) +} + +// MoE implements the full Mixture of Experts layer +type MoE struct { + Gate *MoEGate + SwitchMLP *SwitchMLP + SharedExperts *SharedExperts +} + +// Forward applies the MoE layer +func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array { + shape := x.Shape() + B, L := shape[0], shape[1] + + // Get expert indices and scores + inds, scores := m.Gate.Forward(x, cfg) + + // Apply routed experts + expertOut := m.SwitchMLP.Forward(x, inds, cfg) + + // Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK + scoresExpanded := mlx.ExpandDims(scores, -1) + y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false) + + // Add shared experts if present + if m.SharedExperts != nil { + y = mlx.Add(y, m.SharedExperts.Forward(x)) + } + + return mlx.Reshape(y, B, L, cfg.HiddenSize) +} + +// DenseBlock represents a dense transformer block (for first_k_dense_replace layers) +type DenseBlock struct { + Attention *MLAAttention + MLP *DenseMLP + InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"` + PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"` +} + +// Forward applies the dense block +func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + // Pre-norm attention with residual + r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) + h := mlx.Add(x, r) + + // Pre-norm MLP with residual + r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps)) + return mlx.Add(h, r) +} + +// MoEBlock represents a MoE transformer block +type MoEBlock struct { + Attention *MLAAttention + MoE *MoE + InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"` + PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"` +} + +// Forward applies the MoE block +func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + // Pre-norm attention with residual + r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) + h := mlx.Add(x, r) + + // Pre-norm MoE with residual + r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg) + return mlx.Add(h, r) +} + +// Block interface for both dense and MoE blocks +type Block interface { + Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array +} + +// Model represents the complete GLM4-MoE-Lite model +type Model struct { + EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` + Layers []Block `weight:"-"` // Loaded manually due to different block types + Norm *nn.RMSNorm `weight:"model.norm"` + LMHead nn.LinearLayer `weight:"lm_head"` + + tok *tokenizer.Tokenizer + *Config +} + +// computeScale computes the attention scale. +// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner. +func computeScale(cfg *Config) float32 { + keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim + scale := float32(1.0 / math.Sqrt(float64(keyLength))) + if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 { + s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0 + scale *= s * s + } + return scale +} + +// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support. +// Currently only 4-bit and 8-bit affine quantization are supported. +func supportsGatherQMM(mode string, bits int) bool { + return mode == "affine" && (bits == 4 || bits == 8) +} + +// ExpertWeight holds a single expert's weight with optional quantization components. +type ExpertWeight struct { + Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight + Scales *mlx.Array // Quantization scales (nil if not quantized) + Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases) + Bits int // Quantization bits (4 or 8), 0 if not quantized + GroupSize int // Quantization group size, 0 if not quantized +} + +// getQuantParams returns quantization parameters from model metadata. +// Returns groupSize, bits, and mode for the model's quantization type. +func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) { + groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization()) + // Use metadata group_size if available (overrides default) + if gs := weights.GroupSize(); gs > 0 { + groupSize = gs + } + return groupSize, bits, mode +} + +// loadExpertWeight loads an expert weight. +// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components. +// Otherwise dequantizes and returns only the weight. +func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight { + w, _ := weights.GetTensor(path + ".weight") + if w == nil { + return nil + } + + // Check if this is a quantized weight by looking for scales + scalePath := path + ".weight_scale" + if weights.HasTensor(scalePath) { + scales, _ := weights.GetTensor(scalePath) + var qbiases *mlx.Array + qbiasPath := path + ".weight_qbias" + if weights.HasTensor(qbiasPath) { + qbiases, _ = weights.GetTensor(qbiasPath) + } + + // Get quantization params from metadata + groupSize, bits, mode := getQuantParams(weights) + + // Update config with group size (for GatherQMM calls) + if cfg.QuantGroupSize == 0 { + cfg.QuantGroupSize = groupSize + } + + // If GatherQMM is supported and requested, return quantized components + if useQuantized && supportsGatherQMM(mode, bits) { + return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize} + } + + // Otherwise dequantize + return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)} + } + + return &ExpertWeight{Weight: w} +} + +// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format. +// Returns embed_q and unembed_out weights for per-head projections. +// +// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] +// Output: +// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent +// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output +func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) { + path := prefix + ".self_attn.kv_b_proj" + w, err := weights.GetTensor(path + ".weight") + if err != nil || w == nil { + return nil, nil + } + + // Check if quantized and dequantize + scalePath := path + ".weight_scale" + if weights.HasTensor(scalePath) { + scales, _ := weights.GetTensor(scalePath) + var qbiases *mlx.Array + qbiasPath := path + ".weight_qbias" + if weights.HasTensor(qbiasPath) { + qbiases, _ = weights.GetTensor(qbiasPath) + } + + groupSize, bits, mode := getQuantParams(weights) + w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode) + } + + // w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] + // Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank] + headDim := cfg.QKNopeHeadDim + cfg.VHeadDim + w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank) + + // Split into wk and wv + // wk: [num_heads, qk_nope_head_dim, kv_lora_rank] + // wv: [num_heads, v_head_dim, kv_lora_rank] + wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank}) + wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank}) + + // Transform for absorbed MLA: + // embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim] + // This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection) + embedQ := mlx.Transpose(wk, 0, 2, 1) + + // unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank] + // This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection) + unembedOut := wv + + return embedQ, unembedOut +} + +// StackedExpertWeights holds stacked weights for all experts. +type StackedExpertWeights struct { + Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed] + Scales *mlx.Array // Stacked scales (nil if not quantized) + Biases *mlx.Array // Stacked biases (nil if not quantized) + Bits int // Quantization bits (4 or 8), 0 if not quantized + GroupSize int // Quantization group size, 0 if not quantized +} + +// collectAndStackExpertWeights loads and stacks expert weights for one projection type. +func collectAndStackExpertWeights( + weights safetensors.WeightSource, + prefix string, + projName string, + numExperts int32, + useQuantized bool, + cfg *Config, +) *StackedExpertWeights { + var w, s, b []*mlx.Array + var bits, groupSize int + + for e := int32(0); e < numExperts; e++ { + path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName) + ew := loadExpertWeight(weights, path, useQuantized, cfg) + if ew == nil { + continue + } + w = append(w, ew.Weight) + if ew.Scales != nil { + s = append(s, ew.Scales) + } + if ew.Biases != nil { + b = append(b, ew.Biases) + } + if e == 0 { + bits = ew.Bits + groupSize = ew.GroupSize + } + } + + result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize} + if len(w) > 0 { + result.Weight = mlx.Stack(w, 0) + if len(s) > 0 { + result.Scales = mlx.Stack(s, 0) + } + if len(b) > 0 { + result.Biases = mlx.Stack(b, 0) + } + } + return result +} + +// sanitizeExpertWeights stacks individual expert weights into tensors. +// If useQuantized is true and weights support GatherQMM, returns quantized components. +// Otherwise returns dequantized weights with nil scales/biases. +// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down). +func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) { + gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg) + up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg) + down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg) + return gate, up, down +} + +// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage). +func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) { + // Read config from manifest + configData, err := manifest.ReadConfig("config.json") + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + var cfg Config + if err := json.Unmarshal(configData, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + // Compute derived fields + cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim + cfg.Scale = computeScale(&cfg) + + // Load weights from manifest blobs + weights, err := imagegen.LoadWeightsFromManifest(manifest, "") + if err != nil { + return nil, fmt.Errorf("load weights: %w", err) + } + + if err := weights.Load(0); err != nil { + return nil, fmt.Errorf("load weight data: %w", err) + } + + // Set up quantization parameters (only if model is actually quantized) + // Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading + quantization := weights.Quantization() + useQuantized := false + if quantization != "" { + _, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization) + useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) + } + + // Load tokenizer from manifest with config files for EOS token detection + tokData, err := manifest.ReadConfig("tokenizer.json") + if err != nil { + return nil, fmt.Errorf("load tokenizer config: %w", err) + } + + // Build tokenizer config with companion files for EOS/BOS token loading + tokConfig := &tokenizer.TokenizerConfig{ + ConfigJSON: configData, // Already loaded above, contains eos_token_id + } + + // Try to load generation_config.json if available (preferred source for EOS) + if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil { + tokConfig.GenerationConfigJSON = genConfigData + } + + // Try to load tokenizer_config.json if available + if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil { + tokConfig.TokenizerConfigJSON = tokConfigData + } + + tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) + if err != nil { + return nil, fmt.Errorf("parse tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]Block, cfg.NumHiddenLayers), + Config: &cfg, + tok: tok, + } + + // Load embedding, norm, and lm_head + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return nil, err + } + + // Load layers manually due to different block types + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + prefix := fmt.Sprintf("model.layers.%d", i) + + // Load attention (same for both block types) + attn := &MLAAttention{} + if err := safetensors.LoadModule(attn, weights, prefix); err != nil { + return nil, fmt.Errorf("layer %d attention: %w", i, err) + } + + // Sanitize MLA weights for absorbed attention + embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg) + attn.EmbedQ = nn.NewMultiLinear(embedQ) + attn.UnembedOut = nn.NewMultiLinear(unembedOut) + + if i < cfg.FirstKDenseReplace { + // Dense block + block := &DenseBlock{Attention: attn} + if err := safetensors.LoadModule(block, weights, prefix); err != nil { + return nil, fmt.Errorf("layer %d dense: %w", i, err) + } + m.Layers[i] = block + } else { + // MoE block + block := &MoEBlock{Attention: attn} + if err := safetensors.LoadModule(block, weights, prefix); err != nil { + return nil, fmt.Errorf("layer %d moe block: %w", i, err) + } + + // Stack expert weights (pass cfg so group sizes can be detected) + gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg) + + switchMLP := &SwitchMLP{UseQuantized: useQuantized} + if useQuantized { + switchMLP.GateWeightQ = gate.Weight + switchMLP.GateScales = gate.Scales + switchMLP.GateBiases = gate.Biases + switchMLP.GateBits = gate.Bits + switchMLP.GateGroupSize = gate.GroupSize + switchMLP.UpWeightQ = up.Weight + switchMLP.UpScales = up.Scales + switchMLP.UpBiases = up.Biases + switchMLP.UpBits = up.Bits + switchMLP.UpGroupSize = up.GroupSize + switchMLP.DownWeightQ = down.Weight + switchMLP.DownScales = down.Scales + switchMLP.DownBiases = down.Biases + switchMLP.DownBits = down.Bits + switchMLP.DownGroupSize = down.GroupSize + } else { + switchMLP.GateWeight = gate.Weight + switchMLP.UpWeight = up.Weight + switchMLP.DownWeight = down.Weight + } + + block.MoE = &MoE{ + Gate: &MoEGate{}, + SwitchMLP: switchMLP, + } + + // Load gate weights + if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil { + return nil, fmt.Errorf("layer %d gate: %w", i, err) + } + + // Load shared experts if present + if cfg.NSharedExperts > 0 { + block.MoE.SharedExperts = &SharedExperts{} + if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil { + return nil, fmt.Errorf("layer %d shared experts: %w", i, err) + } + } + + m.Layers[i] = block + } + } + + mlx.Eval(mlx.Collect(m)...) + weights.ReleaseAll() + + return m, nil +} + +// Forward computes the forward pass of the model +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + B, L := tokens.Shape()[0], tokens.Shape()[1] + + h := m.EmbedTokens.Forward(tokens) + + for i, layer := range m.Layers { + var c cache.Cache + if caches != nil { + c = caches[i] + } + h = layer.Forward(h, c, B, L, m.Config) + } + + h = m.Norm.Forward(h, m.RMSNormEps) + return m.LMHead.Forward(h) +} + +// Interface methods + +// NumLayers returns the number of transformer layers +func (m *Model) NumLayers() int { return len(m.Layers) } + +// MaxContextLength returns the maximum context length +func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings } + +// VocabSize returns the vocabulary size +func (m *Model) VocabSize() int32 { return m.Config.VocabSize } + +// Tokenizer returns the model's tokenizer +func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } + +// NewCache creates a new KV cache for the model +func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i := range caches { + caches[i] = cache.NewKVCache() + } + return caches +} + +// FormatPrompt applies the GLM-4 chat template with thinking enabled by default. +// This follows the GLM-4.7 format with tag for reasoning mode. +func (m *Model) FormatPrompt(prompt string) string { + return "[gMASK]<|user|>" + prompt + "<|assistant|>" +} + +// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control. +// When think is true, the prompt ends with to enable reasoning mode. +// When think is false, the prompt ends with to skip reasoning. +func (m *Model) FormatPromptWithThinking(prompt string, think bool) string { + if think { + return "[gMASK]<|user|>" + prompt + "<|assistant|>" + } + return "[gMASK]<|user|>" + prompt + "<|assistant|>" +} + +// NewRenderer returns a new Renderer for formatting multi-turn conversations. +func (m *Model) NewRenderer() *Renderer { + return &Renderer{} +} + +// NewParser returns a new Parser for extracting thinking and tool calls from output. +func (m *Model) NewParser() *Parser { + return &Parser{} +} diff --git a/x/imagegen/models/glm4_moe_lite/parser.go b/x/imagegen/models/glm4_moe_lite/parser.go new file mode 100644 index 000000000..c81ec5a40 --- /dev/null +++ b/x/imagegen/models/glm4_moe_lite/parser.go @@ -0,0 +1,479 @@ +//go:build mlx + +package glm4_moe_lite + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type parserState int + +const ( + parserState_LookingForThinkingOpen parserState = iota + parserState_ThinkingStartedEatingWhitespace + parserState_CollectingThinking + parserState_ThinkingDoneEatingWhitespace + parserState_CollectingContent + parserState_ToolStartedEatingWhitespace + parserState_CollectingToolContent +) + +const ( + thinkingOpenTag = "" + thinkingCloseTag = "" + toolOpenTag = "" + toolCloseTag = "" +) + +// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls. +// GLM-4's prompt ends with when thinking is enabled, so the parser +// must start in CollectingThinking state (the model outputs thinking content directly). +type Parser struct { + state parserState + buffer strings.Builder + tools []api.Tool +} + +// HasToolSupport returns true as GLM4 supports tool calling. +func (p *Parser) HasToolSupport() bool { + return true +} + +// HasThinkingSupport returns true as GLM4 supports thinking mode. +func (p *Parser) HasThinkingSupport() bool { + return true +} + +// Init initializes the parser with tools and thinking configuration. +func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + // When thinking is enabled (nil or true), the prompt ends with , + // so model output starts directly with thinking content (no opening tag). + if thinkValue == nil || thinkValue.Bool() { + p.state = parserState_CollectingThinking + } + return tools +} + +type parserEvent interface { + isParserEvent() +} + +type eventContent struct { + content string +} + +func (eventContent) isParserEvent() {} + +type eventRawToolCall struct { + raw string +} + +func (eventRawToolCall) isParserEvent() {} + +type eventThinkingContent struct { + content string +} + +func (eventThinkingContent) isParserEvent() {} + +// Add processes new output text and returns parsed content, thinking, and tool calls. +func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + + for _, event := range events { + switch event := event.(type) { + case eventRawToolCall: + toolCall, err := parseToolCall(event, p.tools) + if err != nil { + slog.Warn("glm-4 tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case eventThinkingContent: + thinkingSb.WriteString(event.content) + case eventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *Parser) parseEvents() []parserEvent { + var all []parserEvent + + keepLooping := true + for keepLooping { + var events []parserEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String()) + } + + return all +} + +// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer +// and transitions to the next state. Returns (nil, false) if only whitespace remains +// in the buffer (needs more input), or (nil, true) if we successfully transitioned. +func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) { + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + p.buffer.Reset() + if trimmed == "" { + return nil, false // Still only whitespace, keep waiting for more input + } + p.state = nextState + p.buffer.WriteString(trimmed) + return nil, true // Successfully transitioned +} + +// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace), +// the content after (optionally trimmed of leading whitespace), and updates the buffer +func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) { + split := strings.SplitN(p.buffer.String(), tag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + after := split[1] + if trimAfter { + after = strings.TrimLeftFunc(after, unicode.IsSpace) + } + p.buffer.Reset() + p.buffer.WriteString(after) + return before, after +} + +func (p *Parser) eat() ([]parserEvent, bool) { + var events []parserEvent + + switch p.state { + case parserState_LookingForThinkingOpen: + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + if strings.HasPrefix(trimmed, thinkingOpenTag) { + // Found opening tag + after := strings.TrimPrefix(trimmed, thinkingOpenTag) + after = strings.TrimLeftFunc(after, unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + if after == "" { + p.state = parserState_ThinkingStartedEatingWhitespace + } else { + p.state = parserState_CollectingThinking + } + return events, true + } else if strings.HasPrefix(thinkingOpenTag, trimmed) { + // Partial opening tag seen, keep accumulating + return events, false + } else if trimmed == "" { + // Only whitespace, keep accumulating + return events, false + } else { + // No thinking tag found, skip to content collection + p.state = parserState_CollectingContent + // Don't trim - we want to keep the original content + return events, true + } + + case parserState_ThinkingStartedEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking) + + case parserState_CollectingThinking: + acc := p.buffer.String() + if strings.Contains(acc, thinkingCloseTag) { + thinking, remaining := p.splitAtTag(thinkingCloseTag, true) + if len(thinking) > 0 { + events = append(events, eventThinkingContent{content: thinking}) + } + if remaining == "" { + p.state = parserState_ThinkingDoneEatingWhitespace + } else { + p.state = parserState_CollectingContent + } + return events, true + } else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 { + // Partial closing tag - withhold it along with any trailing whitespace before it + beforePartialTag := acc[:len(acc)-overlapLen] + trailingWsLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWsLen + + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, eventThinkingContent{content: unambiguous}) + } + return events, false + } else { + // Pure thinking content - withhold trailing whitespace (might precede closing tag) + whitespaceLen := trailingWhitespaceLen(acc) + ambiguousStart := len(acc) - whitespaceLen + + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, eventThinkingContent{content: unambiguous}) + } + return events, false + } + + case parserState_ThinkingDoneEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent) + + case parserState_CollectingContent: + if strings.Contains(p.buffer.String(), toolOpenTag) { + before, after := p.splitAtTag(toolOpenTag, true) + if len(before) > 0 { + events = append(events, eventContent{content: before}) + } + if after == "" { + p.state = parserState_ToolStartedEatingWhitespace + } else { + p.state = parserState_CollectingToolContent + } + return events, true + } else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 { + beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen] + trailingWsLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWsLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, eventContent{content: unambiguous}) + } + return events, false + } else { + whitespaceLen := trailingWhitespaceLen(p.buffer.String()) + ambiguousStart := len(p.buffer.String()) - whitespaceLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, eventContent{content: unambiguous}) + } + return events, false + } + + case parserState_ToolStartedEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent) + + case parserState_CollectingToolContent: + acc := p.buffer.String() + if strings.Contains(acc, toolCloseTag) { + toolContent, _ := p.splitAtTag(toolCloseTag, true) + if len(toolContent) == 0 { + slog.Warn("glm4 tool call closing tag found but no content before it") + } + events = append(events, eventRawToolCall{raw: toolContent}) + p.state = parserState_CollectingContent + return events, true + } else { + // Keep accumulating - tool calls are not streamed + // We just wait for the closing tag + return events, false + } + + default: + panic("unreachable") + } +} + +// overlap returns the length of the overlap between the end of s and the start of tag. +func overlap(s, tag string) int { + for i := 1; i <= len(tag) && i <= len(s); i++ { + if strings.HasSuffix(s, tag[:i]) { + return i + } + } + return 0 +} + +// trailingWhitespaceLen returns the length of trailing whitespace in s. +func trailingWhitespaceLen(s string) int { + trimmed := strings.TrimRightFunc(s, unicode.IsSpace) + return len(s) - len(trimmed) +} + +// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing +type ToolCallXML struct { + XMLName xml.Name `xml:"tool_call"` + Content string `xml:",chardata"` // Function name (text nodes between tags) + Keys []string `xml:"arg_key"` // All arg_key elements in document order + Values []string `xml:"arg_value"` // All arg_value elements in document order +} + +// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags +func escapeContent(s string) string { + var result strings.Builder + inTag := false + + for i := range len(s) { + ch := s[i] + + if ch == '<' { + // Check if this is a known tag + if strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") { + inTag = true + } + } + + if inTag { + result.WriteByte(ch) + if ch == '>' { + inTag = false + } + } else { + // Escape special characters in text content + switch ch { + case '&': + result.WriteString("&") + case '<': + result.WriteString("<") + case '>': + result.WriteString(">") + default: + result.WriteByte(ch) + } + } + } + + return result.String() +} + +func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + // Escape any unescaped entities in text content + escaped := escapeContent(raw.raw) + + // Wrap the content in a root element to make it valid XML + xmlString := "" + escaped + "" + + // Parse XML into struct + var parsed ToolCallXML + if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil { + return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err) + } + + // Extract and trim function name + functionName := strings.TrimSpace(parsed.Content) + if functionName == "" { + return api.ToolCall{}, fmt.Errorf("empty function name") + } + + // Verify keys and values are paired correctly + if len(parsed.Keys) != len(parsed.Values) { + return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values)) + } + + // Find the matching tool to get parameter types + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == functionName { + matchedTool = &tools[i] + break + } + } + + // Build arguments map by pairing keys and values + toolCall := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: functionName, + Arguments: api.NewToolCallFunctionArguments(), + }, + } + + for i := range parsed.Keys { + key := strings.TrimSpace(parsed.Keys[i]) + value := parsed.Values[i] // Don't trim here - parseValue handles it + + // Look up parameter type + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok { + // Handle anyOf by collecting all types from the union + if len(prop.AnyOf) > 0 { + for _, anyOfProp := range prop.AnyOf { + paramType = append(paramType, anyOfProp.Type...) + } + } else { + paramType = prop.Type + } + } + } + + // Parse value with type coercion + toolCall.Function.Arguments.Set(key, parseValue(value, paramType)) + } + + return toolCall, nil +} + +// parseValue parses a string value and coerces it to the appropriate type based on paramType. +func parseValue(value string, paramType api.PropertyType) any { + value = strings.TrimSpace(value) + + // If no type specified, return as string + if len(paramType) == 0 { + return value + } + + // Try to parse based on specified types + for _, t := range paramType { + switch t { + case "boolean": + if value == "true" { + return true + } + if value == "false" { + return false + } + case "integer": + var i int64 + if _, err := fmt.Sscanf(value, "%d", &i); err == nil { + return i + } + case "number": + var f float64 + if _, err := fmt.Sscanf(value, "%f", &f); err == nil { + return f + } + case "array", "object": + // Try to parse as JSON + var result any + if err := json.Unmarshal([]byte(value), &result); err == nil { + return result + } + } + } + + // Default to string + return value +} diff --git a/x/imagegen/models/glm4_moe_lite/parser_test.go b/x/imagegen/models/glm4_moe_lite/parser_test.go new file mode 100644 index 000000000..0ce382709 --- /dev/null +++ b/x/imagegen/models/glm4_moe_lite/parser_test.go @@ -0,0 +1,192 @@ +//go:build mlx + +package glm4_moe_lite + +import ( + "testing" + + "github.com/ollama/ollama/api" +) + +func TestParserThinking(t *testing.T) { + tests := []struct { + name string + input string + thinkEnabled bool + wantContent string + wantThinking string + wantToolCalls int + }{ + { + name: "thinking enabled - simple thinking then content", + input: "Let me think about this...Here is my answer.", + thinkEnabled: true, + wantThinking: "Let me think about this...", + wantContent: "Here is my answer.", + }, + { + name: "thinking enabled - only thinking", + input: "I need to consider multiple factors...", + thinkEnabled: true, + wantThinking: "I need to consider multiple factors...", + wantContent: "", + }, + { + name: "thinking disabled - direct content", + input: "Here is my direct answer.", + thinkEnabled: false, + wantThinking: "", + wantContent: "Here is my direct answer.", + }, + { + name: "thinking with tool call", + input: "Let me search for that...I'll use a tool.searchquerytest", + thinkEnabled: true, + wantThinking: "Let me search for that...", + wantContent: "I'll use a tool.", + wantToolCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Parser{} + + var thinkValue *api.ThinkValue + if tt.thinkEnabled { + thinkValue = &api.ThinkValue{Value: true} + } else { + thinkValue = &api.ThinkValue{Value: false} + } + + // Define tools for tool call tests + props := api.NewToolPropertiesMap() + props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{ + { + Function: api.ToolFunction{ + Name: "search", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }, + } + + p.Init(tools, nil, thinkValue) + + content, thinking, calls, err := p.Add(tt.input, true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if thinking != tt.wantThinking { + t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking) + } + if content != tt.wantContent { + t.Errorf("content = %q, want %q", content, tt.wantContent) + } + if len(calls) != tt.wantToolCalls { + t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls) + } + }) + } +} + +func TestParserToolCall(t *testing.T) { + p := &Parser{} + + props := api.NewToolPropertiesMap() + props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}}) + props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{ + { + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }, + } + + // Initialize with thinking disabled + tv := &api.ThinkValue{Value: false} + p.Init(tools, nil, tv) + + input := "get_weatherlocationSan Franciscounitcelsius" + + _, _, calls, err := p.Add(input, true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(calls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(calls)) + } + + call := calls[0] + if call.Function.Name != "get_weather" { + t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather") + } + + location, ok := call.Function.Arguments.Get("location") + if !ok || location != "San Francisco" { + t.Errorf("location = %v, want %q", location, "San Francisco") + } + + unit, ok := call.Function.Arguments.Get("unit") + if !ok || unit != "celsius" { + t.Errorf("unit = %v, want %q", unit, "celsius") + } +} + +func TestOverlap(t *testing.T) { + tests := []struct { + s string + tag string + want int + }{ + {"hello<", "", 1}, + {"hello", 2}, + {"hello", 3}, + {"hello", 4}, + {"hello", 5}, + {"hello", 6}, + {"hello", 7}, + {"hello", "", 8}, // Complete tag at end returns full length + {"hello", "", 0}, + {"", "", 0}, + } + + for _, tt := range tests { + t.Run(tt.s+"_"+tt.tag, func(t *testing.T) { + got := overlap(tt.s, tt.tag) + if got != tt.want { + t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want) + } + }) + } +} + +func TestTrailingWhitespaceLen(t *testing.T) { + tests := []struct { + s string + want int + }{ + {"hello ", 3}, + {"hello\n\t ", 3}, + {"hello", 0}, + {"", 0}, + {" ", 3}, + } + + for _, tt := range tests { + t.Run(tt.s, func(t *testing.T) { + got := trailingWhitespaceLen(tt.s) + if got != tt.want { + t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want) + } + }) + } +} diff --git a/x/imagegen/models/glm4_moe_lite/render.go b/x/imagegen/models/glm4_moe_lite/render.go new file mode 100644 index 000000000..4998604bf --- /dev/null +++ b/x/imagegen/models/glm4_moe_lite/render.go @@ -0,0 +1,175 @@ +//go:build mlx + +package glm4_moe_lite + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/ollama/ollama/api" +) + +// Renderer renders messages for GLM4-MoE-Lite models. +// +// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode): +// +// 1. INTERLEAVED THINKING +// The model thinks between tool calls and after receiving tool results. +// This enables complex step-by-step reasoning: interpreting each tool output +// before deciding what to do next. Thinking blocks are preserved and returned +// with tool results to maintain reasoning continuity. +// +// 2. PRESERVED THINKING +// The model retains reasoning content from previous assistant turns in context. +// This preserves reasoning continuity across multi-turn conversations. The +// upstream API has a "clear_thinking" parameter to control this: +// - clear_thinking=true: clears reasoning from previous turns (outputs ) +// - clear_thinking=false: preserves ... blocks from previous turns +// +// 3. TURN-LEVEL THINKING +// Controls whether the model should reason on each turn. The upstream API +// uses "enable_thinking" parameter: +// - enable_thinking=true: outputs to start reasoning +// - enable_thinking=false: outputs to skip reasoning +// +// OLLAMA DEFAULTS: +// - Thinking is ENABLED by default (thinkValue=nil or true outputs ) +// - Thinking is PRESERVED by default (reasoning content from previous turns is always +// included in ... blocks, equivalent to clear_thinking=false) +// - Users can disable thinking per-turn via thinkValue=false +type Renderer struct{} + +// Render renders messages into the GLM4 chat format. +func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { + var sb strings.Builder + + sb.WriteString("[gMASK]") + + if len(tools) > 0 { + sb.WriteString("<|system|>\n") + sb.WriteString("# Tools\n\n") + sb.WriteString("You may call one or more functions to assist with the user query.\n\n") + sb.WriteString("You are provided with function signatures within XML tags:\n") + sb.WriteString("\n") + for _, tool := range tools { + d, _ := json.Marshal(tool) + sb.WriteString(formatToolJSON(d)) + sb.WriteString("\n") + } + sb.WriteString("\n\n") + sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n") + sb.WriteString("{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...") + } + + think := true + if thinkValue != nil && !thinkValue.Bool() { + think = false + } + + for i, message := range messages { + switch message.Role { + case "user": + sb.WriteString("<|user|>") + sb.WriteString(message.Content) + case "assistant": + sb.WriteString("<|assistant|>") + if message.Thinking != "" { + sb.WriteString("" + message.Thinking + "") + } else { + sb.WriteString("") + } + if message.Content != "" { + sb.WriteString(message.Content) + } + if len(message.ToolCalls) > 0 { + for _, toolCall := range message.ToolCalls { + sb.WriteString("" + toolCall.Function.Name) + sb.WriteString(renderToolArguments(toolCall.Function.Arguments)) + sb.WriteString("") + } + } + case "tool": + if i == 0 || messages[i-1].Role != "tool" { + sb.WriteString("<|observation|>") + } + sb.WriteString("") + sb.WriteString(message.Content) + sb.WriteString("") + case "system": + sb.WriteString("<|system|>") + sb.WriteString(message.Content) + } + } + + sb.WriteString("<|assistant|>") + if think { + sb.WriteString("") + } else { + sb.WriteString("") + } + + return sb.String(), nil +} + +// renderToolArguments converts tool call arguments to GLM4 XML format. +func renderToolArguments(args api.ToolCallFunctionArguments) string { + var sb strings.Builder + for key, value := range args.All() { + sb.WriteString("" + key + "") + var valueStr string + if str, ok := value.(string); ok { + valueStr = str + } else { + jsonBytes, err := json.Marshal(value) + if err != nil { + valueStr = fmt.Sprintf("%v", value) + } else { + valueStr = string(jsonBytes) + } + } + + sb.WriteString("" + valueStr + "") + } + + return sb.String() +} + +// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and , +func formatToolJSON(raw []byte) string { + var sb strings.Builder + sb.Grow(len(raw) + len(raw)/10) + + inString := false + escaped := false + for i := range raw { + ch := raw[i] + sb.WriteByte(ch) + + if inString { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + inString = false + } + continue + } + + if ch == '"' { + inString = true + continue + } + + if ch == ':' || ch == ',' { + sb.WriteByte(' ') + } + } + + return sb.String() +} diff --git a/x/imagegen/models/glm4_moe_lite/render_test.go b/x/imagegen/models/glm4_moe_lite/render_test.go new file mode 100644 index 000000000..f0d576bec --- /dev/null +++ b/x/imagegen/models/glm4_moe_lite/render_test.go @@ -0,0 +1,205 @@ +//go:build mlx + +package glm4_moe_lite + +import ( + "strings" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestRendererSimple(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "user", Content: "Hello"}, + } + + // Thinking enabled (default) + result, err := r.Render(messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "[gMASK]<|user|>Hello<|assistant|>" + if result != expected { + t.Errorf("result = %q, want %q", result, expected) + } +} + +func TestRendererThinkingDisabled(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "user", Content: "Hello"}, + } + + tv := &api.ThinkValue{Value: false} + + result, err := r.Render(messages, nil, tv) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "[gMASK]<|user|>Hello<|assistant|>" + if result != expected { + t.Errorf("result = %q, want %q", result, expected) + } +} + +func TestRendererMultiTurn(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "user", Content: "What is 2+2?"}, + {Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"}, + {Role: "user", Content: "And 3+3?"}, + } + + result, err := r.Render(messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check key parts + if !strings.Contains(result, "[gMASK]") { + t.Error("missing [gMASK] prefix") + } + if !strings.Contains(result, "<|user|>What is 2+2?") { + t.Error("missing first user message") + } + if !strings.Contains(result, "<|assistant|>Let me calculate: 2+2=44") { + t.Error("missing assistant message with thinking") + } + if !strings.Contains(result, "<|user|>And 3+3?") { + t.Error("missing second user message") + } + if !strings.HasSuffix(result, "<|assistant|>") { + t.Errorf("should end with <|assistant|>, got suffix: %q", result[len(result)-30:]) + } +} + +func TestRendererWithSystem(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + } + + result, err := r.Render(messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !strings.Contains(result, "<|system|>You are a helpful assistant.") { + t.Error("missing system message") + } +} + +func TestRendererWithTools(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "user", Content: "What's the weather?"}, + } + + props := api.NewToolPropertiesMap() + props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"}) + tools := []api.Tool{ + { + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the weather for a location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: props, + Required: []string{"location"}, + }, + }, + }, + } + + result, err := r.Render(messages, tools, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check for tool system prompt + if !strings.Contains(result, "<|system|>") { + t.Error("missing system tag for tools") + } + if !strings.Contains(result, "# Tools") { + t.Error("missing tools header") + } + if !strings.Contains(result, "") { + t.Error("missing tools tag") + } + if !strings.Contains(result, "get_weather") { + t.Error("missing tool name") + } + if !strings.Contains(result, "") { + t.Error("missing closing tools tag") + } +} + +func TestRendererWithToolCalls(t *testing.T) { + r := &Renderer{} + + args := api.NewToolCallFunctionArguments() + args.Set("location", "San Francisco") + + messages := []api.Message{ + {Role: "user", Content: "What's the weather in SF?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: args, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny, 72F"}, + } + + result, err := r.Render(messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !strings.Contains(result, "get_weather") { + t.Error("missing tool call") + } + if !strings.Contains(result, "location") { + t.Error("missing arg_key") + } + if !strings.Contains(result, "San Francisco") { + t.Error("missing arg_value") + } + if !strings.Contains(result, "") { + t.Error("missing tool call closing tag") + } + if !strings.Contains(result, "<|observation|>") { + t.Error("missing observation tag") + } + if !strings.Contains(result, "Sunny, 72F") { + t.Error("missing tool response") + } +} + +func TestFormatToolJSON(t *testing.T) { + input := []byte(`{"name":"test","value":123}`) + result := formatToolJSON(input) + + // Should add spaces after : and , + if !strings.Contains(result, ": ") { + t.Error("should add space after colon") + } + if !strings.Contains(result, ", ") { + t.Error("should add space after comma") + } +} diff --git a/x/imagegen/nn/nn.go b/x/imagegen/nn/nn.go index 65bf7fa22..d72474358 100644 --- a/x/imagegen/nn/nn.go +++ b/x/imagegen/nn/nn.go @@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear { // NewQuantizedLinear creates a quantized linear layer directly from bf16 weights. // Quantizes the weight immediately and evaluates to break lazy dependencies. +// Note: For modes like "nvfp4", qbiases will be nil. func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear { qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode) // Eval immediately so bf16 weight can be freed - mlx.Eval(qw, scales, qbiases) + // Handle modes that don't return qbiases (e.g., nvfp4) + if qbiases != nil { + mlx.Eval(qw, scales, qbiases) + } else { + mlx.Eval(qw, scales) + } return &QuantizedLinear{ Weight: qw, Scales: scales, @@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear // QuantizedLinear applies an affine transformation using quantized weights. // Equivalent to mlx.nn.QuantizedLinear. +// Supports multiple quantization modes: +// - "affine": scale + zero-point bias (QBiases required) +// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil) type QuantizedLinear struct { Weight *mlx.Array // Quantized weight data Scales *mlx.Array // Scale factors for dequantization - QBiases *mlx.Array // Quantization biases (NOT layer bias) + QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4 Bias *mlx.Array // Layer bias [output_dims] or nil GroupSize int Bits int @@ -220,3 +229,32 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array { } return out } + +// MultiLinearLayer is an interface for per-head linear layers. +// This allows swapping between MultiLinear (bf16) and pre-dequantized weights. +type MultiLinearLayer interface { + Forward(x *mlx.Array) *mlx.Array +} + +// MultiLinear performs per-head linear projections. +// Weight shape: [num_heads, output_dims, input_dims] +// Input shape: [B, num_heads, L, input_dims] +// Output shape: [B, num_heads, L, output_dims] +type MultiLinear struct { + Weight *mlx.Array `weight:"weight"` +} + +// NewMultiLinear creates a MultiLinear layer with the given weight. +func NewMultiLinear(weight *mlx.Array) *MultiLinear { + return &MultiLinear{Weight: weight} +} + +// Forward applies per-head linear transformation: x @ weight.T per head via broadcasting. +func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array { + // Weight: [num_heads, output_dims, input_dims] + // x: [B, num_heads, L, input_dims] + // wT: [num_heads, input_dims, output_dims] + // Result: [B, num_heads, L, output_dims] + wT := mlx.Transpose(ml.Weight, 0, 2, 1) + return mlx.Matmul(x, wT) +} diff --git a/x/imagegen/runner/runner.go b/x/imagegen/runner/runner.go deleted file mode 100644 index f43276468..000000000 --- a/x/imagegen/runner/runner.go +++ /dev/null @@ -1,284 +0,0 @@ -//go:build mlx - -// Package runner provides a subprocess server for image generation. -// It listens on a port and handles HTTP requests for image generation. -package runner - -import ( - "context" - "encoding/json" - "flag" - "fmt" - "image" - "log/slog" - "net/http" - "os" - "os/signal" - "sync" - "syscall" - "time" - - "github.com/ollama/ollama/x/imagegen" - "github.com/ollama/ollama/x/imagegen/mlx" - "github.com/ollama/ollama/x/imagegen/models/flux2" - "github.com/ollama/ollama/x/imagegen/models/zimage" -) - -// Request is the image generation request format -type Request struct { - Prompt string `json:"prompt"` - Width int32 `json:"width,omitempty"` - Height int32 `json:"height,omitempty"` - Steps int `json:"steps,omitempty"` - Seed int64 `json:"seed,omitempty"` - Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning -} - -// Response is streamed back for each progress update -type Response struct { - Content string `json:"content,omitempty"` - Image string `json:"image,omitempty"` // Base64-encoded PNG - Done bool `json:"done"` - Step int `json:"step,omitempty"` - Total int `json:"total,omitempty"` -} - -// ImageModel is the interface for image generation models -type ImageModel interface { - GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) -} - -// ImageEditModel extends ImageModel with image editing/conditioning capability. -// Models that support input images for editing should implement this interface. -type ImageEditModel interface { - ImageModel - GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) -} - -// Server holds the model and handles requests -type Server struct { - mu sync.Mutex - model ImageModel - modelName string -} - -// Execute is the entry point for the image runner subprocess -func Execute(args []string) error { - fs := flag.NewFlagSet("image-runner", flag.ExitOnError) - modelName := fs.String("model", "", "path to image model") - port := fs.Int("port", 0, "port to listen on") - - if err := fs.Parse(args); err != nil { - return err - } - - if *modelName == "" { - return fmt.Errorf("--model is required") - } - if *port == 0 { - return fmt.Errorf("--port is required") - } - - err := mlx.InitMLX() - if err != nil { - slog.Error("unable to initialize MLX", "error", err) - return err - } - slog.Info("MLX library initialized") - slog.Info("starting image runner", "model", *modelName, "port", *port) - - // Detect model type and load appropriate model - modelType := imagegen.DetectModelType(*modelName) - slog.Info("detected model type", "type", modelType) - - var model ImageModel - switch modelType { - case "Flux2KleinPipeline": - m := &flux2.Model{} - if err := m.Load(*modelName); err != nil { - return fmt.Errorf("failed to load model: %w", err) - } - model = m - default: - // Default to Z-Image for ZImagePipeline, FluxPipeline, etc. - m := &zimage.Model{} - if err := m.Load(*modelName); err != nil { - return fmt.Errorf("failed to load model: %w", err) - } - model = m - } - - server := &Server{ - model: model, - modelName: *modelName, - } - - // Set up HTTP handlers - mux := http.NewServeMux() - mux.HandleFunc("/health", server.healthHandler) - mux.HandleFunc("/completion", server.completionHandler) - - httpServer := &http.Server{ - Addr: fmt.Sprintf("127.0.0.1:%d", *port), - Handler: mux, - } - - // Handle shutdown - done := make(chan struct{}) - go func() { - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh - slog.Info("shutting down image runner") - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - httpServer.Shutdown(ctx) - close(done) - }() - - slog.Info("image runner listening", "addr", httpServer.Addr) - if err := httpServer.ListenAndServe(); err != http.ErrServerClosed { - return err - } - - <-done - return nil -} - -func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) -} - -func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - - var req Request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - // Validate and decode input images - const maxInputImages = 2 - if len(req.Images) > maxInputImages { - http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest) - return - } - - var inputImages []image.Image - if len(req.Images) > 0 { - // TODO: add memory check for input images - - inputImages = make([]image.Image, len(req.Images)) - for i, imgBytes := range req.Images { - img, err := imagegen.DecodeImage(imgBytes) - if err != nil { - http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest) - return - } - inputImages[i] = img - } - slog.Info("decoded input images", "count", len(inputImages)) - - // Default width/height to first input image dimensions, scaled to max 1024 - bounds := inputImages[0].Bounds() - w, h := bounds.Dx(), bounds.Dy() - if w > 1024 || h > 1024 { - if w > h { - h = h * 1024 / w - w = 1024 - } else { - w = w * 1024 / h - h = 1024 - } - } - req.Width = int32(w) - req.Height = int32(h) - } - - // Serialize generation requests - MLX model may not handle concurrent generation - s.mu.Lock() - defer s.mu.Unlock() - - // Model applies its own defaults for width/height/steps - // Only seed needs to be set here if not provided - if req.Seed <= 0 { - req.Seed = time.Now().UnixNano() - } - - // Set up streaming response - w.Header().Set("Content-Type", "application/x-ndjson") - w.Header().Set("Transfer-Encoding", "chunked") - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "streaming not supported", http.StatusInternalServerError) - return - } - - // Generate image using the common interface - ctx := r.Context() - enc := json.NewEncoder(w) - - // Progress callback streams step updates - progress := func(step, total int) { - resp := Response{Step: step, Total: total} - enc.Encode(resp) - w.Write([]byte("\n")) - flusher.Flush() - } - - // Use ImageEditModel if available and images provided, otherwise use basic ImageModel - var img *mlx.Array - var err error - if len(inputImages) > 0 { - editModel, ok := s.model.(ImageEditModel) - if !ok { - http.Error(w, "model does not support image editing", http.StatusBadRequest) - return - } - img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress) - } else { - img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress) - } - - if err != nil { - // Don't send error for cancellation - if ctx.Err() != nil { - return - } - resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true} - data, _ := json.Marshal(resp) - w.Write(data) - w.Write([]byte("\n")) - return - } - - // Encode image as base64 PNG - imageData, err := imagegen.EncodeImageBase64(img) - if err != nil { - resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true} - data, _ := json.Marshal(resp) - w.Write(data) - w.Write([]byte("\n")) - return - } - - // Free the generated image array and clean up MLX state - img.Free() - mlx.ClearCache() - mlx.MetalResetPeakMemory() - - // Send final response with image data - resp := Response{ - Image: imageData, - Done: true, - } - data, _ := json.Marshal(resp) - w.Write(data) - w.Write([]byte("\n")) - flusher.Flush() -} diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go index 7f8860b06..4c1d0a9af 100644 --- a/x/imagegen/safetensors/loader.go +++ b/x/imagegen/safetensors/loader.go @@ -17,17 +17,31 @@ type WeightSource interface { GetTensor(name string) (*mlx.Array, error) ListTensors() []string HasTensor(name string) bool - Quantization() string // Returns "FP4", "FP8", or "" + Quantization() string // Returns "NVFP4", "Q4", "Q8", or "" + GroupSize() int // Returns quantization group size, or 0 if not specified } -// quantizationParams returns groupSize, bits, mode for a quantization type. -// Returns defaults (32, 8, "affine") for unknown types (backward compatibility). -func quantizationParams(quantization string) (groupSize, bits int, mode string) { +// QuantizationParams returns groupSize, bits, mode for a quantization type. +// MLX quantization modes: +// - "affine": scale + zero-point bias, group_size=32/64/128 +// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias) +// - "mxfp8": Microsoft MX FP8 with E4M3 scales, group_size=32 (no bias) +func QuantizationParams(quantization string) (groupSize, bits int, mode string) { switch strings.ToUpper(quantization) { - case "FP4": + case "NVFP4": + // NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias) + return 16, 4, "nvfp4" + case "FP4", "Q4", "INT4": + // 4-bit quantization with affine mode (scale + qbias) return 32, 4, "affine" + case "MXFP8": + // Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias) + return 32, 8, "mxfp8" + case "FP8", "Q8", "INT8", "": + // 8-bit quantization with affine mode (default for quantized models) + return 64, 8, "affine" default: - return 32, 8, "affine" // FP8 or unknown + return 32, 8, "affine" // Default to affine } } @@ -122,7 +136,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st } // Handle nn.LinearLayer interface fields specially - if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() { + linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() + if field.Type == linearLayerType { if !hasTag { continue // no tag = skip } @@ -137,6 +152,23 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st continue } + // Handle nn.MultiLinearLayer interface fields specially + multiLinearLayerType := reflect.TypeOf((*nn.MultiLinearLayer)(nil)).Elem() + if field.Type == multiLinearLayerType { + if !hasTag { + continue // no tag = skip + } + layer, err := LoadMultiLinearLayer(weights, fullPath) + if err != nil { + if !optional { + *errs = append(*errs, fullPath+": "+err.Error()) + } + continue + } + fieldVal.Set(reflect.ValueOf(layer)) + continue + } + // Handle by kind switch fieldVal.Kind() { case reflect.Ptr: @@ -216,12 +248,86 @@ func joinPath(prefix, suffix string) string { return prefix + "." + suffix } +// LoadMultiLinearLayer loads a per-head linear layer from weights. +// Weight shape should be [num_heads, output_dims, input_dims]. +// If quantized, always dequantizes since batched quantized matmul isn't supported. +func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLayer, error) { + // Check if this is a quantized layer by looking for scale tensor + scalePath := path + ".weight_scale" + hasScale := weights.HasTensor(scalePath) + + weight, err := weights.GetTensor(path + ".weight") + if err != nil { + return nil, fmt.Errorf("failed to load weight %s: %w", path, err) + } + + if hasScale { + scales, err := weights.GetTensor(scalePath) + if err != nil { + return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err) + } + + var qbiases *mlx.Array + qbiasPath := path + ".weight_qbias" + if weights.HasTensor(qbiasPath) { + qbiases, _ = weights.GetTensor(qbiasPath) + } + + // Always dequantize for MultiLinear - no batched quantized matmul support + // Detect bits from tensor shapes (supports mixed-precision Q4/Q8) + weightShape := weight.Shape() + scalesShape := scales.Shape() + weightCols := int(weightShape[len(weightShape)-1]) + scalesCols := int(scalesShape[len(scalesShape)-1]) + + // Detect quantization from tensor shapes + // groupSize = weightCols * packFactor / scalesCols + // Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata + groupSize4 := weightCols * 8 / scalesCols + groupSize8 := weightCols * 4 / scalesCols + + var bits, groupSize int + // Use metadata to help disambiguate when shapes are ambiguous + // (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32) + quantType := strings.ToUpper(weights.Quantization()) + isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8" + + if groupSize4 == 32 { + // Unambiguous: Q4 with group_size=32 + bits = 4 + groupSize = 32 + } else if groupSize8 == 64 { + // Unambiguous: Q8 with group_size=64 + bits = 8 + groupSize = 64 + } else if groupSize4 == 64 && groupSize8 == 32 { + // Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata + if isQ8Type { + bits = 8 + groupSize = 32 + } else { + bits = 4 + groupSize = 64 + } + } else { + // Fallback: use global quantization params + _, bits, _ = QuantizationParams(weights.Quantization()) + packFactor := 32 / bits + groupSize = weightCols * packFactor / scalesCols + } + weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine") + } + + return nn.NewMultiLinear(weight), nil +} + // LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized. -// If {path}.weight_scale exists, dequantizes the weights. +// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support). func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) { // Check if this is a quantized layer by looking for scale tensor scalePath := path + ".weight_scale" - if weights.HasTensor(scalePath) { + hasScale := weights.HasTensor(scalePath) + if hasScale { weight, err := weights.GetTensor(path + ".weight") if err != nil { return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err) @@ -245,9 +351,52 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) qbiases, _ = weights.GetTensor(qbiasPath) } - groupSize, bits, mode := quantizationParams(weights.Quantization()) + // Detect bits from tensor shapes (supports mixed-precision Q4/Q8) + weightShape := weight.Shape() + scalesShape := scales.Shape() + weightCols := int(weightShape[len(weightShape)-1]) + scalesCols := int(scalesShape[len(scalesShape)-1]) - if mlx.MetalIsAvailable() { + // Detect quantization from tensor shapes + // groupSize = weightCols * packFactor / scalesCols + // Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata + groupSize4 := weightCols * 8 / scalesCols + groupSize8 := weightCols * 4 / scalesCols + + var bits, groupSize int + mode := "affine" + // Use metadata to help disambiguate when shapes are ambiguous + // (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32) + quantType := strings.ToUpper(weights.Quantization()) + isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8" + + if groupSize4 == 32 { + // Unambiguous: Q4 with group_size=32 + bits = 4 + groupSize = 32 + } else if groupSize8 == 64 { + // Unambiguous: Q8 with group_size=64 + bits = 8 + groupSize = 64 + } else if groupSize4 == 64 && groupSize8 == 32 { + // Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata + if isQ8Type { + bits = 8 + groupSize = 32 + } else { + bits = 4 + groupSize = 64 + } + } else { + // Fallback: use global quantization params + _, bits, mode = QuantizationParams(weights.Quantization()) + packFactor := 32 / bits + groupSize = weightCols * packFactor / scalesCols + } + + // NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX, + // so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support. + if mlx.MetalIsAvailable() && mode != "nvfp4" && mode != "mxfp8" { return &nn.QuantizedLinear{ Weight: weight, Scales: scales, diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go index a36052fce..4dbcf59a3 100644 --- a/x/imagegen/safetensors/safetensors.go +++ b/x/imagegen/safetensors/safetensors.go @@ -303,6 +303,11 @@ func (mw *ModelWeights) Quantization() string { return "" } +// GroupSize returns 0 for directory-based weights (use default). +func (mw *ModelWeights) GroupSize() int { + return 0 +} + // ReleaseAll releases all cached native file handles. func (mw *ModelWeights) ReleaseAll() { for path, native := range mw.nativeCache { diff --git a/x/imagegen/server_test.go b/x/imagegen/server_test.go deleted file mode 100644 index 396aa140b..000000000 --- a/x/imagegen/server_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package imagegen - -import ( - "runtime" - "testing" -) - -// TestPlatformSupport verifies platform validation works correctly. -func TestPlatformSupport(t *testing.T) { - err := CheckPlatformSupport() - - switch runtime.GOOS { - case "darwin": - if runtime.GOARCH == "arm64" { - // Apple Silicon should be supported - if err != nil { - t.Errorf("Expected nil error on darwin/arm64, got: %v", err) - } - } else { - // Intel Mac should fail - if err == nil { - t.Error("Expected error on darwin/amd64 (Intel), got nil") - } - if err != nil && err.Error() == "" { - t.Error("Expected meaningful error message for unsupported platform") - } - } - case "linux", "windows": - // Linux/Windows are allowed (CUDA support checked at runtime) - if err != nil { - t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err) - } - default: - // Other platforms should fail - if err == nil { - t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS) - } - } -} - -// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer. -// This is a compile-time check but we document it as a test. -func TestServerInterfaceCompliance(t *testing.T) { - // The var _ llm.LlamaServer = (*Server)(nil) line in server.go - // ensures compile-time interface compliance. - // This test documents that requirement. - t.Log("Server implements llm.LlamaServer interface (compile-time checked)") -} diff --git a/x/imagegen/weights.go b/x/imagegen/weights.go index f49c7e77e..eb60c9895 100644 --- a/x/imagegen/weights.go +++ b/x/imagegen/weights.go @@ -20,20 +20,28 @@ type ManifestWeights struct { nativeCache []*mlx.SafetensorsFile // keep native handles alive } -// LoadWeightsFromManifest creates a weight loader for a component from manifest storage. +// LoadWeightsFromManifest creates a weight loader from manifest storage. +// If component is empty, loads all tensors (for LLM models). +// If component is specified, loads only tensors for that component and strips the prefix. func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) { layers := manifest.GetTensorLayers(component) if len(layers) == 0 { + if component == "" { + return nil, fmt.Errorf("no tensor layers found in manifest") + } return nil, fmt.Errorf("no tensor layers found for component %q", component) } // Strip component prefix from tensor names for model loading // e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight" - prefix := component + "/" tensors := make(map[string]ManifestLayer, len(layers)) for _, layer := range layers { - tensorName := strings.TrimPrefix(layer.Name, prefix) - tensors[tensorName] = layer + if component == "" { + tensors[layer.Name] = layer + } else { + tensorName := strings.TrimPrefix(layer.Name, component+"/") + tensors[tensorName] = layer + } } return &ManifestWeights{ @@ -48,19 +56,30 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife // Blobs are stored in safetensors format for native mlx_load_safetensors mmap. // If dtype is non-zero, tensors are converted to the specified dtype. func (mw *ManifestWeights) Load(dtype mlx.Dtype) error { + // Track native handles to free after batch eval + nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors)) + arrays := make([]*mlx.Array, 0, len(mw.tensors)) + for name, layer := range mw.tensors { path := mw.manifest.BlobPath(layer.Digest) // Load blob as safetensors (native mmap, zero-copy) sf, err := mlx.LoadSafetensorsNative(path) if err != nil { + // Free any handles we've accumulated + for _, h := range nativeHandles { + h.Free() + } return fmt.Errorf("load %s: %w", name, err) } + nativeHandles = append(nativeHandles, sf) // Blob contains single tensor named "data" arr := sf.Get("data") if arr == nil { - sf.Free() + for _, h := range nativeHandles { + h.Free() + } return fmt.Errorf("tensor 'data' not found in blob for %s", name) } @@ -68,11 +87,18 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error { if dtype != 0 && arr.Dtype() != dtype { arr = mlx.AsType(arr, dtype) } - // ALWAYS make a contiguous copy to ensure independence from mmap + // Make contiguous copy to ensure independence from mmap arr = mlx.Contiguous(arr) - mlx.Eval(arr) mw.cache[name] = arr - sf.Free() // Safe to free - arr is now an independent copy + arrays = append(arrays, arr) + } + + // Batch evaluate all tensors at once (much faster than one at a time) + mlx.Eval(arrays...) + + // Now safe to free all native handles + for _, sf := range nativeHandles { + sf.Free() } return nil @@ -107,18 +133,112 @@ func (mw *ManifestWeights) HasTensor(name string) bool { } // Quantization returns the model's quantization type from model_index.json. -// Returns empty string if not quantized or unknown. +// Returns empty string if not quantized. +// Falls back to detecting from tensor names and shapes if not in config. func (mw *ManifestWeights) Quantization() string { if mw.manifest == nil { return "" } + + // Try to read from model_index.json first var index struct { Quantization string `json:"quantization"` } - if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil { + if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.Quantization != "" { + return index.Quantization + } + + // Fallback: detect from tensor names + // Check if any tensors have _scale suffix (indicates quantization) + hasScales := false + hasQBias := false + for name := range mw.tensors { + if strings.HasSuffix(name, ".weight_scale") { + hasScales = true + } + if strings.HasSuffix(name, ".weight_qbias") { + hasQBias = true + } + } + + if !hasScales { + // No scales = not quantized return "" } - return index.Quantization + + // Has scales but no qbias = NVFP4 (or other non-affine mode) + if !hasQBias { + return "NVFP4" + } + + // Has both scales and qbias = affine mode + // Need to determine FP4 vs FP8 from tensor shapes + // FP4: weight last dim is 1/8 of scales last dim * group_size + // FP8: weight last dim is 1/4 of scales last dim * group_size + // + // For affine mode with group_size=32: + // - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8 + // - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4 + // scales_dim = orig_dim / group_size + // So: weight_dim / scales_dim = group_size / pack_factor + // FP4: ratio = 32/8 = 4 + // FP8: ratio = 32/4 = 8 + + // Find a weight/scale pair to check the ratio + for name := range mw.tensors { + if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") { + continue + } + scaleName := name + "_scale" + if _, ok := mw.tensors[scaleName]; !ok { + continue + } + + // Load both tensors to check shapes + weightLayer := mw.tensors[name] + scaleLayer := mw.tensors[scaleName] + + // Get shapes from manifest layer metadata if available + // For now, default to FP4 since it's more common + // The actual shape check would require loading the tensor + + // Simple heuristic: check if scale tensor is ~4x smaller than weight + // FP4: weight is packed 8 per uint32, scales are 1 per group (32) + // So scale size should be ~weight_size * 8 / 32 = weight_size / 4 + // FP8: weight is packed 4 per uint32, scales are 1 per group (32) + // So scale size should be ~weight_size * 4 / 32 = weight_size / 8 + + // Rough size heuristic (assuming float16 scales) + // Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8 + // Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16 + ratio := float64(weightLayer.Size) / float64(scaleLayer.Size) + if ratio < 12 { + // Closer to 8 = Q4 + return "Q4" + } + // Closer to 16 = Q8 + return "Q8" + } + + // Default to Q4 for affine mode (most common) + return "Q4" +} + +// GroupSize returns the quantization group size from model_index.json. +// Returns 0 if not specified (caller should use default based on quantization type). +func (mw *ManifestWeights) GroupSize() int { + if mw.manifest == nil { + return 0 + } + + var index struct { + GroupSize int `json:"group_size"` + } + if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.GroupSize > 0 { + return index.GroupSize + } + + return 0 } // ReleaseAll frees all native handles and clears the tensor cache. diff --git a/x/kvcache/causal.go b/x/kvcache/causal.go index 967fed674..eb5ef6f96 100644 --- a/x/kvcache/causal.go +++ b/x/kvcache/causal.go @@ -1,797 +1,144 @@ +//go:build mlx + package kvcache -// import ( -// "errors" -// "fmt" -// "log/slog" -// "math" -// "slices" - -// "github.com/ollama/ollama/ml" -// "github.com/ollama/ollama/model/input" -// ) - -// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) - -// // Causal cache stores K and V tensors according to their position in the -// // sequence. Returns the history and a mask for attending to past tokens -// // -// // The tensors are of shape embed dim, kv heads, batch size -// // The mask is of shape history size, batch size -// type Causal struct { -// DType ml.DType - -// // swaWindowSize is the number of tokens that will be included in the mask -// // during attention operations. swaMemorySize is the number of tokens that -// // will be retained in memory for partial prefix caching. Set to math.MaxInt32 -// // for unlimited or if sliding window attention is not being used. -// swaWindowSize int32 -// swaMemorySize int32 - -// chunkSize int32 - -// opts CausalOptions - -// // maxBatch is the largest batch that we might receive -// maxBatch int - -// // config controls mostly backend-specific optimizations -// config *ml.CacheConfig - -// // ** current forward pass ** - -// // size of the current batch -// curBatchSize int - -// // locations for data storage for this batch -// curLoc ml.Tensor - -// // mask of the cache as used by this batch -// curMask ml.Tensor - -// // the active layer for Get and Put -// curLayer int - -// // locations in the cache that are needed for this batch -// curCellRange cellRange - -// // curSequences is the sequences corresponding to this pass's entries in the cache -// curSequences []int - -// // curPositions is the positions corresponding to this pass's entries in the cache -// curPositions []int32 - -// // ** cache metadata ** - -// // for each possible location in the cache, stores the position and set of sequences -// // that reference the data there -// cells []cacheCell - -// // maps from sequence to the range of locations where it is stored in the cache -// cellRanges map[int]cellRange - -// // ** cache data storage ** - -// shiftFn shiftFn -// backend ml.Backend -// ctxs map[int]ml.Context -// keys, values map[int]ml.Tensor - -// kHeadDims, vHeadDims, numKVHeads map[int]int -// } - -// type cacheCell struct { -// pos int32 -// sequences []int -// } - -// type cellRange struct { -// min int -// max int -// } - -// func NewCausalCache(shift shiftFn) *Causal { -// return &Causal{ -// shiftFn: shift, -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// kHeadDims: make(map[int]int), -// vHeadDims: make(map[int]int), -// numKVHeads: make(map[int]int), -// } -// } - -// func NewSWACache(windowSize int32, shift shiftFn) *Causal { -// return &Causal{ -// swaWindowSize: windowSize, -// shiftFn: shift, -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// kHeadDims: make(map[int]int), -// vHeadDims: make(map[int]int), -// numKVHeads: make(map[int]int), -// } -// } - -// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal { -// return &Causal{ -// swaWindowSize: windowSize, -// swaMemorySize: memorySize, -// shiftFn: shift, -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// kHeadDims: make(map[int]int), -// vHeadDims: make(map[int]int), -// numKVHeads: make(map[int]int), -// } -// } - -// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal { -// return &Causal{ -// chunkSize: chunkSize, -// shiftFn: shift, -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// kHeadDims: make(map[int]int), -// vHeadDims: make(map[int]int), -// numKVHeads: make(map[int]int), -// } -// } - -// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { -// if c.config == nil { -// var config ml.CacheConfig -// if cc, ok := backend.(ml.BackendCacheConfig); ok { -// config = cc.CacheConfig() -// } -// c.config = &config -// } - -// if c.config.CachePadding == 0 { -// c.config.CachePadding = 1 -// } - -// if c.config.MaskBatchPadding == 0 { -// c.config.MaskBatchPadding = 1 -// } - -// // TODO what types do we handle here? -// // if c.config.MaskDType == ml.DTypeOther { -// // c.config.MaskDType = ml.DTypeFloat32 -// // } - -// if c.swaWindowSize == 0 { -// c.swaWindowSize = math.MaxInt32 -// } -// if c.swaMemorySize == 0 { -// c.swaMemorySize = c.swaWindowSize -// } -// // We will allocate space in the cache for the stop token, which won't be part of a follow on -// // sequence, so allocate an extra token of storage to ensure that we can jump back without -// // causing a cache break. As an optimization, only do this when we have parallel sequences -// // because the extra token will live in the batch buffer and won't get overwritten if we -// // only have a single sequence. -// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 { -// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1) -// } -// if int(c.swaMemorySize) >= capacity { -// c.swaMemorySize = math.MaxInt32 -// } - -// if c.swaMemorySize < c.swaWindowSize { -// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize)) -// } - -// var cacheSize int -// if c.swaMemorySize == math.MaxInt32 { -// cacheSize = maxSequences * capacity -// } else { -// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch -// } -// cacheSize = roundUp(cacheSize, c.config.CachePadding) -// c.cells = make([]cacheCell, cacheSize) - -// c.DType = dtype -// c.cellRanges = make(map[int]cellRange) -// c.backend = backend -// c.maxBatch = maxBatch -// } - -// func (c *Causal) SetConfig(config ml.CacheConfig) { -// if c.config != nil { -// panic("config cannot be changed after being previously set, either by the model or backend") -// } - -// c.config = &config -// } - -// func (c *Causal) Close() { -// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs)) -// for _, ctx := range c.ctxs { -// ctx.Close() -// } -// } - -// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { -// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch) -// // panic("XXX Causal.StartForward") -// c.curBatchSize = len(batch.Positions) -// c.curSequences = batch.Sequences -// c.curPositions = batch.Positions -// c.opts.Except = nil - -// var locs []int32 -// if !reserve { -// c.updateSlidingWindow() - -// var err error -// locs, err = c.findLocs() -// if err != nil { -// return err -// } -// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs)) - -// for i, pos := range batch.Positions { -// seq := batch.Sequences[i] -// loc := int(locs[i]) - -// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}} - -// seqRange, ok := c.cellRanges[seq] -// if !ok { -// seqRange = newRange() -// } - -// seqRange.min = min(seqRange.min, loc) -// c.curCellRange.min = min(c.curCellRange.min, loc) - -// seqRange.max = max(seqRange.max, loc) -// c.curCellRange.max = max(c.curCellRange.max, loc) - -// c.cellRanges[seq] = seqRange -// } -// } else { -// // If we are reserving memory, don't update any of the cache metadata but set the size -// // to the worst case. -// locs = make([]int32, c.curBatchSize) -// for i := range locs { -// locs[i] = int32(i) -// } -// c.curCellRange.min = 0 -// c.curCellRange.max = len(c.cells) - 1 -// } - -// // XXX Building up the locs for what's already processed (if any) -// dummyLocs := []int{} -// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) -// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 - -// for i := range c.curBatchSize { -// enabled := !slices.Contains(c.opts.Except, i) -// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { -// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || -// (enabled && c.cells[j].pos > c.curPositions[i]) || -// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize || -// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize { -// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) -// } else { -// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i { -// dummyLocs = append(dummyLocs, i) -// } -// } -// } -// } -// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs) - -// slog.Info("XXX Causal.StartForward", "locs", locs) -// c.curLoc = ctx.Input().FromInts(locs, len(locs)) -// c.curMask = c.buildMask(ctx) - -// return nil -// } - -// func newRange() cellRange { -// return cellRange{ -// min: math.MaxInt, -// max: 0, -// } -// } - -// // Returns a slice of locations where each token in the batch should be stored -// func (c *Causal) findLocs() ([]int32, error) { -// loc := make([]int32, 0, c.curBatchSize) - -// for i := range c.cells { -// if len(c.cells[i].sequences) == 0 { -// loc = append(loc, int32(i)) -// if len(loc) >= c.curBatchSize { -// return loc, nil -// } -// } -// } - -// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize) -// } - -// func (c *Causal) updateSlidingWindow() { -// c.curCellRange = newRange() - -// if c.swaMemorySize == math.MaxInt32 { -// for _, seq := range c.curSequences { -// if seqRange, ok := c.cellRanges[seq]; ok { -// c.curCellRange.min = min(c.curCellRange.min, seqRange.min) -// c.curCellRange.max = max(c.curCellRange.max, seqRange.max) -// } -// } - -// return -// } - -// type lowestPosition struct { -// pos int32 -// curBatch bool -// } - -// // create a map of unique sequences to the lowest position in that sequence -// lowestPos := make(map[int]lowestPosition) -// for i := range c.curPositions { -// seq := c.curSequences[i] - -// lowest, ok := lowestPos[seq] -// if !ok { -// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true} -// } else if c.curPositions[i] < lowest.pos { -// lowest.pos = c.curPositions[i] -// } - -// lowestPos[seq] = lowest -// } - -// // for any sequences are not part of this batch, clean up any tokens -// // that are no longer needed after the processing of the previous -// // batch -// for seq, seqRange := range c.cellRanges { -// if _, ok := lowestPos[seq]; !ok { -// var last int32 -// for i := seqRange.min; i <= seqRange.max; i++ { -// if slices.Contains(c.cells[i].sequences, seq) { -// last = max(last, c.cells[i].pos) -// } -// } - -// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false} -// } -// } - -// // delete any entries that are beyond the window of the oldest position in the sequence -// for seq, lowest := range lowestPos { -// oldRange, ok := c.cellRanges[seq] -// if !ok { -// continue -// } - -// newRange := newRange() - -// for i := oldRange.min; i <= oldRange.max; i++ { -// if slices.Contains(c.cells[i].sequences, seq) { -// if c.cells[i].pos < lowest.pos-c.swaMemorySize { -// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) -// } else { -// newRange.min = min(newRange.min, i) -// newRange.max = max(newRange.max, i) -// } -// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize { -// c.curCellRange.min = min(c.curCellRange.min, i) -// c.curCellRange.max = max(c.curCellRange.max, i) -// } -// } -// } - -// c.cellRanges[seq] = newRange -// } -// } - -// func roundDown(length, pad int) int { -// return (length / pad) * pad -// } - -// func roundUp(length, pad int) int { -// return ((length + pad - 1) / pad) * pad -// } - -// // Builds a mask of history x batch indicating whether for each token in the batch the -// // token in the history should apply. This is based on both the sequence and causality (the -// // position of the history is not ahead of the token in the batch). -// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { -// // Align and pad the two dimensions as required by the backend -// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) - -// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) -// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 - -// length := c.curCellRange.max - c.curCellRange.min + 1 - -// mask := make([]float32, batchSize*length) - -// for i := range c.curBatchSize { -// enabled := !slices.Contains(c.opts.Except, i) -// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { -// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || -// (enabled && c.cells[j].pos > c.curPositions[i]) || -// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize || -// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize { -// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) -// } -// } -// } - -// // Mask out any padding tokens we added. For padding that we added to the cache history, this -// // has already been masked out because the sequence doesn't match. -// for i := c.curBatchSize * length; i < len(mask); i++ { -// mask[i] = float32(math.Inf(-1)) -// } - -// maskTensor := ctx.Input().FromFloats(mask, batchSize, length) - -// // if c.config.MaskDType != ml.DTypeFloat32 { -// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType) -// // } - -// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length}) - -// return maskTensor -// } - -// func (c *Causal) SetLayer(layer int) { -// c.curLayer = layer -// } - -// type CausalOptions struct { -// // Enabled controls whether the causal mask is generated for a particular index in a batch -// Except []int -// } - -// // SetCausal disables causal mask generation for a particular range of indicies in -// // the current batch for subsequent calls to Get. The state resets for the next forward pass. -// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { -// if !slices.Equal(c.opts.Except, opts.Except) { -// c.opts = opts -// if ctx != nil { -// c.curMask = c.buildMask(ctx) -// } -// } -// } - -// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { -// key := c.keys[c.curLayer] -// value := c.values[c.curLayer] - -// kHeadDim := c.kHeadDims[c.curLayer] -// vHeadDim := c.vHeadDims[c.curLayer] -// numKVHeads := c.numKVHeads[c.curLayer] -// // rowSize := numKVHeads * c.curBatchSize -// // cachedSize := c.curMask.Dim(1) -// cachedSize := c.curLoc.Dim(0) -// // kCellSize := kHeadDim * numKVHeads -// // vCellSize := vHeadDim * numKVHeads - -// slog.Info("XXX Causal.Get full cache", "key", key) -// slog.Info("XXX Causal.Get full cache", "value", value) -// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc) -// slog.Info("XXX Causal.Get", "curMask", c.curMask) -// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim) -// // panic("XXX") - -// // fmt.Fprintln(os.Stderr, key.ToString()) -// // panic("full cache value") - -// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask -// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim) -// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min) - -// // slog.Info("XXX Causal.Get after AsStrided", "key", key) -// // panic("XXX") - -// // if c.config.PermutedV { -// // panic("permuted") -// // // TODO not converted -// // vHeadDim := value.Dim(1) -// // elemSize := value.Stride(2) - -// // value = value.AsStrided(ctx, -// // []int{numKVHeads, vHeadDim, cachedSize}, -// // []int{value.Stride(0), value.Stride(1)}, -// // elemSize*c.curCellRange.min, -// // ) -// // } else { -// // vHeadDim := c.vHeadDims[c.curLayer] -// // rowSize := value.Stride(2) -// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize) -// // panic("XXX") - -// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask -// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim) -// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min) - -// // slog.Info("XXX Causal.Get after AsStrided", "value", value) -// // panic("XXX") - -// // } - -// // // TODO The mask changes from X,X to 1,X, and with the Row-order change -// // // the 1 becomes trailing and messes up later operations -// // // This isn't the right solution, but works around it... -// // if c.curMask.Dim(1) == 1 { -// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3) -// // } -// // fmt.Fprintln(os.Stderr, key.ToString()) -// // fmt.Fprintln(os.Stderr, value.ToString()) -// // panic("XXX") -// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape()) - -// return key, value, c.curMask -// } - -// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { -// kHeadDim := key.Dim(3) -// vHeadDim := value.Dim(3) -// numKVHeads := key.Dim(1) -// batchSize := key.Dim(2) -// kCellSize := kHeadDim * numKVHeads -// vCellSize := vHeadDim * numKVHeads - -// // slog.Info("XXX Causal.Put", "key", key, "value", value) -// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize) -// // panic("XXX") - -// if c.curBatchSize != batchSize { -// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) -// } - -// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend) -// if _, ok := c.ctxs[c.curLayer]; !ok { -// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer) -// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) -// } - -// if _, ok := c.keys[c.curLayer]; !ok { -// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize}) - -// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize) -// c.kHeadDims[c.curLayer] = kHeadDim -// c.vHeadDims[c.curLayer] = vHeadDim -// c.numKVHeads[c.curLayer] = numKVHeads -// } - -// if _, ok := c.values[c.curLayer]; !ok { -// // if c.config.PermutedV { -// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells)) -// // } else { -// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize) -// // } -// } - -// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed - -// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache) -// // panic("XXX") -// // curLoc := 0 // TODO c.curLoc is now a tensor -// // kSize := numKVHeads * kHeadDim -// // vSize := numKVHeads * vHeadDim -// // start := []int{int(curLoc), 0} -// // kStop := []int{int(curLoc + batchSize), int(kSize)} -// // vStop := []int{int(curLoc + batchSize), int(vSize)} -// // strides := []int{1, 1} - -// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache) -// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key) - -// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides) - -// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides)) -// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0})) -// // fmt.Fprintln(os.Stderr, keyCache.ToString()) -// // panic("input value") - -// // fmt.Fprintln(os.Stderr, t.ToString()) -// // panic("XXX") - -// // if c.config.PermutedV { -// // panic("permuted") -// // // TODO not adjusted -// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize) -// // value = value.Transpose(ctx, 2, 0, 1, 3) - -// // valueCache := c.values[c.curLayer] -// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads) - -// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides)) -// // } else { -// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed -// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache) -// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value) -// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides) - -// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0})) -// // } -// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString()) -// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString()) -// // panic("XXX") - -// } - -// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { -// seqRange := newRange() - -// for i := range c.cells { -// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end -// if slices.Contains(c.cells[i].sequences, dstSeq) { -// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq }) -// } - -// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len { -// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq) -// if i < seqRange.min { -// seqRange.min = i -// } -// if i > seqRange.max { -// seqRange.max = i -// } -// } -// } - -// c.cellRanges[dstSeq] = seqRange -// } - -// func (c *Causal) CanResume(seq int, pos int32) bool { -// if c.swaMemorySize == math.MaxInt32 { -// return true -// } - -// seqRange, ok := c.cellRanges[seq] -// if !ok { -// return false -// } - -// // for sliding window, check that the window of the new sequence is contained in -// // the window of what we are storing -// var first int32 = math.MaxInt32 -// var last int32 = -1 -// for i := seqRange.min; i <= seqRange.max; i++ { -// if slices.Contains(c.cells[i].sequences, seq) { -// first = min(first, c.cells[i].pos) -// last = max(last, c.cells[i].pos) -// } -// } - -// if last == -1 { -// return false -// } - -// posWindowStart := max(0, pos-c.swaWindowSize) -// return posWindowStart >= first && pos <= last+1 -// } - -// func (c *Causal) shift(seq int, beginIndex, offset int32) error { -// if c.shiftFn == nil { -// return ErrNotSupported -// } - -// seqRange := c.cellRanges[seq] - -// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { -// size := min(seqRange.max-start+1, c.maxBatch) -// offsets := make([]int32, size) - -// var batchFirst, batchLast int - -// batchFirst = -1 -// for i := range offsets { -// cell := c.cells[start+i] - -// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { -// offsets[i] = offset -// if batchFirst < 0 { -// batchFirst = i -// } -// batchLast = i -// } -// } - -// if batchFirst < 0 { -// continue -// } - -// offsets = offsets[batchFirst : batchLast+1] - -// slog.Info("XXX Causal.shift creating new temporary context") -// ctx := c.backend.NewContext() -// kShift := ctx.Input().FromInts(offsets, len(offsets)) - -// for i, key := range c.keys { -// if key == nil { -// continue -// } - -// kHeadDim := key.Dim(2) -// numKVHeads := key.Dim(1) -// rowSize := key.Stride(0) - -// key = key.AsStrided(ctx, -// []int{len(offsets), numKVHeads, kHeadDim}, -// []int{key.Stride(0), key.Stride(1)}, -// rowSize*(start+batchFirst), -// ) - -// roped, err := c.shiftFn(ctx, i, key, kShift) -// if err != nil { -// ctx.Close() -// return err -// } - -// ctx.Forward(roped.Copy(ctx, key)) -// } - -// ctx.Compute() -// ctx.Close() -// } - -// return nil -// } - -// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { -// // TODO(jessegross): We should check to see if removing the middle of the sequence will -// // cause the sliding window to encompass tokens that we no longer have. If so, then we -// // should return an error, which will trigger the runner to evaluate the full history and -// // rebuild the window. However, if we have multimodal inputs in our history, this reuse -// // results in use after free, so we don't do it for now. - -// var offset int32 -// if endIndex != math.MaxInt32 { -// offset = beginIndex - endIndex -// } - -// seqRange := newRange() - -// for i := range c.cells { -// if slices.Contains(c.cells[i].sequences, seq) { -// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex { -// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) -// } else { -// if c.cells[i].pos >= endIndex { -// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { -// return errors.New("shifting cells shared by multiple sequences not supported") -// } - -// c.cells[i].pos += offset -// } -// if i < seqRange.min { -// seqRange.min = i -// } -// if i > seqRange.max { -// seqRange.max = i -// } -// } -// } -// } - -// if seqRange == newRange() { -// delete(c.cellRanges, seq) -// return nil -// } - -// c.cellRanges[seq] = seqRange - -// if endIndex != math.MaxInt32 { -// err := c.shift(seq, endIndex+offset, offset) -// if err != nil { -// return err -// } -// } - -// return nil -// } +import ( + "github.com/ollama/ollama/x/ml" + "github.com/ollama/ollama/x/model/input" +) + +// Causal cache stores K and V tensors according to their position in the +// sequence. Returns the history and a mask for attending to past tokens +type Causal struct { + DType ml.DType + + // locations for data storage for this batch + curLocPut ml.Tensor + + // locations for data storage for this batch + curLocGet ml.Tensor + + // the active layer for Get and Put + curLayer int + + capacity int + + offset int + + backend ml.Backend + ctxs map[int]ml.Context + keys, values map[int]ml.Tensor + + // TODO is this needed per layer, or will it always be consistent? + kHeadDims, vHeadDims, numKVHeads map[int]int +} + +func NewCausalCache() *Causal { + return &Causal{ + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + kHeadDims: make(map[int]int), + vHeadDims: make(map[int]int), + numKVHeads: make(map[int]int), + } +} + +func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { + c.DType = dtype + c.capacity = capacity + c.backend = backend +} + +func (c *Causal) SetConfig(config ml.CacheConfig) {} + +func (c *Causal) SetLayer(layer int) { + c.curLayer = layer +} + +func (c *Causal) Close() { + // slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs)) + for _, ctx := range c.ctxs { + ctx.Close() + } +} + +func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { + locsPut := make([]int32, len(batch.Positions)) + for i := c.offset; i < len(batch.Positions); i++ { + locsPut[i-c.offset] = int32(i) + } + c.offset += len(batch.Positions) + locsGet := make([]int32, c.offset) + for i := range c.offset { + locsGet[i] = int32(i) + } + c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet)) + c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut)) + // slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet) + + return nil +} +func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { + kHeadDim := key.Dim(3) + vHeadDim := value.Dim(3) + numKVHeads := key.Dim(1) + batchSize := key.Dim(2) + kCellSize := kHeadDim * numKVHeads + vCellSize := vHeadDim * numKVHeads + // slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize) + + if _, ok := c.ctxs[c.curLayer]; !ok { + // slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer) + c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) + } + + if _, ok := c.keys[c.curLayer]; !ok { + // slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize}) + c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize) + c.kHeadDims[c.curLayer] = kHeadDim + c.vHeadDims[c.curLayer] = vHeadDim + c.numKVHeads[c.curLayer] = numKVHeads + } + key = key.Reshape(ctx, batchSize, 1, kCellSize) + + // slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer]) + // slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut) + // slog.Info("XXX Causal.Put ", "key", key) + ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0})) + value = value.Reshape(ctx, batchSize, 1, vCellSize) + ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0})) + +} + +func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + key := c.keys[c.curLayer] + value := c.values[c.curLayer] + + kHeadDim := c.kHeadDims[c.curLayer] + vHeadDim := c.vHeadDims[c.curLayer] + numKVHeads := c.numKVHeads[c.curLayer] + // rowSize := numKVHeads * c.curBatchSize + // cachedSize := c.curMask.Dim(1) + cachedSize := c.curLocGet.Dim(0) + // kCellSize := kHeadDim * numKVHeads + // vCellSize := vHeadDim * numKVHeads + // slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim}) + + key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim) + value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim) + return key, value, nil +} + +func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { + panic("not implemented") +} + +func (c *Causal) CanResume(seq int, pos int32) bool { + panic("not implemented") +} + +func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { + panic("not implemented") +} diff --git a/x/kvcache/causal_test.go b/x/kvcache/causal_test.go deleted file mode 100644 index d7ac430b1..000000000 --- a/x/kvcache/causal_test.go +++ /dev/null @@ -1,973 +0,0 @@ -package kvcache - -// import ( -// "fmt" -// "math" -// "slices" -// "testing" - -// "github.com/ollama/ollama/ml" -// "github.com/ollama/ollama/model/input" -// ) - -// type testCase struct { -// name string -// in []float32 -// inShape []int -// seqs []int -// pos []int32 -// expected []float32 -// expectedShape []int -// expectedMask []float32 -// } - -// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) { -// t.Helper() -// for _, permuted := range []bool{false, true} { -// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) { -// fn(t, &testBackend{permutedV: permuted}) -// }) -// } -// } - -// func TestStore(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewCausalCache(nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, -// inShape: []int{2, 3, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, -// expectedShape: []int{2, 3, 4}, -// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, -// }, -// { -// name: "SecondBatch", -// in: []float32{115, 215, 125, 225, 135, 235}, -// inShape: []int{2, 3, 1}, -// seqs: []int{0}, -// pos: []int32{4}, -// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, -// expectedShape: []int{2, 3, 5}, -// expectedMask: []float32{0, 0, 0, 0, 0}, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestSWA(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewSWACache(1, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// x := float32(math.Inf(-1)) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, x, -// 0, 0, x, x, -// x, 0, 0, x, -// x, x, 0, 0, -// }, -// }, -// { -// name: "SecondBatch", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{4, 5}, -// expected: []float32{5, 6, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, 0, -// 0, 0, x, x, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestSWASeparateBatches(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewSWACache(1, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 2, 16, 2) - -// x := float32(math.Inf(-1)) - -// tests := []testCase{ -// { -// name: "First seq 0", -// in: []float32{1, 2}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{0, 1}, -// expected: []float32{1, 2}, -// expectedShape: []int{1, 1, 2}, -// expectedMask: []float32{ -// 0, x, -// 0, 0, -// }, -// }, -// { -// name: "Second seq 0", -// in: []float32{3, 4}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{2, 3}, -// expected: []float32{2, 3, 4}, -// expectedShape: []int{1, 1, 3}, -// expectedMask: []float32{ -// 0, 0, x, -// x, 0, 0, -// }, -// }, -// { -// name: "First seq 1", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{1, 1}, -// pos: []int32{0, 1}, -// expected: []float32{5, 6}, -// expectedShape: []int{1, 1, 2}, -// expectedMask: []float32{ -// 0, x, -// 0, 0, -// }, -// }, -// { -// name: "Second seq 1", -// in: []float32{7, 8}, -// inShape: []int{1, 1, 2}, -// seqs: []int{1, 1}, -// pos: []int32{2, 3}, -// expected: []float32{6, 3, 4, 7, 8}, -// expectedShape: []int{1, 1, 5}, -// expectedMask: []float32{ -// 0, x, x, 0, x, -// x, x, x, 0, 0, -// }, -// }, -// { -// name: "Third seq 0", -// in: []float32{9, 10}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{4, 5}, -// expected: []float32{9, 10, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, 0, -// 0, 0, x, x, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestSWAMem(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewSWAMemCache(1, 3, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// x := float32(math.Inf(-1)) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, x, -// 0, 0, x, x, -// x, 0, 0, x, -// x, x, 0, 0, -// }, -// }, -// { -// name: "SecondBatch", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{4, 5}, -// expected: []float32{5, 2, 3, 4, 6}, -// expectedShape: []int{1, 1, 5}, -// expectedMask: []float32{ -// 0, x, x, 0, x, -// 0, x, x, x, 0, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestChunkedAttention(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewChunkedAttentionCache(2, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// x := float32(math.Inf(-1)) - -// testCache( -// t, backend, cache, -// []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, x, -// 0, 0, x, x, -// x, x, 0, x, -// x, x, 0, 0, -// }, -// }, -// { -// name: "SecondBatch", -// in: []float32{5, 6, 7}, -// inShape: []int{1, 1, 3}, -// seqs: []int{0, 0, 0}, -// pos: []int32{4, 5, 6}, -// expected: []float32{1, 2, 3, 4, 5, 6, 7}, -// expectedShape: []int{1, 1, 7}, -// expectedMask: []float32{ -// x, x, x, x, 0, x, x, -// x, x, x, x, 0, 0, x, -// x, x, x, x, x, x, 0, -// }, -// }, -// { -// name: "ThirdBatch", -// in: []float32{8, 9}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{7, 8}, -// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, -// expectedShape: []int{1, 1, 9}, -// expectedMask: []float32{ -// x, x, x, x, x, x, 0, 0, x, -// x, x, x, x, x, x, x, x, 0, -// }, -// }, -// }, -// ) -// }) -// } - -// func TestSequences(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewCausalCache(nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 1, 1}, -// pos: []int32{0, 1, 0, 1}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, -// }, -// { -// name: "SecondBatch", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 1}, -// pos: []int32{2, 2}, -// expected: []float32{1, 2, 3, 4, 5, 6}, -// expectedShape: []int{1, 1, 6}, -// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestRemove(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { -// return key.Add(ctx, shift), nil -// }) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// x := float32(math.Inf(-1)) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 1, 1}, -// pos: []int32{0, 1, 0, 1}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, x, -// 0, 0, x, x, -// x, x, 0, x, -// x, x, 0, 0, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) - -// err := cache.Remove(0, 1, math.MaxInt32) -// if err != nil { -// panic(err) -// } - -// tests = []testCase{ -// { -// name: "RemoveEnd", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 1}, -// pos: []int32{1, 2}, -// expected: []float32{1, 5, 3, 4, 6}, -// expectedShape: []int{1, 1, 5}, -// expectedMask: []float32{ -// 0, 0, x, x, x, -// x, x, 0, 0, 0, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) - -// err = cache.Remove(0, 0, 1) -// if err != nil { -// panic(err) -// } - -// tests = []testCase{ -// { -// name: "RemoveMiddle", -// in: []float32{7, 8}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{1, 2}, -// expected: []float32{7, 4, 3, 4, 6, 8}, -// expectedShape: []int{1, 1, 6}, -// expectedMask: []float32{ -// 0, 0, x, x, x, x, -// 0, 0, x, x, x, 0, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestCopy(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, -// }, -// } - -// testCache(t, backend, cache, tests) - -// cache.CopyPrefix(0, 1, 2) - -// tests = []testCase{ -// { -// name: "Copy", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{1, 1}, -// pos: []int32{3, 4}, -// expected: []float32{1, 2, 3, 4, 5, 6}, -// expectedShape: []int{1, 1, 6}, -// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) { -// for _, test := range tests { -// t.Run(test.name, func(t *testing.T) { -// context := backend.NewContext() -// defer context.Close() - -// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false) -// if err != nil { -// panic(err) -// } - -// cache.SetLayer(0) -// tensor := context.FromFloats(test.in, test.inShape...) -// cache.Put(context, tensor, tensor) - -// out, _, mask := cache.Get(context) - -// context.Forward(out, mask).Compute(out, mask) - -// if !slices.Equal(out.Floats(), test.expected) { -// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected) -// } - -// if !slices.Equal(out.Shape(), test.expectedShape) { -// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape) -// } - -// if !slices.Equal(mask.Floats(), test.expectedMask) { -// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask) -// } -// }) -// } -// } - -// func TestCanResume(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// windowSize := int32(4) -// cache := NewSWACache(windowSize, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// context := backend.NewContext() -// defer context.Close() - -// err := cache.StartForward(context, input.Batch{ -// Positions: []int32{0, 1, 2, 3, 4}, -// Sequences: []int{0, 0, 0, 0, 0}, -// }, false) -// if err != nil { -// t.Fatalf("StartForward failed: %v", err) -// } - -// cache.SetLayer(0) -// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5) -// cache.Put(context, tensor, tensor) - -// // with window size 4, nothing has slid out of the window yet -// if !cache.CanResume(0, 0) { -// t.Errorf("CanResume(0, 0) = false, want true (within window)") -// } -// if !cache.CanResume(0, 1) { -// t.Errorf("CanResume(0, 1) = false, want true (within window)") -// } -// if !cache.CanResume(0, 2) { -// t.Errorf("CanResume(0, 2) = false, want true (within window)") -// } -// if !cache.CanResume(0, 3) { -// t.Errorf("CanResume(0, 3) = false, want true (latest position)") -// } -// if !cache.CanResume(0, 4) { -// t.Errorf("CanResume(0, 4) = false, want true (latest position)") -// } - -// // shift window by adding position 5 -// err = cache.StartForward(context, input.Batch{ -// Positions: []int32{5}, -// Sequences: []int{0}, -// }, false) -// if err != nil { -// t.Fatalf("StartForward failed: %v", err) -// } - -// cache.SetLayer(0) -// tensor = context.FromFloats([]float32{6}, 1, 1, 1) -// cache.Put(context, tensor, tensor) - -// // only the latest position has overlapping windows -// if cache.CanResume(0, 0) { -// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") -// } -// if cache.CanResume(0, 1) { -// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") -// } -// if cache.CanResume(0, 2) { -// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") -// } -// if cache.CanResume(0, 3) { -// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") -// } -// if cache.CanResume(0, 4) { -// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") -// } -// if !cache.CanResume(0, 5) { -// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") -// } -// }) -// } - -// func TestCanResumeSWAMem(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// windowSize := int32(4) -// memSize := int32(5) -// cache := NewSWAMemCache(windowSize, memSize, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// context := backend.NewContext() -// defer context.Close() - -// err := cache.StartForward(context, input.Batch{ -// Positions: []int32{0, 1, 2, 3, 4, 5, 6}, -// Sequences: []int{0, 0, 0, 0, 0, 0, 0}, -// }, false) -// if err != nil { -// t.Fatalf("StartForward failed: %v", err) -// } - -// cache.SetLayer(0) -// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) -// cache.Put(context, tensor, tensor) - -// // shift window by adding position 7 -// err = cache.StartForward(context, input.Batch{ -// Positions: []int32{7}, -// Sequences: []int{0}, -// }, false) -// if err != nil { -// t.Fatalf("StartForward failed: %v", err) -// } - -// cache.SetLayer(0) -// tensor = context.FromFloats([]float32{8}, 1, 1, 1) -// cache.Put(context, tensor, tensor) - -// // only the latest position has overlapping windows -// if cache.CanResume(0, 0) { -// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") -// } -// if cache.CanResume(0, 1) { -// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") -// } -// if cache.CanResume(0, 2) { -// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") -// } -// if cache.CanResume(0, 3) { -// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") -// } -// if cache.CanResume(0, 4) { -// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") -// } -// if cache.CanResume(0, 5) { -// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)") -// } -// if !cache.CanResume(0, 6) { -// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)") -// } -// if !cache.CanResume(0, 7) { -// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)") -// } -// }) -// } - -// type testBackend struct { -// ml.Backend -// permutedV bool -// } - -// func (b *testBackend) NewContext() ml.Context { -// return &testContext{} -// } - -// func (b *testBackend) NewContextSize(int) ml.Context { -// return &testContext{} -// } - -// func (b *testBackend) CacheConfig() ml.CacheConfig { -// return ml.CacheConfig{PermutedV: b.permutedV} -// } - -// type testContext struct { -// ml.Context -// } - -// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { -// total := 0 - -// if len(shape) > 0 { -// total = 1 -// for _, s := range shape { -// total *= s -// } -// } - -// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} -// } - -// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { -// return c.Empty(dtype, shape...) -// } - -// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor { -// t := c.Empty(ml.DTypeF32, shape...).(*testTensor) - -// copy(t.data, s) - -// return t -// } - -// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor { -// f := make([]float32, len(s)) -// for i := range f { -// f[i] = float32(s[i]) -// } - -// out := c.FromFloats(f, shape...) -// out.(*testTensor).dtype = ml.DTypeI32 - -// return out -// } - -// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { -// s := make([]float32, 0, int((stop-start)/step)) -// for i := start; i < stop; i += step { -// s = append(s, i) -// } - -// out := c.FromFloats(s, len(s)) -// out.(*testTensor).dtype = dtype -// return out -// } - -// func (c *testContext) Input() ml.Context { return c } -// func (c *testContext) Layer(int) ml.Context { return c } - -// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } - -// func (c *testContext) Compute(...ml.Tensor) {} - -// func (c *testContext) Reserve() {} - -// func (c *testContext) MaxGraphNodes() int { -// return 10 -// } - -// func (c *testContext) Close() {} - -// type testTensor struct { -// ml.Tensor - -// dtype ml.DType -// elementSize int -// data []float32 -// shape []int -// } - -// func (t *testTensor) Dim(n int) int { -// return t.shape[n] -// } - -// func (t *testTensor) Stride(n int) int { -// stride := t.elementSize -// for i := range n { -// stride *= t.shape[i] -// } - -// return stride -// } - -// func (t *testTensor) Shape() []int { -// return t.shape -// } - -// func (t *testTensor) DType() ml.DType { -// return t.dtype -// } - -// func (t *testTensor) Floats() []float32 { -// out := make([]float32, len(t.data)) -// copy(out, t.data) -// return out -// } - -// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor { -// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) -// for i := range out.data { -// out.data[i] = -t.data[i] -// } -// return out -// } - -// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { -// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) - -// for i := range out.data { -// out.data[i] = t.data[i] + t2.(*testTensor).data[i] -// } - -// return out -// } - -// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { -// return &testTensor{ -// dtype: t.dtype, -// elementSize: t.elementSize, -// data: t.data, -// shape: shape, -// } -// } - -// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { -// offset /= t.elementSize - -// var s []int - -// switch len(shape) { -// case 1: -// s = []int{shape[0]} -// case 3: -// s = []int{shape[0], shape[2]} -// case 5: -// s = []int{shape[0], shape[2], shape[4]} -// default: -// panic("unsupported number of dimensions") -// } - -// context := &testContext{} - -// view := context.Empty(t.dtype, s...).(*testTensor) -// view.data = t.data[offset : offset+len(view.data)] - -// return view -// } - -// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor { -// if len(t.shape) > 4 || len(order) > 4 { -// panic("permute only supports up to 4 dimensions") -// } - -// if len(order) != len(t.shape) && len(order) != 4 { -// panic("invalid number of dimensions for permute") -// } - -// // ggml_permute expects 4 axes, so fill in any missing dimensions. -// orderFull := append(make([]int, 0, 4), order...) -// for len(orderFull) < 4 { -// orderFull = append(orderFull, len(orderFull)) -// } - -// seen := [4]bool{} - -// shape4 := [4]int{1, 1, 1, 1} -// for i := 0; i < len(t.shape) && i < 4; i++ { -// shape4[i] = t.shape[i] -// } - -// newShape4 := [4]int{1, 1, 1, 1} -// for axis := range 4 { -// dst := orderFull[axis] -// if dst < 0 || dst >= 4 { -// panic("invalid axis for permute") -// } -// if seen[dst] { -// panic("duplicate axis for permute") -// } -// seen[dst] = true -// newShape4[dst] = shape4[axis] -// } - -// total := len(t.data) -// newData := make([]float32, total) - -// if total > 0 { -// oldDims := shape4 -// newDims := newShape4 - -// oldStride := [4]int{1, 1, 1, 1} -// newStride := [4]int{1, 1, 1, 1} -// for i := 1; i < 4; i++ { -// oldStride[i] = oldStride[i-1] * oldDims[i-1] -// newStride[i] = newStride[i-1] * newDims[i-1] -// } - -// var coords [4]int -// var newCoords [4]int - -// for idx := range total { -// remainder := idx -// for axis := range 4 { -// dim := oldDims[axis] -// if dim == 0 { -// coords[axis] = 0 -// continue -// } -// coords[axis] = remainder % dim -// remainder /= dim -// } - -// for axis := range 4 { -// newCoords[orderFull[axis]] = coords[axis] -// } - -// newIndex := 0 -// for axis := range 4 { -// if newDims[axis] == 0 { -// continue -// } -// newIndex += newCoords[axis] * newStride[axis] -// } - -// newData[newIndex] = t.data[idx] -// } -// } - -// numDims := 4 -// for numDims > 1 && newShape4[numDims-1] <= 1 { -// numDims-- -// } - -// newShape := make([]int, numDims) -// copy(newShape, newShape4[:numDims]) - -// return &testTensor{ -// dtype: t.dtype, -// elementSize: t.elementSize, -// data: newData, -// shape: newShape, -// } -// } - -// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor { -// dst := t -// srcTensor := src.(*testTensor) -// idxTensor := idxs.(*testTensor) - -// shapeTo4D := func(shape []int) [4]int { -// out := [4]int{1, 1, 1, 1} -// for i := 0; i < len(shape) && i < 4; i++ { -// out[i] = shape[i] -// } -// return out -// } - -// computeStrides := func(shape [4]int) [4]int { -// out := [4]int{1, 1, 1, 1} -// for i := 1; i < 4; i++ { -// out[i] = out[i-1] * shape[i-1] -// } -// return out -// } - -// dstShape4D := shapeTo4D(dst.shape) -// srcShape4D := shapeTo4D(srcTensor.shape) -// idxShape4D := shapeTo4D(idxTensor.shape) - -// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] { -// panic("SetRows requires matching tensor shapes") -// } - -// if srcShape4D[1] != idxShape4D[0] { -// panic("SetRows rows/index mismatch") -// } - -// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 { -// panic("SetRows cannot broadcast indices") -// } - -// if idxShape4D[3] != 1 { -// panic("SetRows expects 1D or 2D index tensors") -// } - -// dstStride := computeStrides(dstShape4D) -// srcStride := computeStrides(srcShape4D) -// idxStride := computeStrides(idxShape4D) - -// numColumns := srcShape4D[0] -// numRows := srcShape4D[1] - -// for dim3Index := range dstShape4D[3] { -// for dim2Index := range dstShape4D[2] { -// idxDim2 := 0 -// idxDim3 := 0 -// if idxShape4D[1] > 0 { -// idxDim2 = dim2Index % idxShape4D[1] -// } -// if idxShape4D[2] > 0 { -// idxDim3 = dim3Index % idxShape4D[2] -// } - -// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1] -// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2] -// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2] - -// for row := range numRows { -// idx := int(idxTensor.data[idxBase+row*idxStride[0]]) -// if idx < 0 || idx >= dstShape4D[1] { -// panic("SetRows index out of range") -// } - -// srcOffset := srcBase + row*srcStride[1] -// dstOffset := dstBase + idx*dstStride[1] - -// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns]) -// } -// } -// } - -// return dst -// } - -// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { -// copy(t2.(*testTensor).data, t.data) -// return nil -// } diff --git a/x/kvcache/mlx.go b/x/kvcache/mlx.go deleted file mode 100644 index fa3865104..000000000 --- a/x/kvcache/mlx.go +++ /dev/null @@ -1,144 +0,0 @@ -//go:build mlx - -package kvcache - -import ( - "github.com/ollama/ollama/x/ml" - "github.com/ollama/ollama/x/model/input" -) - -// Causal cache stores K and V tensors according to their position in the -// sequence. Returns the history and a mask for attending to past tokens -type MLXCausal struct { - DType ml.DType - - // locations for data storage for this batch - curLocPut ml.Tensor - - // locations for data storage for this batch - curLocGet ml.Tensor - - // the active layer for Get and Put - curLayer int - - capacity int - - offset int - - backend ml.Backend - ctxs map[int]ml.Context - keys, values map[int]ml.Tensor - - // TODO is this needed per layer, or will it always be consistent? - kHeadDims, vHeadDims, numKVHeads map[int]int -} - -func NewMLXCausalCache() *MLXCausal { - return &MLXCausal{ - ctxs: make(map[int]ml.Context), - keys: make(map[int]ml.Tensor), - values: make(map[int]ml.Tensor), - kHeadDims: make(map[int]int), - vHeadDims: make(map[int]int), - numKVHeads: make(map[int]int), - } -} - -func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { - c.DType = dtype - c.capacity = capacity - c.backend = backend -} - -func (c *MLXCausal) SetConfig(config ml.CacheConfig) {} - -func (c *MLXCausal) SetLayer(layer int) { - c.curLayer = layer -} - -func (c *MLXCausal) Close() { - // slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs)) - for _, ctx := range c.ctxs { - ctx.Close() - } -} - -func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { - locsPut := make([]int32, len(batch.Positions)) - for i := c.offset; i < len(batch.Positions); i++ { - locsPut[i-c.offset] = int32(i) - } - c.offset += len(batch.Positions) - locsGet := make([]int32, c.offset) - for i := range c.offset { - locsGet[i] = int32(i) - } - c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet)) - c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut)) - // slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet) - - return nil -} -func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) { - kHeadDim := key.Dim(3) - vHeadDim := value.Dim(3) - numKVHeads := key.Dim(1) - batchSize := key.Dim(2) - kCellSize := kHeadDim * numKVHeads - vCellSize := vHeadDim * numKVHeads - // slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize) - - if _, ok := c.ctxs[c.curLayer]; !ok { - // slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer) - c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) - } - - if _, ok := c.keys[c.curLayer]; !ok { - // slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize}) - c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize) - c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize) - c.kHeadDims[c.curLayer] = kHeadDim - c.vHeadDims[c.curLayer] = vHeadDim - c.numKVHeads[c.curLayer] = numKVHeads - } - key = key.Reshape(ctx, batchSize, 1, kCellSize) - - // slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer]) - // slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut) - // slog.Info("XXX MLXCausal.Put ", "key", key) - ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0})) - value = value.Reshape(ctx, batchSize, 1, vCellSize) - ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0})) - -} - -func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { - key := c.keys[c.curLayer] - value := c.values[c.curLayer] - - kHeadDim := c.kHeadDims[c.curLayer] - vHeadDim := c.vHeadDims[c.curLayer] - numKVHeads := c.numKVHeads[c.curLayer] - // rowSize := numKVHeads * c.curBatchSize - // cachedSize := c.curMask.Dim(1) - cachedSize := c.curLocGet.Dim(0) - // kCellSize := kHeadDim * numKVHeads - // vCellSize := vHeadDim * numKVHeads - // slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim}) - - key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim) - value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim) - return key, value, nil -} - -func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) { - panic("not implemented") -} - -func (c *MLXCausal) CanResume(seq int, pos int32) bool { - panic("not implemented") -} - -func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error { - panic("not implemented") -} diff --git a/x/mlxrunner/imagegen.go b/x/mlxrunner/imagegen.go new file mode 100644 index 000000000..b1cdd91df --- /dev/null +++ b/x/mlxrunner/imagegen.go @@ -0,0 +1,134 @@ +//go:build mlx + +package mlxrunner + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/models/flux2" + "github.com/ollama/ollama/x/imagegen/models/zimage" +) + +// ImageModel is the interface for image generation models. +type ImageModel interface { + GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) +} + +var imageGenMu sync.Mutex + +// loadImageModel loads an image generation model. +func (s *server) loadImageModel() error { + // Check memory requirements before loading + var requiredMemory uint64 + if manifest, err := imagegen.LoadManifest(s.modelName); err == nil { + requiredMemory = uint64(manifest.TotalTensorSize()) + } + availableMemory := mlx.GetMemoryLimit() + if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory { + return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB", + requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024)) + } + + // Detect model type and load appropriate model + modelType := imagegen.DetectModelType(s.modelName) + slog.Info("detected image model type", "type", modelType) + + var model ImageModel + switch modelType { + case "Flux2KleinPipeline": + m := &flux2.Model{} + if err := m.Load(s.modelName); err != nil { + return fmt.Errorf("failed to load flux2 model: %w", err) + } + model = m + default: + // Default to Z-Image for ZImagePipeline, FluxPipeline, etc. + m := &zimage.Model{} + if err := m.Load(s.modelName); err != nil { + return fmt.Errorf("failed to load zimage model: %w", err) + } + model = m + } + + s.imageModel = model + return nil +} + +// handleImageCompletion handles image generation requests. +func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, req Request) { + // Serialize generation requests - MLX model may not handle concurrent generation + imageGenMu.Lock() + defer imageGenMu.Unlock() + + // Set seed if not provided + if req.Seed <= 0 { + req.Seed = time.Now().UnixNano() + } + + // Set up streaming response + w.Header().Set("Content-Type", "application/x-ndjson") + w.Header().Set("Transfer-Encoding", "chunked") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return + } + + ctx := r.Context() + enc := json.NewEncoder(w) + + // Progress callback streams step updates + progress := func(step, total int) { + resp := Response{Step: step, Total: total} + enc.Encode(resp) + w.Write([]byte("\n")) + flusher.Flush() + } + + // Generate image + img, err := s.imageModel.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress) + if err != nil { + // Don't send error for cancellation + if ctx.Err() != nil { + return + } + resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true} + data, _ := json.Marshal(resp) + w.Write(data) + w.Write([]byte("\n")) + return + } + + // Encode image as base64 PNG + imageData, err := imagegen.EncodeImageBase64(img) + if err != nil { + resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true} + data, _ := json.Marshal(resp) + w.Write(data) + w.Write([]byte("\n")) + return + } + + // Free the generated image array and clean up MLX state + img.Free() + mlx.ClearCache() + mlx.MetalResetPeakMemory() + + // Send final response with image data + resp := Response{ + Image: imageData, + Done: true, + } + data, _ := json.Marshal(resp) + w.Write(data) + w.Write([]byte("\n")) + flusher.Flush() +} diff --git a/x/mlxrunner/llm.go b/x/mlxrunner/llm.go new file mode 100644 index 000000000..865750573 --- /dev/null +++ b/x/mlxrunner/llm.go @@ -0,0 +1,420 @@ +//go:build mlx + +package mlxrunner + +import ( + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// TextModel is the interface for LLM text generation models. +type TextModel interface { + Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array + NewCache(maxSeqLen int32) []cache.Cache + Tokenizer() *tokenizer.Tokenizer + VocabSize() int32 + MaxContextLength() int32 + NumLayers() int +} + +// llmState holds the state for LLM generation +type llmState struct { + model TextModel +} + +var llmMu sync.Mutex + +// Dedicated stream for generation (like mlx-lm's generation_stream) +var generationStream *mlx.Stream + +// withStream runs fn with the generation stream as default +func withStream(fn func()) { + // Lazy initialization of generationStream + if generationStream == nil { + generationStream = mlx.NewStream() + } + orig := mlx.GetDefaultStream() + mlx.SetDefaultStream(generationStream) + fn() + mlx.SetDefaultStream(orig) +} + +// Decoder wraps model + cache for autoregressive generation. +// This matches the pattern from cmd/engine/generate.go +type Decoder struct { + model TextModel + caches []cache.Cache + vocabSize int32 + temp float32 + token *mlx.Array // Current token (kept across iterations) + oldCacheState []*mlx.Array // Preallocated slice for old cache state +} + +func NewDecoder(m TextModel, temp float32) *Decoder { + caches := m.NewCache(0) + return &Decoder{ + model: m, + caches: caches, + vocabSize: m.VocabSize(), + temp: temp, + oldCacheState: make([]*mlx.Array, 0, len(caches)*2), + } +} + +func (d *Decoder) prefill(inputIDs []int32) int { + processed := 0 + + // Track old cache state to free after each chunk + var oldCacheState []*mlx.Array + + // Process all-but-1 tokens in chunks, eval cache state for memory management + for len(inputIDs) > 1 { + chunkSize := min(2048, len(inputIDs)-1) + if chunkSize <= 0 { + break + } + chunk := inputIDs[:chunkSize] + + // Save old cache state before forward + oldCacheState = oldCacheState[:0] + for _, c := range d.caches { + oldCacheState = append(oldCacheState, c.State()...) + } + + var cacheState []*mlx.Array + withStream(func() { + x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))}) + d.model.Forward(x, d.caches) + for _, c := range d.caches { + cacheState = append(cacheState, c.State()...) + } + }) + mlx.Eval(cacheState...) + + // Free old cache state + for _, arr := range oldCacheState { + if arr != nil { + arr.Free() + } + } + + inputIDs = inputIDs[chunkSize:] + processed += chunkSize + } + + // Save old cache state before final step + oldCacheState = oldCacheState[:0] + for _, c := range d.caches { + oldCacheState = append(oldCacheState, c.State()...) + } + + // Final token + sampling + withStream(func() { + x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))}) + mlx.Eval(x) // Materialize before any other evals + logits := d.model.Forward(x, d.caches) + d.token = sample(logits, d.temp, d.vocabSize) + }) + // Keep cache state (token auto-kept by AsyncEval) + for _, c := range d.caches { + mlx.Keep(c.State()...) + } + mlx.AsyncEval(d.token) + + // Free old cache state from before final step + for _, arr := range oldCacheState { + if arr != nil { + arr.Free() + } + } + + mlx.ClearCache() + + return processed + len(inputIDs) +} + +func (d *Decoder) step() int32 { + prevToken := d.token + + // Save old cache state (reuse preallocated slice) + d.oldCacheState = d.oldCacheState[:0] + for _, c := range d.caches { + d.oldCacheState = append(d.oldCacheState, c.State()...) + } + + withStream(func() { + logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches) + d.token = sample(logits, d.temp, d.vocabSize) + }) + // Keep token and new cache state so they survive cleanup + mlx.Keep(d.token) + for _, c := range d.caches { + mlx.Keep(c.State()...) + } + mlx.AsyncEval(d.token) + + // Sync on previous token (GPU already working on next step) + val := prevToken.ItemInt32() + + // Free old token and old cache state + prevToken.Free() + for _, arr := range d.oldCacheState { + arr.Free() + } + return val +} + +// sample samples from logits using temperature scaling +func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array { + // Get last position logits: [1, L, vocab] -> [vocab] + shape := logits.Shape() + seqLen := shape[1] + lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize}) + lastLogits = mlx.Reshape(lastLogits, vocabSize) + + if temp <= 0 || temp < 0.01 { + // Greedy decoding + return mlx.Argmax(lastLogits, -1, false) + } + + // Apply temperature scaling + scaled := mlx.DivScalar(lastLogits, temp) + return mlx.RandomCategorical(scaled, -1, 1) +} + +// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage. +func (s *server) loadLLMModel() error { + // Load the manifest to get model information + manifest, err := imagegen.LoadManifest(s.modelName) + if err != nil { + return fmt.Errorf("failed to load manifest: %w", err) + } + + // Detect model architecture from config.json + configData, err := manifest.ReadConfig("config.json") + if err != nil { + return fmt.Errorf("failed to read config.json: %w", err) + } + + var modelConfig struct { + Architectures []string `json:"architectures"` + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(configData, &modelConfig); err != nil { + return fmt.Errorf("failed to parse config.json: %w", err) + } + + arch := "" + if len(modelConfig.Architectures) > 0 { + arch = modelConfig.Architectures[0] + } + if arch == "" { + arch = modelConfig.ModelType + } + + slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType) + + // Load the appropriate model based on architecture + var model TextModel + archLower := strings.ToLower(arch) + + switch { + case strings.Contains(archLower, "glm4moelite"): + m, err := glm4_moe_lite.LoadFromManifest(manifest) + if err != nil { + return fmt.Errorf("failed to load glm4-moe-lite model: %w", err) + } + model = m + slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers()) + + default: + return fmt.Errorf("LLM architecture %q is not yet supported. "+ + "Supported architectures: glm4-moe-lite. "+ + "Please convert your model to GGUF format or use a supported architecture", arch) + } + + s.llmModel = &llmState{ + model: model, + } + + return nil +} + +// handleLLMCompletion handles LLM text generation requests. +func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) { + if s.llmModel == nil { + http.Error(w, "LLM model not loaded", http.StatusInternalServerError) + return + } + + // Serialize generation requests + llmMu.Lock() + defer llmMu.Unlock() + + if err := s.llmGenerate(w, r, req); err != nil { + slog.Error("LLM generation failed", "error", err) + // Don't send error if we've already started streaming + } +} + +// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine +func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error { + state := s.llmModel + + // Set up streaming response + w.Header().Set("Content-Type", "application/x-ndjson") + w.Header().Set("Transfer-Encoding", "chunked") + flusher, ok := w.(http.Flusher) + if !ok { + return errors.New("streaming not supported") + } + + tok := state.model.Tokenizer() + + // The prompt is already formatted by the server using the model's renderer + // (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here. + prompt := req.Prompt + + // Tokenize the prompt + inputIDs := tok.Encode(prompt, true) + slog.Debug("tokenized prompt", "num_tokens", len(inputIDs)) + + // Generation parameters + maxTokens := int(state.model.MaxContextLength()) + if maxTokens <= 0 { + maxTokens = 4096 + } + if req.Options != nil && req.Options.NumPredict > 0 { + maxTokens = req.Options.NumPredict + } + + temperature := float32(0.7) + if req.Options != nil && req.Options.Temperature > 0 { + temperature = float32(req.Options.Temperature) + } + + // Enable MLX compilation for better performance + mlx.EnableCompile() + + // Create decoder with fresh caches + dec := NewDecoder(state.model, temperature) + + prefillStart := time.Now() + prefillTokens := dec.prefill(inputIDs) + // Prefill measurement includes time to first token + firstToken := dec.step() + prefillDuration := time.Since(prefillStart) + promptEvalDuration := prefillDuration + + enc := json.NewEncoder(w) + ctx := r.Context() + generated := 0 + stopReason := "max_tokens" + + // Handle first token + generated++ + if tok.IsEOS(firstToken) { + resp := Response{ + Done: true, + StopReason: fmt.Sprintf("first_token_eos:%d", firstToken), + PromptEvalCount: prefillTokens, + PromptEvalDuration: int(promptEvalDuration.Nanoseconds()), + } + enc.Encode(resp) + flusher.Flush() + return nil + } + + text := tok.Decode([]int32{firstToken}) + resp := Response{Content: text} + enc.Encode(resp) + flusher.Flush() + + genStart := time.Now() + + // Generation loop + for n := 1; n < maxTokens; n++ { + // Check for cancellation + select { + case <-ctx.Done(): + stopReason = fmt.Sprintf("context_cancelled:%d", generated) + break + default: + } + if stopReason != "max_tokens" { + break + } + + token := dec.step() + generated++ + + if tok.IsEOS(token) { + stopReason = fmt.Sprintf("eos_token:%d", token) + break + } + + text := tok.Decode([]int32{token}) + + // Check for stop sequences + if req.Options != nil && len(req.Options.Stop) > 0 { + shouldStop := false + var matchedStop string + for _, stop := range req.Options.Stop { + if strings.Contains(text, stop) { + text = strings.Split(text, stop)[0] + shouldStop = true + matchedStop = stop + break + } + } + if shouldStop { + if text != "" { + resp := Response{Content: text} + enc.Encode(resp) + flusher.Flush() + } + stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop) + break + } + } + + resp := Response{Content: text} + enc.Encode(resp) + flusher.Flush() + + // Periodically clear MLX cache + if n%256 == 0 { + mlx.ClearCache() + } + } + + // Clean up + mlx.ClearCache() + + // Send final response with stats + evalDuration := time.Since(genStart) + resp = Response{ + Done: true, + StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated), + PromptEvalCount: prefillTokens, + PromptEvalDuration: int(promptEvalDuration.Nanoseconds()), + EvalCount: generated, + EvalDuration: int(evalDuration.Nanoseconds()), + } + enc.Encode(resp) + flusher.Flush() + + return nil +} diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go new file mode 100644 index 000000000..df0b7eafa --- /dev/null +++ b/x/mlxrunner/runner.go @@ -0,0 +1,204 @@ +//go:build mlx + +// Package mlxrunner provides a unified MLX runner for both LLM and image generation models. +package mlxrunner + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// Execute is the entry point for the unified MLX runner subprocess. +func Execute(args []string) error { + // Set up logging with appropriate level from environment + slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()}))) + + fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError) + modelName := fs.String("model", "", "path to model") + port := fs.Int("port", 0, "port to listen on") + + if err := fs.Parse(args); err != nil { + return err + } + + if *modelName == "" { + return fmt.Errorf("--model is required") + } + if *port == 0 { + return fmt.Errorf("--port is required") + } + + // Initialize MLX + if err := mlx.InitMLX(); err != nil { + slog.Error("unable to initialize MLX", "error", err) + return err + } + slog.Info("MLX library initialized") + + // Detect model type from capabilities + mode := detectModelMode(*modelName) + slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode) + + // Create and start server + server, err := newServer(*modelName, *port, mode) + if err != nil { + return fmt.Errorf("failed to create server: %w", err) + } + + // Set up HTTP handlers + mux := http.NewServeMux() + mux.HandleFunc("/health", server.healthHandler) + mux.HandleFunc("/completion", server.completionHandler) + + // LLM-specific endpoints + if mode == ModeLLM { + mux.HandleFunc("/tokenize", server.tokenizeHandler) + mux.HandleFunc("/embedding", server.embeddingHandler) + } + + httpServer := &http.Server{ + Addr: fmt.Sprintf("127.0.0.1:%d", *port), + Handler: mux, + } + + // Handle shutdown + done := make(chan struct{}) + go func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + <-sigCh + slog.Info("shutting down mlx runner") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + httpServer.Shutdown(ctx) + close(done) + }() + + slog.Info("mlx runner listening", "addr", httpServer.Addr) + if err := httpServer.ListenAndServe(); err != http.ErrServerClosed { + return err + } + + <-done + return nil +} + +// detectModelMode determines whether a model is an LLM or image generation model. +func detectModelMode(modelName string) ModelMode { + // Check for image generation model by looking at model_index.json + modelType := imagegen.DetectModelType(modelName) + if modelType != "" { + // Known image generation model types + switch modelType { + case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline": + return ModeImageGen + } + } + + // Default to LLM mode for safetensors models without known image gen types + return ModeLLM +} + +// server holds the model and handles HTTP requests. +type server struct { + mode ModelMode + modelName string + port int + + // Image generation model (when mode == ModeImageGen) + imageModel ImageModel + + // LLM model (when mode == ModeLLM) + llmModel *llmState +} + +// newServer creates a new server instance and loads the appropriate model. +func newServer(modelName string, port int, mode ModelMode) (*server, error) { + s := &server{ + mode: mode, + modelName: modelName, + port: port, + } + + switch mode { + case ModeImageGen: + if err := s.loadImageModel(); err != nil { + return nil, fmt.Errorf("failed to load image model: %w", err) + } + case ModeLLM: + if err := s.loadLLMModel(); err != nil { + return nil, fmt.Errorf("failed to load LLM model: %w", err) + } + } + + return s, nil +} + +func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) { + resp := HealthResponse{Status: "ok"} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req Request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + switch s.mode { + case ModeImageGen: + s.handleImageCompletion(w, r, req) + case ModeLLM: + s.handleLLMCompletion(w, r, req) + } +} + +func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) { + if s.llmModel == nil { + http.Error(w, "LLM model not loaded", http.StatusInternalServerError) + return + } + + var req struct { + Content string `json:"content"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + tok := s.llmModel.model.Tokenizer() + tokens := tok.Encode(req.Content, false) + + // Convert int32 to int for JSON response + intTokens := make([]int, len(tokens)) + for i, t := range tokens { + intTokens[i] = int(t) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens}) +} + +func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) { + http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented) +} diff --git a/x/imagegen/runner/runner_stub.go b/x/mlxrunner/runner_stub.go similarity index 60% rename from x/imagegen/runner/runner_stub.go rename to x/mlxrunner/runner_stub.go index eafad7bca..3b0f35500 100644 --- a/x/imagegen/runner/runner_stub.go +++ b/x/mlxrunner/runner_stub.go @@ -1,10 +1,10 @@ //go:build !mlx -package runner +package mlxrunner import "errors" // Execute returns an error when not built with MLX support. func Execute(args []string) error { - return errors.New("image generation not available: build with mlx tag") + return errors.New("MLX runner not available: build with mlx tag") } diff --git a/x/imagegen/server.go b/x/mlxrunner/server.go similarity index 58% rename from x/imagegen/server.go rename to x/mlxrunner/server.go index ca9367694..89dd0bf04 100644 --- a/x/imagegen/server.go +++ b/x/mlxrunner/server.go @@ -1,4 +1,4 @@ -package imagegen +package mlxrunner import ( "bufio" @@ -23,19 +23,19 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/x/imagegen" ) -// Server wraps an image generation subprocess to implement llm.LlamaServer. +// Server wraps an MLX runner subprocess to implement llm.LlamaServer. // // This implementation is compatible with Ollama's scheduler and can be loaded/unloaded -// like any other model. The plan is to eventually bring this into the llm/ package -// and evolve llm/ to support MLX and multimodal models. For now, keeping the code -// separate allows for independent iteration on image generation support. +// like any other model. It supports both LLM (safetensors) and image generation models. type Server struct { mu sync.Mutex cmd *exec.Cmd port int modelName string + mode ModelMode vramSize uint64 done chan error client *http.Client @@ -43,10 +43,10 @@ type Server struct { lastErrLock sync.Mutex } -// NewServer spawns a new image generation subprocess and waits until it's ready. -func NewServer(modelName string) (*Server, error) { +// NewServer spawns a new MLX runner subprocess and waits until it's ready. +func NewServer(modelName string, mode ModelMode) (*Server, error) { // Validate platform support before attempting to start - if err := CheckPlatformSupport(); err != nil { + if err := imagegen.CheckPlatformSupport(); err != nil { return nil, err } @@ -71,8 +71,8 @@ func NewServer(modelName string) (*Server, error) { exe = eval } - // Spawn subprocess: ollama runner --image-engine --model --port - cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port)) + // Spawn subprocess: ollama runner --mlx-engine --model --port + cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port)) cmd.Env = os.Environ() // On Linux, set LD_LIBRARY_PATH to include MLX library directories @@ -105,17 +105,21 @@ func NewServer(modelName string) (*Server, error) { slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) } - // Get total weight size from manifest - var weightSize uint64 - if manifest, err := LoadManifest(modelName); err == nil { - weightSize = uint64(manifest.TotalTensorSize()) + // Estimate VRAM based on tensor size from manifest + var vramSize uint64 + if manifest, err := imagegen.LoadManifest(modelName); err == nil { + vramSize = uint64(manifest.TotalTensorSize()) + } else { + // Fallback: default to 8GB if manifest can't be loaded + vramSize = 8 * 1024 * 1024 * 1024 } s := &Server{ cmd: cmd, port: port, modelName: modelName, - vramSize: weightSize, + mode: mode, + vramSize: vramSize, done: make(chan error, 1), client: &http.Client{Timeout: 10 * time.Minute}, } @@ -126,23 +130,23 @@ func NewServer(modelName string) (*Server, error) { go func() { scanner := bufio.NewScanner(stdout) for scanner.Scan() { - slog.Info("image-runner", "msg", scanner.Text()) + slog.Info("mlx-runner", "msg", scanner.Text()) } }() go func() { scanner := bufio.NewScanner(stderr) for scanner.Scan() { line := scanner.Text() - slog.Warn("image-runner", "msg", line) + slog.Warn("mlx-runner", "msg", line) s.lastErrLock.Lock() s.lastErr = line s.lastErrLock.Unlock() } }() - slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port) + slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode) if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start image runner: %w", err) + return nil, fmt.Errorf("failed to start mlx runner: %w", err) } // Reap subprocess when it exits @@ -165,6 +169,7 @@ func (s *Server) ModelPath() string { return s.modelName } +// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment. func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) { return nil, nil } @@ -200,18 +205,18 @@ func (s *Server) waitUntilRunning() error { // Include recent stderr lines for better error context errMsg := s.getLastErr() if errMsg != "" { - return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err) + return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err) } - return fmt.Errorf("image runner exited unexpectedly: %w", err) + return fmt.Errorf("mlx runner exited unexpectedly: %w", err) case <-timeout: errMsg := s.getLastErr() if errMsg != "" { - return fmt.Errorf("timeout waiting for image runner: %s", errMsg) + return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg) } - return errors.New("timeout waiting for image runner to start") + return errors.New("timeout waiting for mlx runner to start") case <-ticker.C: if err := s.Ping(ctx); err == nil { - slog.Info("image runner is ready", "port", s.port) + slog.Info("mlx runner is ready", "port", s.port) return nil } } @@ -225,8 +230,12 @@ func (s *Server) getLastErr() string { return s.lastErr } -func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil } +// WaitUntilRunning satisfies the LlamaServer interface. +func (s *Server) WaitUntilRunning(ctx context.Context) error { + return nil +} +// Completion handles both text and image generation requests. func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { seed := req.Seed if seed == 0 { @@ -240,22 +249,26 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f } // Build request for subprocess - creq := struct { - Prompt string `json:"prompt"` - Width int32 `json:"width,omitempty"` - Height int32 `json:"height,omitempty"` - Steps int32 `json:"steps,omitempty"` - Seed int64 `json:"seed,omitempty"` - Images [][]byte `json:"images,omitempty"` - }{ + creq := Request{ Prompt: req.Prompt, Width: req.Width, Height: req.Height, - Steps: req.Steps, + Steps: int(req.Steps), Seed: seed, Images: images, } + // Pass LLM options if present + if req.Options != nil { + creq.Options = &RequestOptions{ + NumPredict: req.Options.NumPredict, + Temperature: float64(req.Options.Temperature), + TopP: float64(req.Options.TopP), + TopK: req.Options.TopK, + Stop: req.Options.Stop, + } + } + body, err := json.Marshal(creq) if err != nil { return err @@ -282,25 +295,40 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max for scanner.Scan() { - // Parse subprocess response (has singular "image" field) + // Parse subprocess response var raw struct { - Image string `json:"image,omitempty"` - Content string `json:"content,omitempty"` - Done bool `json:"done"` - Step int `json:"step,omitempty"` - Total int `json:"total,omitempty"` + Image string `json:"image,omitempty"` + Content string `json:"content,omitempty"` + Done bool `json:"done"` + Step int `json:"step,omitempty"` + Total int `json:"total,omitempty"` + StopReason string `json:"stop_reason,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration int `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` } if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil { + slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes())) continue } + // Log stop reason when generation completes + if raw.Done && raw.StopReason != "" { + slog.Info("mlx generation completed", "stop_reason", raw.StopReason) + } + // Convert to llm.CompletionResponse cresp := llm.CompletionResponse{ - Content: raw.Content, - Done: raw.Done, - Step: raw.Step, - TotalSteps: raw.Total, - Image: raw.Image, + Content: raw.Content, + Done: raw.Done, + Step: raw.Step, + TotalSteps: raw.Total, + Image: raw.Image, + PromptEvalCount: raw.PromptEvalCount, + PromptEvalDuration: time.Duration(raw.PromptEvalDuration), + EvalCount: raw.EvalCount, + EvalDuration: time.Duration(raw.EvalDuration), } fn(cresp) @@ -309,7 +337,20 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f } } - return scanner.Err() + // Scanner exited without receiving Done - connection was likely closed + scanErr := scanner.Err() + if scanErr != nil { + slog.Error("mlx scanner error", "error", scanErr) + } else { + slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed") + } + + // Check if subprocess is still alive + if s.HasExited() { + slog.Error("mlx subprocess has exited unexpectedly") + } + + return scanErr } // Close terminates the subprocess. @@ -318,7 +359,7 @@ func (s *Server) Close() error { defer s.mu.Unlock() if s.cmd != nil && s.cmd.Process != nil { - slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid) + slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid) s.cmd.Process.Signal(os.Interrupt) // Wait briefly for graceful shutdown @@ -347,23 +388,56 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramSize } -// Context length is not applicable for image generation. +// ContextLength returns the context length (not applicable for image generation). func (s *Server) ContextLength() int { return 0 } +// Embedding returns embeddings for the input. func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) { - return nil, 0, errors.New("not supported") + return nil, 0, errors.New("embeddings not supported for MLX models") } +// Tokenize tokenizes the input content. func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) { - return nil, errors.New("not supported") + body, err := json.Marshal(map[string]string{"content": content}) + if err != nil { + return nil, err + } + + url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode) + } + + var result struct { + Tokens []int `json:"tokens"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + return result.Tokens, nil } +// Detokenize converts tokens back to text. func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) { - return "", errors.New("not supported") + return "", errors.New("detokenization not supported for MLX models") } +// Pid returns the process ID of the subprocess. func (s *Server) Pid() int { s.mu.Lock() defer s.mu.Unlock() @@ -373,9 +447,17 @@ func (s *Server) Pid() int { return -1 } -func (s *Server) GetPort() int { return s.port } -func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil } +// GetPort returns the port the subprocess is listening on. +func (s *Server) GetPort() int { + return s.port +} +// GetDeviceInfos returns device information. +func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + return nil +} + +// HasExited returns whether the subprocess has exited. func (s *Server) HasExited() bool { select { case <-s.done: diff --git a/x/mlxrunner/types.go b/x/mlxrunner/types.go new file mode 100644 index 000000000..cd22d3941 --- /dev/null +++ b/x/mlxrunner/types.go @@ -0,0 +1,81 @@ +// Package mlxrunner provides a unified MLX runner for both LLM and image generation models. +// +// This package handles safetensors models created with `ollama create --experimental`, +// supporting both text generation (LLM) and image generation (diffusion) models +// through a single unified interface. +package mlxrunner + +// Request is the request format for completion requests. +type Request struct { + Prompt string `json:"prompt"` + + // LLM-specific fields + Options *RequestOptions `json:"options,omitempty"` + + // Image generation fields + Width int32 `json:"width,omitempty"` + Height int32 `json:"height,omitempty"` + Steps int `json:"steps,omitempty"` + Seed int64 `json:"seed,omitempty"` + Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning +} + +// RequestOptions contains LLM-specific generation options. +type RequestOptions struct { + NumPredict int `json:"num_predict,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop []string `json:"stop,omitempty"` +} + +// Response is streamed back for each progress update. +type Response struct { + // Text generation response + Content string `json:"content,omitempty"` + + // Image generation response + Image string `json:"image,omitempty"` // Base64-encoded PNG + + // Common fields + Done bool `json:"done"` + DoneReason int `json:"done_reason,omitempty"` + StopReason string `json:"stop_reason,omitempty"` // Debug: why generation stopped + + // Progress fields + Step int `json:"step,omitempty"` + Total int `json:"total,omitempty"` + + // Statistics + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration int `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` +} + +// HealthResponse is returned by the health endpoint. +type HealthResponse struct { + Status string `json:"status"` + Progress float32 `json:"progress,omitempty"` +} + +// ModelMode represents the type of model being run. +type ModelMode int + +const ( + // ModeLLM indicates a text generation model. + ModeLLM ModelMode = iota + // ModeImageGen indicates an image generation model. + ModeImageGen +) + +func (m ModelMode) String() string { + switch m { + case ModeLLM: + return "llm" + case ModeImageGen: + return "imagegen" + default: + return "unknown" + } +} diff --git a/x/model/models/gemma3/model.go b/x/model/models/gemma3/model.go index 23f78f207..4072122c6 100644 --- a/x/model/models/gemma3/model.go +++ b/x/model/models/gemma3/model.go @@ -87,7 +87,7 @@ func New(c fs.Config) (model.Model, error) { // m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) // TODO need to implement sliding window... - m.Cache = kvcache.NewMLXCausalCache() + m.Cache = kvcache.NewCausalCache() return &m, nil } diff --git a/x/server/show.go b/x/server/show.go index dd95774d3..652293e77 100644 --- a/x/server/show.go +++ b/x/server/show.go @@ -163,9 +163,18 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) { // getTensorInfoFromManifest extracts tensor info from a manifest. // This is separated for testability. +// For quantized models, groups weight/scale/qbias into single entries with detected quantization type. func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) { var tensors []api.Tensor + // First pass: collect all tensor info and identify scale tensors + type tensorData struct { + info *safetensorsTensorInfo + digest string + } + tensorMap := make(map[string]*tensorData) + scaleMap := make(map[string]*tensorData) // base name -> scale tensor info + for _, layer := range mf.Layers { if layer.MediaType != manifest.MediaTypeImageTensor { continue @@ -178,28 +187,96 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) { } info, err := readSafetensorsHeader(blobPath) if err != nil { - // Skip tensors we can't read continue } - // Convert shape from int to uint64 - shape := make([]uint64, len(info.Shape)) - for i, s := range info.Shape { - shape[i] = uint64(s) + td := &tensorData{info: info, digest: layer.Digest} + + if strings.HasSuffix(layer.Name, "_scale") { + baseName := strings.TrimSuffix(layer.Name, "_scale") + scaleMap[baseName] = td + } else if strings.HasSuffix(layer.Name, "_qbias") { + // Skip qbias tensors - they're included with the quantized weight + continue + } else { + tensorMap[layer.Name] = td + } + } + + // Second pass: build tensor list with quantization info + for _, layer := range mf.Layers { + if layer.MediaType != manifest.MediaTypeImageTensor { + continue } - tensors = append(tensors, api.Tensor{ - Name: layer.Name, - Type: info.Dtype, - Shape: shape, - }) + // Skip scale and qbias tensors + if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") { + continue + } + + td := tensorMap[layer.Name] + if td == nil { + continue + } + + // Check if this tensor has a corresponding scale tensor (quantized) + scaleTd := scaleMap[layer.Name] + if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 { + // Quantized tensor - detect bits from shapes + weightCols := td.info.Shape[len(td.info.Shape)-1] + scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1] + + // Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4 + // Q4 uses group_size=32: weightCols * 8 / scaleCols = 32 + // Q8 uses group_size=64: weightCols * 4 / scaleCols = 64 + var bits int + var quantType string + if weightCols*8/scaleCols == 32 { + bits = 4 + quantType = "Q4" + } else if weightCols*4/scaleCols == 64 { + bits = 8 + quantType = "Q8" + } else { + // Unknown quantization, show raw + quantType = td.info.Dtype + } + + // Calculate unpacked shape + shape := make([]uint64, len(td.info.Shape)) + for i, s := range td.info.Shape { + shape[i] = uint64(s) + } + if bits > 0 { + packFactor := int64(32 / bits) + shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor) + } + + tensors = append(tensors, api.Tensor{ + Name: layer.Name, + Type: quantType, + Shape: shape, + }) + } else { + // Non-quantized tensor + shape := make([]uint64, len(td.info.Shape)) + for i, s := range td.info.Shape { + shape[i] = uint64(s) + } + + tensors = append(tensors, api.Tensor{ + Name: layer.Name, + Type: td.info.Dtype, + Shape: shape, + }) + } } return tensors, nil } // GetSafetensorsDtype returns the quantization type for a safetensors model. -// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8"). +// Reads from model_index.json first, falls back to detection from tensor names. // Otherwise returns the torch_dtype from config.json. func GetSafetensorsDtype(name model.Name) (string, error) { mf, err := manifest.ParseNamedManifest(name) @@ -207,16 +284,38 @@ func GetSafetensorsDtype(name model.Name) (string, error) { return "", fmt.Errorf("failed to load manifest: %w", err) } - // Check if model is quantized by looking for _scale tensors + // First try to read quantization from model_index.json + var modelIndex struct { + Quantization string `json:"quantization"` + } + if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" { + return modelIndex.Quantization, nil + } + + // Fallback: detect from tensor names + hasScales := false + hasQBias := false for _, layer := range mf.Layers { if layer.MediaType == manifest.MediaTypeImageTensor { if strings.HasSuffix(layer.Name, "_scale") { - // Model is quantized - return FP8 (affine quantization) - return "FP8", nil + hasScales = true + } + if strings.HasSuffix(layer.Name, "_qbias") { + hasQBias = true } } } + if hasScales { + if hasQBias { + // Affine mode (has scale + qbias) - could be Q4 or Q8 + // Default to Q4 as it's more common + return "Q4", nil + } + // No qbias = NVFP4 + return "NVFP4", nil + } + // Not quantized - return torch_dtype from config.json var cfg struct { TorchDtype string `json:"torch_dtype"`