diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go
index eedd0c61d..433b2ab1b 100644
--- a/cmd/cmd_test.go
+++ b/cmd/cmd_test.go
@@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) {
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
- QuantizationLevel: "FP8",
+ QuantizationLevel: "Q8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
@@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) {
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
- " quantization FP8 \n" +
+ " quantization Q8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
diff --git a/runner/runner.go b/runner/runner.go
index 543410798..db50758bb 100644
--- a/runner/runner.go
+++ b/runner/runner.go
@@ -3,7 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
- imagerunner "github.com/ollama/ollama/x/imagegen/runner"
+ "github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -12,18 +12,18 @@ func Execute(args []string) error {
}
var newRunner bool
- var imageRunner bool
+ var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
- if len(args) > 0 && args[0] == "--image-engine" {
+ if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
- imageRunner = true
+ mlxRunner = true
}
- if imageRunner {
- return imagerunner.Execute(args)
+ if mlxRunner {
+ return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
diff --git a/server/sched.go b/server/sched.go
index 5f22e0b87..3aa9969a0 100644
--- a/server/sched.go
+++ b/server/sched.go
@@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
- "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -195,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
- // Check for image generation model before attempting GGML load
+ // Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
- if s.loadImageGen(pending) {
+ if s.loadMLX(pending) {
break
}
continue
}
+ // Check for experimental safetensors LLM models
+ if pending.model.Config.ModelFormat == "safetensors" {
+ if slices.Contains(pending.model.Config.Capabilities, "completion") {
+ // LLM model with safetensors format - use MLX runner
+ if s.loadMLX(pending) {
+ break
+ }
+ continue
+ }
+ }
+
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -552,11 +563,20 @@ iGPUScan:
return false
}
-// loadImageGen loads an image generation model.
-func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
- // Use model name for imagegen (it resolves manifests by name, not file path)
+// loadMLX loads an experimental safetensors model using the unified MLX runner.
+// This supports both LLM (completion) and image generation models.
+func (s *Scheduler) loadMLX(req *LlmRequest) bool {
+ // Determine mode based on capabilities
+ var mode mlxrunner.ModelMode
+ if slices.Contains(req.model.Config.Capabilities, "image") {
+ mode = mlxrunner.ModeImageGen
+ } else {
+ mode = mlxrunner.ModeLLM
+ }
+
+ // Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
- server, err := imagegen.NewServer(modelName)
+ server, err := mlxrunner.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true
diff --git a/x/create/client/create.go b/x/create/client/create.go
index c7e51a525..36e7f164b 100644
--- a/x/create/client/create.go
+++ b/x/create/client/create.go
@@ -13,6 +13,7 @@ import (
"io"
"os"
"path/filepath"
+ "strings"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
@@ -34,7 +35,7 @@ type ModelfileConfig struct {
type CreateOptions struct {
ModelName string
ModelDir string
- Quantize string // "fp8" for quantization
+ Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
@@ -53,10 +54,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
+ var parserName, rendererName string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
+
+ // Check if model supports thinking based on architecture
+ if supportsThinking(opts.ModelDir) {
+ capabilities = append(capabilities, "thinking")
+ }
+
+ // Set parser and renderer name based on architecture
+ parserName = getParserName(opts.ModelDir)
+ rendererName = getRendererName(opts.ModelDir)
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@@ -81,14 +92,14 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
- newManifestWriter(opts, capabilities),
+ newManifestWriter(opts, capabilities, parserName, rendererName),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
- newManifestWriter(opts, capabilities),
+ newManifestWriter(opts, capabilities, "", ""),
progressFn,
)
}
@@ -204,7 +215,7 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
-func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
+func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
@@ -229,6 +240,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
ModelFormat: "safetensors",
Capabilities: caps,
Requires: MinOllamaVersion,
+ Parser: parserName,
+ Renderer: rendererName,
}
configJSON, err := json.Marshal(configData)
if err != nil {
@@ -295,3 +308,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
return layers, nil
}
+
+// supportsThinking checks if the model supports thinking mode based on its architecture.
+// This reads the config.json from the model directory and checks the architectures field.
+func supportsThinking(modelDir string) bool {
+ configPath := filepath.Join(modelDir, "config.json")
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ return false
+ }
+
+ var cfg struct {
+ Architectures []string `json:"architectures"`
+ ModelType string `json:"model_type"`
+ }
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ return false
+ }
+
+ // Check architectures that support thinking
+ thinkingArchitectures := []string{
+ "glm4moe", // GLM-4 MoE models
+ "deepseek", // DeepSeek models
+ "qwen3", // Qwen3 models
+ }
+
+ // Check the architecture list
+ for _, arch := range cfg.Architectures {
+ archLower := strings.ToLower(arch)
+ for _, thinkArch := range thinkingArchitectures {
+ if strings.Contains(archLower, thinkArch) {
+ return true
+ }
+ }
+ }
+
+ // Also check model_type
+ if cfg.ModelType != "" {
+ typeLower := strings.ToLower(cfg.ModelType)
+ for _, thinkArch := range thinkingArchitectures {
+ if strings.Contains(typeLower, thinkArch) {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+// getParserName returns the parser name for a model based on its architecture.
+// This reads the config.json from the model directory and determines the appropriate parser.
+func getParserName(modelDir string) string {
+ configPath := filepath.Join(modelDir, "config.json")
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ return ""
+ }
+
+ var cfg struct {
+ Architectures []string `json:"architectures"`
+ ModelType string `json:"model_type"`
+ }
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ return ""
+ }
+
+ // Check architectures for known parsers
+ for _, arch := range cfg.Architectures {
+ archLower := strings.ToLower(arch)
+ if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
+ return "glm-4.7"
+ }
+ if strings.Contains(archLower, "deepseek") {
+ return "deepseek3"
+ }
+ if strings.Contains(archLower, "qwen3") {
+ return "qwen3-coder"
+ }
+ }
+
+ // Also check model_type
+ if cfg.ModelType != "" {
+ typeLower := strings.ToLower(cfg.ModelType)
+ if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
+ return "glm-4.7"
+ }
+ if strings.Contains(typeLower, "deepseek") {
+ return "deepseek3"
+ }
+ if strings.Contains(typeLower, "qwen3") {
+ return "qwen3-coder"
+ }
+ }
+
+ return ""
+}
+
+// getRendererName returns the renderer name for a model based on its architecture.
+// This reads the config.json from the model directory and determines the appropriate renderer.
+func getRendererName(modelDir string) string {
+ configPath := filepath.Join(modelDir, "config.json")
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ return ""
+ }
+
+ var cfg struct {
+ Architectures []string `json:"architectures"`
+ ModelType string `json:"model_type"`
+ }
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ return ""
+ }
+
+ // Check architectures for known renderers
+ for _, arch := range cfg.Architectures {
+ archLower := strings.ToLower(arch)
+ if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
+ return "glm-4.7"
+ }
+ if strings.Contains(archLower, "deepseek") {
+ return "deepseek3"
+ }
+ if strings.Contains(archLower, "qwen3") {
+ return "qwen3-coder"
+ }
+ }
+
+ // Also check model_type
+ if cfg.ModelType != "" {
+ typeLower := strings.ToLower(cfg.ModelType)
+ if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
+ return "glm-4.7"
+ }
+ if strings.Contains(typeLower, "deepseek") {
+ return "deepseek3"
+ }
+ if strings.Contains(typeLower, "qwen3") {
+ return "qwen3-coder"
+ }
+ }
+
+ return ""
+}
diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go
index 3a9f37cfc..e69003f73 100644
--- a/x/create/client/quantize.go
+++ b/x/create/client/quantize.go
@@ -13,7 +13,11 @@ import (
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
-// Supported quantization types: "fp8" (affine 8-bit)
+// Supported quantization types:
+// - "q4": affine 4-bit, group_size=32 (with qbiases)
+// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
+// - "q8": affine 8-bit, group_size=64 (with qbiases)
+// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
@@ -54,12 +58,18 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
- case "fp4":
- // affine mode: group_size=32, bits=4
+ case "q4":
+ // affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
- case "fp8":
- // affine mode: group_size=32, bits=8
- qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
+ case "nvfp4":
+ // NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
+ qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
+ case "q8":
+ // affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
+ qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
+ case "mxfp8":
+ // Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
+ qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
diff --git a/x/create/create.go b/x/create/create.go
index 823d0f842..2474c8c66 100644
--- a/x/create/create.go
+++ b/x/create/create.go
@@ -228,7 +228,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
-// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
+// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
@@ -262,36 +262,134 @@ func ShouldQuantize(name, component string) bool {
return strings.HasSuffix(name, ".weight")
}
-// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
+// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
-func ShouldQuantizeTensor(name string, shape []int32) bool {
+// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
+func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
+ return GetTensorQuantization(name, shape, quantize) != ""
+}
+
+// normalizeQuantType converts various quantization type aliases to canonical forms.
+// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
+func normalizeQuantType(quantize string) string {
+ switch strings.ToUpper(quantize) {
+ case "Q4", "INT4", "FP4":
+ return "q4"
+ case "Q8", "INT8", "FP8":
+ return "q8"
+ case "NVFP4":
+ return "nvfp4"
+ case "MXFP8":
+ return "mxfp8"
+ default:
+ return quantize
+ }
+}
+
+// getQuantGroupSize returns the group size for a given quantization type.
+// These must match the values used in quantize.go when creating quantized models.
+func getQuantGroupSize(quantize string) int {
+ switch normalizeQuantType(quantize) {
+ case "nvfp4":
+ return 16
+ case "q4":
+ return 32
+ case "mxfp8":
+ return 32
+ case "q8":
+ return 64
+ default:
+ return 32
+ }
+}
+
+// GetTensorQuantization returns the appropriate quantization type for a tensor.
+// Returns "" if the tensor should not be quantized.
+// This implements mixed-precision quantization:
+// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
+// - Output projection, gate/up weights: q4 (less sensitive)
+// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
+// - Norms, embeddings, biases, routing gates: no quantization
+func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
- return false
+ return ""
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
- return false
+ return ""
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
- return false
+ return ""
}
- // MLX quantization requires last dimension to be divisible by group size (32)
- if shape[len(shape)-1]%32 != 0 {
- return false
+ // Normalize quantization type to canonical form
+ quantNorm := normalizeQuantType(quantize)
+
+ // MLX quantization requires last dimension to be divisible by group size
+ // nvfp4: 16, q4/mxfp8: 32, q8: 64
+ groupSize := int32(32)
+ switch quantNorm {
+ case "nvfp4":
+ groupSize = 16
+ case "q8":
+ groupSize = 64
+ }
+ if shape[len(shape)-1]%groupSize != 0 {
+ return ""
}
- return true
+ // Skip routing gate weights (should stay high precision)
+ // In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
+ if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
+ return ""
+ }
+
+ // For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
+ if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
+ return quantNorm
+ }
+
+ // Attention MLA weights - keep unquantized (bf16)
+ // These are highly sensitive: errors accumulate in the KV cache over time
+ // q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
+ if strings.Contains(name, "q_a_proj") ||
+ strings.Contains(name, "q_b_proj") ||
+ strings.Contains(name, "kv_a_proj") ||
+ strings.Contains(name, "kv_b_proj") {
+ return "" // No quantization - keep bf16
+ }
+
+ // Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
+ // mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
+ if strings.Contains(name, "down_proj") {
+ return "q8"
+ }
+
+ // Output projection, gate/up weights - use requested quantization (Q4)
+ // o_proj, gate_proj, up_proj
+ if strings.Contains(name, "o_proj") ||
+ strings.Contains(name, "gate_proj") ||
+ strings.Contains(name, "up_proj") {
+ return quantNorm
+ }
+
+ // LM head - use requested quantization
+ if strings.Contains(name, "lm_head") {
+ return quantNorm
+ }
+
+ // Default to requested quantization for other weights
+ return quantNorm
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
-// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
+// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
@@ -330,9 +428,10 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
}
// Determine quantization type for this tensor (empty string if not quantizing)
+ // GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := ""
- if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
- quantizeType = quantize
+ if quantize != "" {
+ quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
}
// Store as minimal safetensors format (88 bytes header overhead)
@@ -388,6 +487,23 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return fmt.Errorf("config.json not found in %s", modelDir)
}
+ // Create model_index.json with quantization info if quantizing
+ if quantize != "" {
+ modelIndex := map[string]any{
+ "quantization": strings.ToUpper(quantize),
+ "group_size": getQuantGroupSize(quantize),
+ }
+ indexData, err := json.MarshalIndent(modelIndex, "", " ")
+ if err != nil {
+ return fmt.Errorf("failed to marshal model_index.json: %w", err)
+ }
+ indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
+ if err != nil {
+ return fmt.Errorf("failed to create model_index.json layer: %w", err)
+ }
+ layers = append(layers, indexLayer)
+ }
+
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
diff --git a/x/create/create_test.go b/x/create/create_test.go
index c69bb10a8..b5d0a7b34 100644
--- a/x/create/create_test.go
+++ b/x/create/create_test.go
@@ -536,41 +536,51 @@ func TestShouldQuantize(t *testing.T) {
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
- name string
- tensor string
- shape []int32
- want bool
+ name string
+ tensor string
+ shape []int32
+ quantize string
+ want bool
}{
// 2D tensors with sufficient size should be quantized
- {"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
- {"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
+ {"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
+ {"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
+ {"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
// Small tensors should not be quantized (< 1024 elements)
- {"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
- {"small 2D weight", "small.weight", []int32{31, 31}, false},
+ {"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
+ {"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
// 1D tensors should not be quantized
- {"1D tensor", "layer_norm.weight", []int32{4096}, false},
+ {"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
// 3D+ tensors should not be quantized
- {"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
- {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
+ {"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
+ {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
// Embeddings should not be quantized regardless of shape
- {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
+ {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
// Norms should not be quantized regardless of shape
- {"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
+ {"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
// Biases should not be quantized
- {"bias 2D", "proj.bias", []int32{4096, 1}, false},
+ {"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
+
+ // Group size divisibility tests
+ // FP8/FP4 require divisible by 32
+ {"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
+ {"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
+ // NVFP4 requires divisible by 16
+ {"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
+ {"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got := ShouldQuantizeTensor(tt.tensor, tt.shape)
+ got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
- t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
+ t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
@@ -741,7 +751,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
progressFn := func(status string) {}
- err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
+ err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
diff --git a/x/create/imagegen.go b/x/create/imagegen.go
index 595a40417..0da0e764a 100644
--- a/x/create/imagegen.go
+++ b/x/create/imagegen.go
@@ -15,15 +15,15 @@ import (
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
-// Supported quantization types: fp8 (or empty for no quantization).
+// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
- case "", "fp4", "fp8":
+ case "", "q4", "q8", "nvfp4", "mxfp8":
// valid
default:
- return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
+ return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
}
var layers []LayerInfo
@@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
- if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
+ if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
quantizeType = quantize
}
@@ -213,10 +213,18 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
}
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
-// MLX requires the last dimension to be divisible by the group size (32).
-func canQuantizeShape(shape []int32) bool {
+// MLX requires the last dimension to be divisible by the group size.
+// nvfp4: 16, q4/mxfp8: 32, q8: 64
+func canQuantizeShape(shape []int32, quantize string) bool {
if len(shape) < 2 {
return false
}
- return shape[len(shape)-1]%32 == 0
+ groupSize := int32(32)
+ switch strings.ToUpper(quantize) {
+ case "NVFP4":
+ groupSize = 16
+ case "Q8":
+ groupSize = 64
+ }
+ return shape[len(shape)-1]%groupSize == 0
}
diff --git a/x/imagegen/cache/cache.go b/x/imagegen/cache/cache.go
index 4faa2412e..8a25193cd 100644
--- a/x/imagegen/cache/cache.go
+++ b/x/imagegen/cache/cache.go
@@ -9,6 +9,7 @@ type Cache interface {
Offset() int
Len() int
State() []*mlx.Array
+ Reset()
}
type KVCache struct {
@@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array {
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
+// Reset clears the cache state for a new generation session
+func (c *KVCache) Reset() {
+ c.keys = nil
+ c.values = nil
+ c.offset = 0
+}
+
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
@@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array {
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
+
+// Reset clears the cache state for a new generation session
+func (c *RotatingKVCache) Reset() {
+ c.keys = nil
+ c.values = nil
+ c.offset = 0
+ c.idx = 0
+}
diff --git a/x/imagegen/manifest.go b/x/imagegen/manifest.go
index 7692d2b09..da55cbe81 100644
--- a/x/imagegen/manifest.go
+++ b/x/imagegen/manifest.go
@@ -102,14 +102,17 @@ func (m *ModelManifest) BlobPath(digest string) string {
return filepath.Join(m.BlobDir, blobName)
}
-// GetTensorLayers returns all tensor layers for a given component.
-// Component should be "text_encoder", "transformer", or "vae".
-// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
+// GetTensorLayers returns tensor layers, optionally filtered by component.
+// If component is empty, returns all tensor layers (for LLM models).
+// If component is specified (e.g., "text_encoder", "transformer", "vae"),
+// returns only layers with that prefix.
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
- prefix := component + "/"
var layers []ManifestLayer
for _, layer := range m.Manifest.Layers {
- if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
+ if layer.MediaType != "application/vnd.ollama.image.tensor" {
+ continue
+ }
+ if component == "" || strings.HasPrefix(layer.Name, component+"/") {
layers = append(layers, layer)
}
}
@@ -206,7 +209,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
- info.Quantization = "FP8"
+ info.Quantization = "Q8"
break
}
}
diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go
index 2b31aadfb..2232e482b 100644
--- a/x/imagegen/mlx/mlx.go
+++ b/x/imagegen/mlx/mlx.go
@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
return Concatenate([]*Array{a, b}, axis)
}
+// Stack stacks arrays along a new axis (axis 0 by default)
+func Stack(arrays []*Array, axis int) *Array {
+ handles := make([]C.mlx_array, len(arrays))
+ for i, arr := range arrays {
+ handles[i] = arr.c
+ }
+ vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
+ res := C.mlx_array_new()
+ C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
+ C.mlx_vector_array_free(vec)
+ return newArray(res)
+}
+
// Slice slices the array
func Slice(a *Array, start, stop []int32) *Array {
n := len(start)
diff --git a/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
new file mode 100644
index 000000000..caebbe361
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
@@ -0,0 +1,840 @@
+//go:build mlx
+
+// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
+// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
+package glm4_moe_lite
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+
+ "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/cache"
+ "github.com/ollama/ollama/x/imagegen/mlx"
+ "github.com/ollama/ollama/x/imagegen/nn"
+ "github.com/ollama/ollama/x/imagegen/safetensors"
+ "github.com/ollama/ollama/x/imagegen/tokenizer"
+)
+
+// RopeScaling holds RoPE scaling configuration
+type RopeScaling struct {
+ Factor float32 `json:"factor"`
+ MscaleAllDim float32 `json:"mscale_all_dim"`
+}
+
+// Config holds GLM4-MoE-Lite model configuration
+type Config struct {
+ HiddenSize int32 `json:"hidden_size"`
+ NumHiddenLayers int32 `json:"num_hidden_layers"`
+ IntermediateSize int32 `json:"intermediate_size"`
+ MoEIntermediateSize int32 `json:"moe_intermediate_size"`
+ NumAttentionHeads int32 `json:"num_attention_heads"`
+ NumKeyValueHeads int32 `json:"num_key_value_heads"`
+ VocabSize int32 `json:"vocab_size"`
+ RMSNormEps float32 `json:"rms_norm_eps"`
+ RopeTheta float32 `json:"rope_theta"`
+ MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
+ AttentionBias bool `json:"attention_bias"`
+
+ // MLA (Multi-head Latent Attention) parameters
+ QLoraRank int32 `json:"q_lora_rank"`
+ KVLoraRank int32 `json:"kv_lora_rank"`
+ QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
+ QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
+ VHeadDim int32 `json:"v_head_dim"`
+
+ // MoE parameters
+ NRoutedExperts int32 `json:"n_routed_experts"`
+ NSharedExperts int32 `json:"n_shared_experts"`
+ NumExpertsPerTok int32 `json:"num_experts_per_tok"`
+ RoutedScalingFactor float32 `json:"routed_scaling_factor"`
+ NormTopKProb bool `json:"norm_topk_prob"`
+ FirstKDenseReplace int32 `json:"first_k_dense_replace"`
+ NGroup int32 `json:"n_group"`
+ TopKGroup int32 `json:"topk_group"`
+
+ // RoPE scaling
+ RopeScaling *RopeScaling `json:"rope_scaling"`
+
+ // Quantization parameters (set during load based on model quantization)
+ QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
+ QuantBits int `json:"-"` // Bits per weight (4 or 8)
+ QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
+
+ // Computed fields
+ QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
+ Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
+}
+
+// MLAAttention implements Multi-head Latent Attention with absorption.
+// This uses absorbed MLA which operates in latent space for reduced KV cache.
+type MLAAttention struct {
+ // Low-rank query projections
+ QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
+ QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
+ QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
+
+ // Low-rank KV projections (with shared rope component)
+ KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
+ KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
+
+ // Absorbed MLA projections (derived from kv_b_proj)
+ // EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
+ // UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
+ EmbedQ *nn.MultiLinear `weight:"-"`
+ UnembedOut *nn.MultiLinear `weight:"-"`
+
+ // Output projection
+ OProj nn.LinearLayer `weight:"self_attn.o_proj"`
+}
+
+// Forward computes absorbed MLA attention output.
+// This operates in latent space for reduced KV cache memory.
+func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ // Query path: q_a_proj -> layernorm -> q_b_proj
+ q := a.QAProj.Forward(x)
+ q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
+ q = a.QBProj.Forward(q)
+
+ // Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
+ q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
+ q = mlx.Transpose(q, 0, 2, 1, 3)
+
+ // Split Q into nope and rope parts
+ qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
+ qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
+
+ // KV path: get compressed KV and k_pe
+ compressedKV := a.KVAProjWithMQA.Forward(x)
+
+ // Split into compressed_kv and k_pe (shared rope component)
+ kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
+ kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
+
+ // k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
+ kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
+ kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
+
+ // Apply layernorm to get kv latent representation
+ kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
+ // kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting
+ kvLatent = mlx.ExpandDims(kvLatent, 1)
+
+ // Apply RoPE to the rope parts
+ offset := 0
+ if c != nil {
+ offset = c.Offset()
+ }
+ qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
+ kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
+
+ // ABSORBED MLA: project q_nope to latent space
+ // qNope: [B, num_heads, L, qk_nope_head_dim]
+ // EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim]
+ // Result: [B, num_heads, L, kv_lora_rank]
+ qLatent := a.EmbedQ.Forward(qNope)
+
+ // Keys = concat(kvLatent, kPE)
+ // kvLatent: [B, 1, L, kv_lora_rank]
+ // kPE: [B, 1, L, qk_rope_head_dim]
+ // keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim]
+ keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
+
+ // Cache the smaller latent representation
+ // We cache keys (latent + rope) and use empty values since values are derived from keys
+ cachedL := L
+ if c != nil {
+ // Create placeholder values with 0 dims for cache (we don't actually use cached values)
+ placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32)
+ keys, _ = c.Update(keys, placeholderValues, int(L))
+ cachedL = int32(keys.Shape()[2])
+ }
+
+ // Values are the first kv_lora_rank dims of keys (slice off rope part)
+ values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
+
+ // Queries = concat(qLatent, qPE)
+ // qLatent: [B, num_heads, L, kv_lora_rank]
+ // qPE: [B, num_heads, L, qk_rope_head_dim]
+ // queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim]
+ queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
+
+ // Attention in latent space
+ // queries: [B, num_heads, L, kv_lora_rank + rope_dim]
+ // keys: [B, 1, cachedL, kv_lora_rank + rope_dim]
+ // values: [B, 1, cachedL, kv_lora_rank]
+ out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
+
+ // ABSORBED MLA: unembed from latent space
+ // out: [B, num_heads, L, kv_lora_rank]
+ // UnembedOut: [num_heads, v_head_dim, kv_lora_rank]
+ // Result: [B, num_heads, L, v_head_dim]
+ out = a.UnembedOut.Forward(out)
+
+ // Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
+ out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
+
+ return a.OProj.Forward(out)
+}
+
+// DenseMLP implements the standard SwiGLU MLP for dense layers
+type DenseMLP struct {
+ GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
+ UpProj nn.LinearLayer `weight:"mlp.up_proj"`
+ DownProj nn.LinearLayer `weight:"mlp.down_proj"`
+}
+
+// Forward applies the SwiGLU MLP
+func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
+ gate := mlx.SiLU(m.GateProj.Forward(x))
+ up := m.UpProj.Forward(x)
+ return m.DownProj.Forward(mlx.Mul(gate, up))
+}
+
+// MoEGate implements the expert gating mechanism
+type MoEGate struct {
+ Gate nn.LinearLayer `weight:"mlp.gate"`
+ EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
+}
+
+// Forward computes expert selection indices and scores
+func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
+ // Compute gate logits through linear layer (handles both quantized and non-quantized)
+ gates := g.Gate.Forward(x)
+
+ // Sigmoid scoring
+ scores := mlx.Sigmoid(gates)
+ origScores := scores
+
+ // Add correction bias if present
+ if g.EScoreCorrectionBias != nil {
+ scores = mlx.Add(scores, g.EScoreCorrectionBias)
+ }
+
+ // Group-wise expert selection (simplified for n_group=1)
+ // Select top-k experts
+ topK := cfg.NumExpertsPerTok
+ negScores := mlx.Neg(scores)
+ inds := mlx.Argpartition(negScores, int(topK)-1, -1)
+
+ shape := inds.Shape()
+ inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
+
+ // Get scores for selected experts
+ scores = mlx.TakeAlongAxis(origScores, inds, -1)
+
+ // Normalize if configured
+ if topK > 1 && cfg.NormTopKProb {
+ sumScores := mlx.Sum(scores, -1, true)
+ scores = mlx.Div(scores, sumScores)
+ }
+
+ // Apply routing scaling factor
+ scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
+
+ return inds, scores
+}
+
+// SwitchMLP implements the MoE expert computation using stacked weights
+// Note: No weight tags - these are populated manually by stacking expert weights
+type SwitchMLP struct {
+ // Dequantized weights (used when GatherQMM not available)
+ GateWeight *mlx.Array
+ UpWeight *mlx.Array
+ DownWeight *mlx.Array
+
+ // Quantized weights (used with GatherQMM for 4/8-bit affine)
+ GateWeightQ, GateScales, GateBiases *mlx.Array
+ UpWeightQ, UpScales, UpBiases *mlx.Array
+ DownWeightQ, DownScales, DownBiases *mlx.Array
+
+ // Quantization bits per projection (supports mixed precision Q4/Q8)
+ GateBits int
+ UpBits int
+ DownBits int
+
+ // Quantization group size per projection (detected from tensor shapes)
+ GateGroupSize int
+ UpGroupSize int
+ DownGroupSize int
+
+ // If true, use GatherQMM with quantized weights
+ UseQuantized bool
+}
+
+// Forward applies the switched expert MLP
+func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
+ shape := x.Shape()
+ B, L := shape[0], shape[1]
+ topK := cfg.NumExpertsPerTok
+
+ // Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
+ xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
+
+ // Flatten for gather_mm: [B*L, 1, 1, D]
+ xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
+
+ // Flatten indices: [B, L, topK] -> [B*L, topK]
+ idxFlat := mlx.Reshape(indices, B*L, topK)
+
+ // Sort for efficient gather (when we have many tokens)
+ doSort := B*L >= 64
+ var invOrder *mlx.Array
+ n := B * L * topK
+
+ if doSort {
+ idxAll := mlx.Flatten(idxFlat)
+ order := mlx.Argsort(idxAll, 0)
+ invOrder = mlx.Argsort(order, 0)
+ // Reorder x based on sorted indices
+ xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
+ idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
+ }
+
+ var gate, up, hidden, down *mlx.Array
+
+ if s.UseQuantized {
+ // Use GatherQMM for quantized weights (faster, keeps weights quantized)
+ // Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
+ gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
+ nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
+ up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
+ nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
+
+ hidden = mlx.Mul(mlx.SiLU(gate), up)
+
+ down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
+ nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
+ } else {
+ // Use GatherMM for dequantized/non-quantized weights
+ gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
+ up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
+
+ hidden = mlx.Mul(mlx.SiLU(gate), up)
+
+ down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
+ }
+
+ // Unsort if we sorted
+ if doSort {
+ down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
+ } else {
+ down = mlx.Squeeze(down, 2)
+ }
+
+ return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
+}
+
+// SharedExperts implements the shared expert MLP
+type SharedExperts struct {
+ GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
+ UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
+ DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
+}
+
+// Forward applies the shared expert MLP
+func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
+ gate := mlx.SiLU(s.GateProj.Forward(x))
+ up := s.UpProj.Forward(x)
+ return s.DownProj.Forward(mlx.Mul(gate, up))
+}
+
+// MoE implements the full Mixture of Experts layer
+type MoE struct {
+ Gate *MoEGate
+ SwitchMLP *SwitchMLP
+ SharedExperts *SharedExperts
+}
+
+// Forward applies the MoE layer
+func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
+ shape := x.Shape()
+ B, L := shape[0], shape[1]
+
+ // Get expert indices and scores
+ inds, scores := m.Gate.Forward(x, cfg)
+
+ // Apply routed experts
+ expertOut := m.SwitchMLP.Forward(x, inds, cfg)
+
+ // Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
+ scoresExpanded := mlx.ExpandDims(scores, -1)
+ y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
+
+ // Add shared experts if present
+ if m.SharedExperts != nil {
+ y = mlx.Add(y, m.SharedExperts.Forward(x))
+ }
+
+ return mlx.Reshape(y, B, L, cfg.HiddenSize)
+}
+
+// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
+type DenseBlock struct {
+ Attention *MLAAttention
+ MLP *DenseMLP
+ InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
+ PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
+}
+
+// Forward applies the dense block
+func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ // Pre-norm attention with residual
+ r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
+ h := mlx.Add(x, r)
+
+ // Pre-norm MLP with residual
+ r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
+ return mlx.Add(h, r)
+}
+
+// MoEBlock represents a MoE transformer block
+type MoEBlock struct {
+ Attention *MLAAttention
+ MoE *MoE
+ InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
+ PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
+}
+
+// Forward applies the MoE block
+func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ // Pre-norm attention with residual
+ r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
+ h := mlx.Add(x, r)
+
+ // Pre-norm MoE with residual
+ r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
+ return mlx.Add(h, r)
+}
+
+// Block interface for both dense and MoE blocks
+type Block interface {
+ Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
+}
+
+// Model represents the complete GLM4-MoE-Lite model
+type Model struct {
+ EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
+ Layers []Block `weight:"-"` // Loaded manually due to different block types
+ Norm *nn.RMSNorm `weight:"model.norm"`
+ LMHead nn.LinearLayer `weight:"lm_head"`
+
+ tok *tokenizer.Tokenizer
+ *Config
+}
+
+// computeScale computes the attention scale.
+// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner.
+func computeScale(cfg *Config) float32 {
+ keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
+ scale := float32(1.0 / math.Sqrt(float64(keyLength)))
+ if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
+ s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
+ scale *= s * s
+ }
+ return scale
+}
+
+// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
+// Currently only 4-bit and 8-bit affine quantization are supported.
+func supportsGatherQMM(mode string, bits int) bool {
+ return mode == "affine" && (bits == 4 || bits == 8)
+}
+
+// ExpertWeight holds a single expert's weight with optional quantization components.
+type ExpertWeight struct {
+ Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
+ Scales *mlx.Array // Quantization scales (nil if not quantized)
+ Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
+ Bits int // Quantization bits (4 or 8), 0 if not quantized
+ GroupSize int // Quantization group size, 0 if not quantized
+}
+
+// getQuantParams returns quantization parameters from model metadata.
+// Returns groupSize, bits, and mode for the model's quantization type.
+func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
+ groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
+ // Use metadata group_size if available (overrides default)
+ if gs := weights.GroupSize(); gs > 0 {
+ groupSize = gs
+ }
+ return groupSize, bits, mode
+}
+
+// loadExpertWeight loads an expert weight.
+// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
+// Otherwise dequantizes and returns only the weight.
+func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
+ w, _ := weights.GetTensor(path + ".weight")
+ if w == nil {
+ return nil
+ }
+
+ // Check if this is a quantized weight by looking for scales
+ scalePath := path + ".weight_scale"
+ if weights.HasTensor(scalePath) {
+ scales, _ := weights.GetTensor(scalePath)
+ var qbiases *mlx.Array
+ qbiasPath := path + ".weight_qbias"
+ if weights.HasTensor(qbiasPath) {
+ qbiases, _ = weights.GetTensor(qbiasPath)
+ }
+
+ // Get quantization params from metadata
+ groupSize, bits, mode := getQuantParams(weights)
+
+ // Update config with group size (for GatherQMM calls)
+ if cfg.QuantGroupSize == 0 {
+ cfg.QuantGroupSize = groupSize
+ }
+
+ // If GatherQMM is supported and requested, return quantized components
+ if useQuantized && supportsGatherQMM(mode, bits) {
+ return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
+ }
+
+ // Otherwise dequantize
+ return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
+ }
+
+ return &ExpertWeight{Weight: w}
+}
+
+// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
+// Returns embed_q and unembed_out weights for per-head projections.
+//
+// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
+// Output:
+// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent
+// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output
+func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
+ path := prefix + ".self_attn.kv_b_proj"
+ w, err := weights.GetTensor(path + ".weight")
+ if err != nil || w == nil {
+ return nil, nil
+ }
+
+ // Check if quantized and dequantize
+ scalePath := path + ".weight_scale"
+ if weights.HasTensor(scalePath) {
+ scales, _ := weights.GetTensor(scalePath)
+ var qbiases *mlx.Array
+ qbiasPath := path + ".weight_qbias"
+ if weights.HasTensor(qbiasPath) {
+ qbiases, _ = weights.GetTensor(qbiasPath)
+ }
+
+ groupSize, bits, mode := getQuantParams(weights)
+ w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
+ }
+
+ // w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
+ // Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
+ headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
+ w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
+
+ // Split into wk and wv
+ // wk: [num_heads, qk_nope_head_dim, kv_lora_rank]
+ // wv: [num_heads, v_head_dim, kv_lora_rank]
+ wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
+ wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
+
+ // Transform for absorbed MLA:
+ // embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim]
+ // This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection)
+ embedQ := mlx.Transpose(wk, 0, 2, 1)
+
+ // unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank]
+ // This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection)
+ unembedOut := wv
+
+ return embedQ, unembedOut
+}
+
+// StackedExpertWeights holds stacked weights for all experts.
+type StackedExpertWeights struct {
+ Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
+ Scales *mlx.Array // Stacked scales (nil if not quantized)
+ Biases *mlx.Array // Stacked biases (nil if not quantized)
+ Bits int // Quantization bits (4 or 8), 0 if not quantized
+ GroupSize int // Quantization group size, 0 if not quantized
+}
+
+// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
+func collectAndStackExpertWeights(
+ weights safetensors.WeightSource,
+ prefix string,
+ projName string,
+ numExperts int32,
+ useQuantized bool,
+ cfg *Config,
+) *StackedExpertWeights {
+ var w, s, b []*mlx.Array
+ var bits, groupSize int
+
+ for e := int32(0); e < numExperts; e++ {
+ path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
+ ew := loadExpertWeight(weights, path, useQuantized, cfg)
+ if ew == nil {
+ continue
+ }
+ w = append(w, ew.Weight)
+ if ew.Scales != nil {
+ s = append(s, ew.Scales)
+ }
+ if ew.Biases != nil {
+ b = append(b, ew.Biases)
+ }
+ if e == 0 {
+ bits = ew.Bits
+ groupSize = ew.GroupSize
+ }
+ }
+
+ result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
+ if len(w) > 0 {
+ result.Weight = mlx.Stack(w, 0)
+ if len(s) > 0 {
+ result.Scales = mlx.Stack(s, 0)
+ }
+ if len(b) > 0 {
+ result.Biases = mlx.Stack(b, 0)
+ }
+ }
+ return result
+}
+
+// sanitizeExpertWeights stacks individual expert weights into tensors.
+// If useQuantized is true and weights support GatherQMM, returns quantized components.
+// Otherwise returns dequantized weights with nil scales/biases.
+// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
+func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
+ gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
+ up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
+ down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
+ return gate, up, down
+}
+
+// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
+func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
+ // Read config from manifest
+ configData, err := manifest.ReadConfig("config.json")
+ if err != nil {
+ return nil, fmt.Errorf("load config: %w", err)
+ }
+
+ var cfg Config
+ if err := json.Unmarshal(configData, &cfg); err != nil {
+ return nil, fmt.Errorf("parse config: %w", err)
+ }
+
+ // Compute derived fields
+ cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
+ cfg.Scale = computeScale(&cfg)
+
+ // Load weights from manifest blobs
+ weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
+ if err != nil {
+ return nil, fmt.Errorf("load weights: %w", err)
+ }
+
+ if err := weights.Load(0); err != nil {
+ return nil, fmt.Errorf("load weight data: %w", err)
+ }
+
+ // Set up quantization parameters (only if model is actually quantized)
+ // Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
+ quantization := weights.Quantization()
+ useQuantized := false
+ if quantization != "" {
+ _, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
+ useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
+ }
+
+ // Load tokenizer from manifest with config files for EOS token detection
+ tokData, err := manifest.ReadConfig("tokenizer.json")
+ if err != nil {
+ return nil, fmt.Errorf("load tokenizer config: %w", err)
+ }
+
+ // Build tokenizer config with companion files for EOS/BOS token loading
+ tokConfig := &tokenizer.TokenizerConfig{
+ ConfigJSON: configData, // Already loaded above, contains eos_token_id
+ }
+
+ // Try to load generation_config.json if available (preferred source for EOS)
+ if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
+ tokConfig.GenerationConfigJSON = genConfigData
+ }
+
+ // Try to load tokenizer_config.json if available
+ if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
+ tokConfig.TokenizerConfigJSON = tokConfigData
+ }
+
+ tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
+ if err != nil {
+ return nil, fmt.Errorf("parse tokenizer: %w", err)
+ }
+
+ m := &Model{
+ Layers: make([]Block, cfg.NumHiddenLayers),
+ Config: &cfg,
+ tok: tok,
+ }
+
+ // Load embedding, norm, and lm_head
+ if err := safetensors.LoadModule(m, weights, ""); err != nil {
+ return nil, err
+ }
+
+ // Load layers manually due to different block types
+ for i := int32(0); i < cfg.NumHiddenLayers; i++ {
+ prefix := fmt.Sprintf("model.layers.%d", i)
+
+ // Load attention (same for both block types)
+ attn := &MLAAttention{}
+ if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d attention: %w", i, err)
+ }
+
+ // Sanitize MLA weights for absorbed attention
+ embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
+ attn.EmbedQ = nn.NewMultiLinear(embedQ)
+ attn.UnembedOut = nn.NewMultiLinear(unembedOut)
+
+ if i < cfg.FirstKDenseReplace {
+ // Dense block
+ block := &DenseBlock{Attention: attn}
+ if err := safetensors.LoadModule(block, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d dense: %w", i, err)
+ }
+ m.Layers[i] = block
+ } else {
+ // MoE block
+ block := &MoEBlock{Attention: attn}
+ if err := safetensors.LoadModule(block, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d moe block: %w", i, err)
+ }
+
+ // Stack expert weights (pass cfg so group sizes can be detected)
+ gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
+
+ switchMLP := &SwitchMLP{UseQuantized: useQuantized}
+ if useQuantized {
+ switchMLP.GateWeightQ = gate.Weight
+ switchMLP.GateScales = gate.Scales
+ switchMLP.GateBiases = gate.Biases
+ switchMLP.GateBits = gate.Bits
+ switchMLP.GateGroupSize = gate.GroupSize
+ switchMLP.UpWeightQ = up.Weight
+ switchMLP.UpScales = up.Scales
+ switchMLP.UpBiases = up.Biases
+ switchMLP.UpBits = up.Bits
+ switchMLP.UpGroupSize = up.GroupSize
+ switchMLP.DownWeightQ = down.Weight
+ switchMLP.DownScales = down.Scales
+ switchMLP.DownBiases = down.Biases
+ switchMLP.DownBits = down.Bits
+ switchMLP.DownGroupSize = down.GroupSize
+ } else {
+ switchMLP.GateWeight = gate.Weight
+ switchMLP.UpWeight = up.Weight
+ switchMLP.DownWeight = down.Weight
+ }
+
+ block.MoE = &MoE{
+ Gate: &MoEGate{},
+ SwitchMLP: switchMLP,
+ }
+
+ // Load gate weights
+ if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d gate: %w", i, err)
+ }
+
+ // Load shared experts if present
+ if cfg.NSharedExperts > 0 {
+ block.MoE.SharedExperts = &SharedExperts{}
+ if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
+ }
+ }
+
+ m.Layers[i] = block
+ }
+ }
+
+ mlx.Eval(mlx.Collect(m)...)
+ weights.ReleaseAll()
+
+ return m, nil
+}
+
+// Forward computes the forward pass of the model
+func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
+ B, L := tokens.Shape()[0], tokens.Shape()[1]
+
+ h := m.EmbedTokens.Forward(tokens)
+
+ for i, layer := range m.Layers {
+ var c cache.Cache
+ if caches != nil {
+ c = caches[i]
+ }
+ h = layer.Forward(h, c, B, L, m.Config)
+ }
+
+ h = m.Norm.Forward(h, m.RMSNormEps)
+ return m.LMHead.Forward(h)
+}
+
+// Interface methods
+
+// NumLayers returns the number of transformer layers
+func (m *Model) NumLayers() int { return len(m.Layers) }
+
+// MaxContextLength returns the maximum context length
+func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
+
+// VocabSize returns the vocabulary size
+func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
+
+// Tokenizer returns the model's tokenizer
+func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
+
+// NewCache creates a new KV cache for the model
+func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
+ caches := make([]cache.Cache, len(m.Layers))
+ for i := range caches {
+ caches[i] = cache.NewKVCache()
+ }
+ return caches
+}
+
+// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
+// This follows the GLM-4.7 format with tag for reasoning mode.
+func (m *Model) FormatPrompt(prompt string) string {
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+}
+
+// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
+// When think is true, the prompt ends with to enable reasoning mode.
+// When think is false, the prompt ends with to skip reasoning.
+func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
+ if think {
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+ }
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+}
+
+// NewRenderer returns a new Renderer for formatting multi-turn conversations.
+func (m *Model) NewRenderer() *Renderer {
+ return &Renderer{}
+}
+
+// NewParser returns a new Parser for extracting thinking and tool calls from output.
+func (m *Model) NewParser() *Parser {
+ return &Parser{}
+}
diff --git a/x/imagegen/models/glm4_moe_lite/parser.go b/x/imagegen/models/glm4_moe_lite/parser.go
new file mode 100644
index 000000000..c81ec5a40
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/parser.go
@@ -0,0 +1,479 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "context"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "log/slog"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/logutil"
+)
+
+type parserState int
+
+const (
+ parserState_LookingForThinkingOpen parserState = iota
+ parserState_ThinkingStartedEatingWhitespace
+ parserState_CollectingThinking
+ parserState_ThinkingDoneEatingWhitespace
+ parserState_CollectingContent
+ parserState_ToolStartedEatingWhitespace
+ parserState_CollectingToolContent
+)
+
+const (
+ thinkingOpenTag = ""
+ thinkingCloseTag = ""
+ toolOpenTag = ""
+ toolCloseTag = ""
+)
+
+// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
+// GLM-4's prompt ends with when thinking is enabled, so the parser
+// must start in CollectingThinking state (the model outputs thinking content directly).
+type Parser struct {
+ state parserState
+ buffer strings.Builder
+ tools []api.Tool
+}
+
+// HasToolSupport returns true as GLM4 supports tool calling.
+func (p *Parser) HasToolSupport() bool {
+ return true
+}
+
+// HasThinkingSupport returns true as GLM4 supports thinking mode.
+func (p *Parser) HasThinkingSupport() bool {
+ return true
+}
+
+// Init initializes the parser with tools and thinking configuration.
+func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
+ p.tools = tools
+ // When thinking is enabled (nil or true), the prompt ends with ,
+ // so model output starts directly with thinking content (no opening tag).
+ if thinkValue == nil || thinkValue.Bool() {
+ p.state = parserState_CollectingThinking
+ }
+ return tools
+}
+
+type parserEvent interface {
+ isParserEvent()
+}
+
+type eventContent struct {
+ content string
+}
+
+func (eventContent) isParserEvent() {}
+
+type eventRawToolCall struct {
+ raw string
+}
+
+func (eventRawToolCall) isParserEvent() {}
+
+type eventThinkingContent struct {
+ content string
+}
+
+func (eventThinkingContent) isParserEvent() {}
+
+// Add processes new output text and returns parsed content, thinking, and tool calls.
+func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
+ p.buffer.WriteString(s)
+ events := p.parseEvents()
+
+ var toolCalls []api.ToolCall
+ var contentSb strings.Builder
+ var thinkingSb strings.Builder
+
+ for _, event := range events {
+ switch event := event.(type) {
+ case eventRawToolCall:
+ toolCall, err := parseToolCall(event, p.tools)
+ if err != nil {
+ slog.Warn("glm-4 tool call parsing failed", "error", err)
+ return "", "", nil, err
+ }
+ toolCalls = append(toolCalls, toolCall)
+ case eventThinkingContent:
+ thinkingSb.WriteString(event.content)
+ case eventContent:
+ contentSb.WriteString(event.content)
+ }
+ }
+
+ return contentSb.String(), thinkingSb.String(), toolCalls, nil
+}
+
+func (p *Parser) parseEvents() []parserEvent {
+ var all []parserEvent
+
+ keepLooping := true
+ for keepLooping {
+ var events []parserEvent
+ events, keepLooping = p.eat()
+ if len(events) > 0 {
+ all = append(all, events...)
+ }
+ }
+
+ if len(all) > 0 {
+ slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
+ }
+
+ return all
+}
+
+// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
+// and transitions to the next state. Returns (nil, false) if only whitespace remains
+// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
+func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
+ trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
+ p.buffer.Reset()
+ if trimmed == "" {
+ return nil, false // Still only whitespace, keep waiting for more input
+ }
+ p.state = nextState
+ p.buffer.WriteString(trimmed)
+ return nil, true // Successfully transitioned
+}
+
+// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
+// the content after (optionally trimmed of leading whitespace), and updates the buffer
+func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
+ split := strings.SplitN(p.buffer.String(), tag, 2)
+ before := split[0]
+ before = strings.TrimRightFunc(before, unicode.IsSpace)
+ after := split[1]
+ if trimAfter {
+ after = strings.TrimLeftFunc(after, unicode.IsSpace)
+ }
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ return before, after
+}
+
+func (p *Parser) eat() ([]parserEvent, bool) {
+ var events []parserEvent
+
+ switch p.state {
+ case parserState_LookingForThinkingOpen:
+ trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
+ if strings.HasPrefix(trimmed, thinkingOpenTag) {
+ // Found opening tag
+ after := strings.TrimPrefix(trimmed, thinkingOpenTag)
+ after = strings.TrimLeftFunc(after, unicode.IsSpace)
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ if after == "" {
+ p.state = parserState_ThinkingStartedEatingWhitespace
+ } else {
+ p.state = parserState_CollectingThinking
+ }
+ return events, true
+ } else if strings.HasPrefix(thinkingOpenTag, trimmed) {
+ // Partial opening tag seen, keep accumulating
+ return events, false
+ } else if trimmed == "" {
+ // Only whitespace, keep accumulating
+ return events, false
+ } else {
+ // No thinking tag found, skip to content collection
+ p.state = parserState_CollectingContent
+ // Don't trim - we want to keep the original content
+ return events, true
+ }
+
+ case parserState_ThinkingStartedEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
+
+ case parserState_CollectingThinking:
+ acc := p.buffer.String()
+ if strings.Contains(acc, thinkingCloseTag) {
+ thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
+ if len(thinking) > 0 {
+ events = append(events, eventThinkingContent{content: thinking})
+ }
+ if remaining == "" {
+ p.state = parserState_ThinkingDoneEatingWhitespace
+ } else {
+ p.state = parserState_CollectingContent
+ }
+ return events, true
+ } else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
+ // Partial closing tag - withhold it along with any trailing whitespace before it
+ beforePartialTag := acc[:len(acc)-overlapLen]
+ trailingWsLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWsLen
+
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ } else {
+ // Pure thinking content - withhold trailing whitespace (might precede closing tag)
+ whitespaceLen := trailingWhitespaceLen(acc)
+ ambiguousStart := len(acc) - whitespaceLen
+
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ case parserState_ThinkingDoneEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
+
+ case parserState_CollectingContent:
+ if strings.Contains(p.buffer.String(), toolOpenTag) {
+ before, after := p.splitAtTag(toolOpenTag, true)
+ if len(before) > 0 {
+ events = append(events, eventContent{content: before})
+ }
+ if after == "" {
+ p.state = parserState_ToolStartedEatingWhitespace
+ } else {
+ p.state = parserState_CollectingToolContent
+ }
+ return events, true
+ } else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
+ beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
+ trailingWsLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWsLen
+
+ unambiguous := p.buffer.String()[:ambiguousStart]
+ ambiguous := p.buffer.String()[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventContent{content: unambiguous})
+ }
+ return events, false
+ } else {
+ whitespaceLen := trailingWhitespaceLen(p.buffer.String())
+ ambiguousStart := len(p.buffer.String()) - whitespaceLen
+
+ unambiguous := p.buffer.String()[:ambiguousStart]
+ ambiguous := p.buffer.String()[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ case parserState_ToolStartedEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
+
+ case parserState_CollectingToolContent:
+ acc := p.buffer.String()
+ if strings.Contains(acc, toolCloseTag) {
+ toolContent, _ := p.splitAtTag(toolCloseTag, true)
+ if len(toolContent) == 0 {
+ slog.Warn("glm4 tool call closing tag found but no content before it")
+ }
+ events = append(events, eventRawToolCall{raw: toolContent})
+ p.state = parserState_CollectingContent
+ return events, true
+ } else {
+ // Keep accumulating - tool calls are not streamed
+ // We just wait for the closing tag
+ return events, false
+ }
+
+ default:
+ panic("unreachable")
+ }
+}
+
+// overlap returns the length of the overlap between the end of s and the start of tag.
+func overlap(s, tag string) int {
+ for i := 1; i <= len(tag) && i <= len(s); i++ {
+ if strings.HasSuffix(s, tag[:i]) {
+ return i
+ }
+ }
+ return 0
+}
+
+// trailingWhitespaceLen returns the length of trailing whitespace in s.
+func trailingWhitespaceLen(s string) int {
+ trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
+ return len(s) - len(trimmed)
+}
+
+// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
+type ToolCallXML struct {
+ XMLName xml.Name `xml:"tool_call"`
+ Content string `xml:",chardata"` // Function name (text nodes between tags)
+ Keys []string `xml:"arg_key"` // All arg_key elements in document order
+ Values []string `xml:"arg_value"` // All arg_value elements in document order
+}
+
+// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
+func escapeContent(s string) string {
+ var result strings.Builder
+ inTag := false
+
+ for i := range len(s) {
+ ch := s[i]
+
+ if ch == '<' {
+ // Check if this is a known tag
+ if strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") {
+ inTag = true
+ }
+ }
+
+ if inTag {
+ result.WriteByte(ch)
+ if ch == '>' {
+ inTag = false
+ }
+ } else {
+ // Escape special characters in text content
+ switch ch {
+ case '&':
+ result.WriteString("&")
+ case '<':
+ result.WriteString("<")
+ case '>':
+ result.WriteString(">")
+ default:
+ result.WriteByte(ch)
+ }
+ }
+ }
+
+ return result.String()
+}
+
+func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
+ // Escape any unescaped entities in text content
+ escaped := escapeContent(raw.raw)
+
+ // Wrap the content in a root element to make it valid XML
+ xmlString := "" + escaped + ""
+
+ // Parse XML into struct
+ var parsed ToolCallXML
+ if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
+ return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
+ }
+
+ // Extract and trim function name
+ functionName := strings.TrimSpace(parsed.Content)
+ if functionName == "" {
+ return api.ToolCall{}, fmt.Errorf("empty function name")
+ }
+
+ // Verify keys and values are paired correctly
+ if len(parsed.Keys) != len(parsed.Values) {
+ return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
+ }
+
+ // Find the matching tool to get parameter types
+ var matchedTool *api.Tool
+ for i := range tools {
+ if tools[i].Function.Name == functionName {
+ matchedTool = &tools[i]
+ break
+ }
+ }
+
+ // Build arguments map by pairing keys and values
+ toolCall := api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: functionName,
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ }
+
+ for i := range parsed.Keys {
+ key := strings.TrimSpace(parsed.Keys[i])
+ value := parsed.Values[i] // Don't trim here - parseValue handles it
+
+ // Look up parameter type
+ var paramType api.PropertyType
+ if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
+ if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
+ // Handle anyOf by collecting all types from the union
+ if len(prop.AnyOf) > 0 {
+ for _, anyOfProp := range prop.AnyOf {
+ paramType = append(paramType, anyOfProp.Type...)
+ }
+ } else {
+ paramType = prop.Type
+ }
+ }
+ }
+
+ // Parse value with type coercion
+ toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
+ }
+
+ return toolCall, nil
+}
+
+// parseValue parses a string value and coerces it to the appropriate type based on paramType.
+func parseValue(value string, paramType api.PropertyType) any {
+ value = strings.TrimSpace(value)
+
+ // If no type specified, return as string
+ if len(paramType) == 0 {
+ return value
+ }
+
+ // Try to parse based on specified types
+ for _, t := range paramType {
+ switch t {
+ case "boolean":
+ if value == "true" {
+ return true
+ }
+ if value == "false" {
+ return false
+ }
+ case "integer":
+ var i int64
+ if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
+ return i
+ }
+ case "number":
+ var f float64
+ if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
+ return f
+ }
+ case "array", "object":
+ // Try to parse as JSON
+ var result any
+ if err := json.Unmarshal([]byte(value), &result); err == nil {
+ return result
+ }
+ }
+ }
+
+ // Default to string
+ return value
+}
diff --git a/x/imagegen/models/glm4_moe_lite/parser_test.go b/x/imagegen/models/glm4_moe_lite/parser_test.go
new file mode 100644
index 000000000..0ce382709
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/parser_test.go
@@ -0,0 +1,192 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestParserThinking(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ thinkEnabled bool
+ wantContent string
+ wantThinking string
+ wantToolCalls int
+ }{
+ {
+ name: "thinking enabled - simple thinking then content",
+ input: "Let me think about this...Here is my answer.",
+ thinkEnabled: true,
+ wantThinking: "Let me think about this...",
+ wantContent: "Here is my answer.",
+ },
+ {
+ name: "thinking enabled - only thinking",
+ input: "I need to consider multiple factors...",
+ thinkEnabled: true,
+ wantThinking: "I need to consider multiple factors...",
+ wantContent: "",
+ },
+ {
+ name: "thinking disabled - direct content",
+ input: "Here is my direct answer.",
+ thinkEnabled: false,
+ wantThinking: "",
+ wantContent: "Here is my direct answer.",
+ },
+ {
+ name: "thinking with tool call",
+ input: "Let me search for that...I'll use a tool.searchquerytest",
+ thinkEnabled: true,
+ wantThinking: "Let me search for that...",
+ wantContent: "I'll use a tool.",
+ wantToolCalls: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ p := &Parser{}
+
+ var thinkValue *api.ThinkValue
+ if tt.thinkEnabled {
+ thinkValue = &api.ThinkValue{Value: true}
+ } else {
+ thinkValue = &api.ThinkValue{Value: false}
+ }
+
+ // Define tools for tool call tests
+ props := api.NewToolPropertiesMap()
+ props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "search",
+ Parameters: api.ToolFunctionParameters{
+ Properties: props,
+ },
+ },
+ },
+ }
+
+ p.Init(tools, nil, thinkValue)
+
+ content, thinking, calls, err := p.Add(tt.input, true)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if thinking != tt.wantThinking {
+ t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
+ }
+ if content != tt.wantContent {
+ t.Errorf("content = %q, want %q", content, tt.wantContent)
+ }
+ if len(calls) != tt.wantToolCalls {
+ t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
+ }
+ })
+ }
+}
+
+func TestParserToolCall(t *testing.T) {
+ p := &Parser{}
+
+ props := api.NewToolPropertiesMap()
+ props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
+ props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Parameters: api.ToolFunctionParameters{
+ Properties: props,
+ },
+ },
+ },
+ }
+
+ // Initialize with thinking disabled
+ tv := &api.ThinkValue{Value: false}
+ p.Init(tools, nil, tv)
+
+ input := "get_weatherlocationSan Franciscounitcelsius"
+
+ _, _, calls, err := p.Add(input, true)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(calls) != 1 {
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
+ }
+
+ call := calls[0]
+ if call.Function.Name != "get_weather" {
+ t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
+ }
+
+ location, ok := call.Function.Arguments.Get("location")
+ if !ok || location != "San Francisco" {
+ t.Errorf("location = %v, want %q", location, "San Francisco")
+ }
+
+ unit, ok := call.Function.Arguments.Get("unit")
+ if !ok || unit != "celsius" {
+ t.Errorf("unit = %v, want %q", unit, "celsius")
+ }
+}
+
+func TestOverlap(t *testing.T) {
+ tests := []struct {
+ s string
+ tag string
+ want int
+ }{
+ {"hello<", "", 1},
+ {"hello", "", 2},
+ {"hello", 3},
+ {"hello", 4},
+ {"hello", 5},
+ {"hello", 6},
+ {"hello", 7},
+ {"hello", "", 8}, // Complete tag at end returns full length
+ {"hello", "", 0},
+ {"", "", 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
+ got := overlap(tt.s, tt.tag)
+ if got != tt.want {
+ t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestTrailingWhitespaceLen(t *testing.T) {
+ tests := []struct {
+ s string
+ want int
+ }{
+ {"hello ", 3},
+ {"hello\n\t ", 3},
+ {"hello", 0},
+ {"", 0},
+ {" ", 3},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.s, func(t *testing.T) {
+ got := trailingWhitespaceLen(tt.s)
+ if got != tt.want {
+ t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/x/imagegen/models/glm4_moe_lite/render.go b/x/imagegen/models/glm4_moe_lite/render.go
new file mode 100644
index 000000000..4998604bf
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/render.go
@@ -0,0 +1,175 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+)
+
+// Renderer renders messages for GLM4-MoE-Lite models.
+//
+// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
+//
+// 1. INTERLEAVED THINKING
+// The model thinks between tool calls and after receiving tool results.
+// This enables complex step-by-step reasoning: interpreting each tool output
+// before deciding what to do next. Thinking blocks are preserved and returned
+// with tool results to maintain reasoning continuity.
+//
+// 2. PRESERVED THINKING
+// The model retains reasoning content from previous assistant turns in context.
+// This preserves reasoning continuity across multi-turn conversations. The
+// upstream API has a "clear_thinking" parameter to control this:
+// - clear_thinking=true: clears reasoning from previous turns (outputs )
+// - clear_thinking=false: preserves ... blocks from previous turns
+//
+// 3. TURN-LEVEL THINKING
+// Controls whether the model should reason on each turn. The upstream API
+// uses "enable_thinking" parameter:
+// - enable_thinking=true: outputs to start reasoning
+// - enable_thinking=false: outputs to skip reasoning
+//
+// OLLAMA DEFAULTS:
+// - Thinking is ENABLED by default (thinkValue=nil or true outputs )
+// - Thinking is PRESERVED by default (reasoning content from previous turns is always
+// included in ... blocks, equivalent to clear_thinking=false)
+// - Users can disable thinking per-turn via thinkValue=false
+type Renderer struct{}
+
+// Render renders messages into the GLM4 chat format.
+func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
+ var sb strings.Builder
+
+ sb.WriteString("[gMASK]")
+
+ if len(tools) > 0 {
+ sb.WriteString("<|system|>\n")
+ sb.WriteString("# Tools\n\n")
+ sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
+ sb.WriteString("You are provided with function signatures within XML tags:\n")
+ sb.WriteString("\n")
+ for _, tool := range tools {
+ d, _ := json.Marshal(tool)
+ sb.WriteString(formatToolJSON(d))
+ sb.WriteString("\n")
+ }
+ sb.WriteString("\n\n")
+ sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
+ sb.WriteString("{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...")
+ }
+
+ think := true
+ if thinkValue != nil && !thinkValue.Bool() {
+ think = false
+ }
+
+ for i, message := range messages {
+ switch message.Role {
+ case "user":
+ sb.WriteString("<|user|>")
+ sb.WriteString(message.Content)
+ case "assistant":
+ sb.WriteString("<|assistant|>")
+ if message.Thinking != "" {
+ sb.WriteString("" + message.Thinking + "")
+ } else {
+ sb.WriteString("")
+ }
+ if message.Content != "" {
+ sb.WriteString(message.Content)
+ }
+ if len(message.ToolCalls) > 0 {
+ for _, toolCall := range message.ToolCalls {
+ sb.WriteString("" + toolCall.Function.Name)
+ sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
+ sb.WriteString("")
+ }
+ }
+ case "tool":
+ if i == 0 || messages[i-1].Role != "tool" {
+ sb.WriteString("<|observation|>")
+ }
+ sb.WriteString("")
+ sb.WriteString(message.Content)
+ sb.WriteString("")
+ case "system":
+ sb.WriteString("<|system|>")
+ sb.WriteString(message.Content)
+ }
+ }
+
+ sb.WriteString("<|assistant|>")
+ if think {
+ sb.WriteString("")
+ } else {
+ sb.WriteString("")
+ }
+
+ return sb.String(), nil
+}
+
+// renderToolArguments converts tool call arguments to GLM4 XML format.
+func renderToolArguments(args api.ToolCallFunctionArguments) string {
+ var sb strings.Builder
+ for key, value := range args.All() {
+ sb.WriteString("" + key + "")
+ var valueStr string
+ if str, ok := value.(string); ok {
+ valueStr = str
+ } else {
+ jsonBytes, err := json.Marshal(value)
+ if err != nil {
+ valueStr = fmt.Sprintf("%v", value)
+ } else {
+ valueStr = string(jsonBytes)
+ }
+ }
+
+ sb.WriteString("" + valueStr + "")
+ }
+
+ return sb.String()
+}
+
+// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
+func formatToolJSON(raw []byte) string {
+ var sb strings.Builder
+ sb.Grow(len(raw) + len(raw)/10)
+
+ inString := false
+ escaped := false
+ for i := range raw {
+ ch := raw[i]
+ sb.WriteByte(ch)
+
+ if inString {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == '"' {
+ inString = false
+ }
+ continue
+ }
+
+ if ch == '"' {
+ inString = true
+ continue
+ }
+
+ if ch == ':' || ch == ',' {
+ sb.WriteByte(' ')
+ }
+ }
+
+ return sb.String()
+}
diff --git a/x/imagegen/models/glm4_moe_lite/render_test.go b/x/imagegen/models/glm4_moe_lite/render_test.go
new file mode 100644
index 000000000..f0d576bec
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/render_test.go
@@ -0,0 +1,205 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestRendererSimple(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "Hello"},
+ }
+
+ // Thinking enabled (default)
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := "[gMASK]<|user|>Hello<|assistant|>"
+ if result != expected {
+ t.Errorf("result = %q, want %q", result, expected)
+ }
+}
+
+func TestRendererThinkingDisabled(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "Hello"},
+ }
+
+ tv := &api.ThinkValue{Value: false}
+
+ result, err := r.Render(messages, nil, tv)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := "[gMASK]<|user|>Hello<|assistant|>"
+ if result != expected {
+ t.Errorf("result = %q, want %q", result, expected)
+ }
+}
+
+func TestRendererMultiTurn(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "What is 2+2?"},
+ {Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
+ {Role: "user", Content: "And 3+3?"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Check key parts
+ if !strings.Contains(result, "[gMASK]") {
+ t.Error("missing [gMASK] prefix")
+ }
+ if !strings.Contains(result, "<|user|>What is 2+2?") {
+ t.Error("missing first user message")
+ }
+ if !strings.Contains(result, "<|assistant|>Let me calculate: 2+2=44") {
+ t.Error("missing assistant message with thinking")
+ }
+ if !strings.Contains(result, "<|user|>And 3+3?") {
+ t.Error("missing second user message")
+ }
+ if !strings.HasSuffix(result, "<|assistant|>") {
+ t.Errorf("should end with <|assistant|>, got suffix: %q", result[len(result)-30:])
+ }
+}
+
+func TestRendererWithSystem(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "system", Content: "You are a helpful assistant."},
+ {Role: "user", Content: "Hello"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
+ t.Error("missing system message")
+ }
+}
+
+func TestRendererWithTools(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ }
+
+ props := api.NewToolPropertiesMap()
+ props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get the weather for a location",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: props,
+ Required: []string{"location"},
+ },
+ },
+ },
+ }
+
+ result, err := r.Render(messages, tools, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Check for tool system prompt
+ if !strings.Contains(result, "<|system|>") {
+ t.Error("missing system tag for tools")
+ }
+ if !strings.Contains(result, "# Tools") {
+ t.Error("missing tools header")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing tools tag")
+ }
+ if !strings.Contains(result, "get_weather") {
+ t.Error("missing tool name")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing closing tools tag")
+ }
+}
+
+func TestRendererWithToolCalls(t *testing.T) {
+ r := &Renderer{}
+
+ args := api.NewToolCallFunctionArguments()
+ args.Set("location", "San Francisco")
+
+ messages := []api.Message{
+ {Role: "user", Content: "What's the weather in SF?"},
+ {
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: args,
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "Sunny, 72F"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if !strings.Contains(result, "get_weather") {
+ t.Error("missing tool call")
+ }
+ if !strings.Contains(result, "location") {
+ t.Error("missing arg_key")
+ }
+ if !strings.Contains(result, "San Francisco") {
+ t.Error("missing arg_value")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing tool call closing tag")
+ }
+ if !strings.Contains(result, "<|observation|>") {
+ t.Error("missing observation tag")
+ }
+ if !strings.Contains(result, "Sunny, 72F") {
+ t.Error("missing tool response")
+ }
+}
+
+func TestFormatToolJSON(t *testing.T) {
+ input := []byte(`{"name":"test","value":123}`)
+ result := formatToolJSON(input)
+
+ // Should add spaces after : and ,
+ if !strings.Contains(result, ": ") {
+ t.Error("should add space after colon")
+ }
+ if !strings.Contains(result, ", ") {
+ t.Error("should add space after comma")
+ }
+}
diff --git a/x/imagegen/nn/nn.go b/x/imagegen/nn/nn.go
index 65bf7fa22..d72474358 100644
--- a/x/imagegen/nn/nn.go
+++ b/x/imagegen/nn/nn.go
@@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
// Quantizes the weight immediately and evaluates to break lazy dependencies.
+// Note: For modes like "nvfp4", qbiases will be nil.
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
// Eval immediately so bf16 weight can be freed
- mlx.Eval(qw, scales, qbiases)
+ // Handle modes that don't return qbiases (e.g., nvfp4)
+ if qbiases != nil {
+ mlx.Eval(qw, scales, qbiases)
+ } else {
+ mlx.Eval(qw, scales)
+ }
return &QuantizedLinear{
Weight: qw,
Scales: scales,
@@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear
// QuantizedLinear applies an affine transformation using quantized weights.
// Equivalent to mlx.nn.QuantizedLinear.
+// Supports multiple quantization modes:
+// - "affine": scale + zero-point bias (QBiases required)
+// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil)
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
- QBiases *mlx.Array // Quantization biases (NOT layer bias)
+ QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int
@@ -220,3 +229,32 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
}
return out
}
+
+// MultiLinearLayer is an interface for per-head linear layers.
+// This allows swapping between MultiLinear (bf16) and pre-dequantized weights.
+type MultiLinearLayer interface {
+ Forward(x *mlx.Array) *mlx.Array
+}
+
+// MultiLinear performs per-head linear projections.
+// Weight shape: [num_heads, output_dims, input_dims]
+// Input shape: [B, num_heads, L, input_dims]
+// Output shape: [B, num_heads, L, output_dims]
+type MultiLinear struct {
+ Weight *mlx.Array `weight:"weight"`
+}
+
+// NewMultiLinear creates a MultiLinear layer with the given weight.
+func NewMultiLinear(weight *mlx.Array) *MultiLinear {
+ return &MultiLinear{Weight: weight}
+}
+
+// Forward applies per-head linear transformation: x @ weight.T per head via broadcasting.
+func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
+ // Weight: [num_heads, output_dims, input_dims]
+ // x: [B, num_heads, L, input_dims]
+ // wT: [num_heads, input_dims, output_dims]
+ // Result: [B, num_heads, L, output_dims]
+ wT := mlx.Transpose(ml.Weight, 0, 2, 1)
+ return mlx.Matmul(x, wT)
+}
diff --git a/x/imagegen/runner/runner.go b/x/imagegen/runner/runner.go
deleted file mode 100644
index f43276468..000000000
--- a/x/imagegen/runner/runner.go
+++ /dev/null
@@ -1,284 +0,0 @@
-//go:build mlx
-
-// Package runner provides a subprocess server for image generation.
-// It listens on a port and handles HTTP requests for image generation.
-package runner
-
-import (
- "context"
- "encoding/json"
- "flag"
- "fmt"
- "image"
- "log/slog"
- "net/http"
- "os"
- "os/signal"
- "sync"
- "syscall"
- "time"
-
- "github.com/ollama/ollama/x/imagegen"
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/models/flux2"
- "github.com/ollama/ollama/x/imagegen/models/zimage"
-)
-
-// Request is the image generation request format
-type Request struct {
- Prompt string `json:"prompt"`
- Width int32 `json:"width,omitempty"`
- Height int32 `json:"height,omitempty"`
- Steps int `json:"steps,omitempty"`
- Seed int64 `json:"seed,omitempty"`
- Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
-}
-
-// Response is streamed back for each progress update
-type Response struct {
- Content string `json:"content,omitempty"`
- Image string `json:"image,omitempty"` // Base64-encoded PNG
- Done bool `json:"done"`
- Step int `json:"step,omitempty"`
- Total int `json:"total,omitempty"`
-}
-
-// ImageModel is the interface for image generation models
-type ImageModel interface {
- GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
-}
-
-// ImageEditModel extends ImageModel with image editing/conditioning capability.
-// Models that support input images for editing should implement this interface.
-type ImageEditModel interface {
- ImageModel
- GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
-}
-
-// Server holds the model and handles requests
-type Server struct {
- mu sync.Mutex
- model ImageModel
- modelName string
-}
-
-// Execute is the entry point for the image runner subprocess
-func Execute(args []string) error {
- fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
- modelName := fs.String("model", "", "path to image model")
- port := fs.Int("port", 0, "port to listen on")
-
- if err := fs.Parse(args); err != nil {
- return err
- }
-
- if *modelName == "" {
- return fmt.Errorf("--model is required")
- }
- if *port == 0 {
- return fmt.Errorf("--port is required")
- }
-
- err := mlx.InitMLX()
- if err != nil {
- slog.Error("unable to initialize MLX", "error", err)
- return err
- }
- slog.Info("MLX library initialized")
- slog.Info("starting image runner", "model", *modelName, "port", *port)
-
- // Detect model type and load appropriate model
- modelType := imagegen.DetectModelType(*modelName)
- slog.Info("detected model type", "type", modelType)
-
- var model ImageModel
- switch modelType {
- case "Flux2KleinPipeline":
- m := &flux2.Model{}
- if err := m.Load(*modelName); err != nil {
- return fmt.Errorf("failed to load model: %w", err)
- }
- model = m
- default:
- // Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
- m := &zimage.Model{}
- if err := m.Load(*modelName); err != nil {
- return fmt.Errorf("failed to load model: %w", err)
- }
- model = m
- }
-
- server := &Server{
- model: model,
- modelName: *modelName,
- }
-
- // Set up HTTP handlers
- mux := http.NewServeMux()
- mux.HandleFunc("/health", server.healthHandler)
- mux.HandleFunc("/completion", server.completionHandler)
-
- httpServer := &http.Server{
- Addr: fmt.Sprintf("127.0.0.1:%d", *port),
- Handler: mux,
- }
-
- // Handle shutdown
- done := make(chan struct{})
- go func() {
- sigCh := make(chan os.Signal, 1)
- signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
- <-sigCh
- slog.Info("shutting down image runner")
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- httpServer.Shutdown(ctx)
- close(done)
- }()
-
- slog.Info("image runner listening", "addr", httpServer.Addr)
- if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
- return err
- }
-
- <-done
- return nil
-}
-
-func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
-}
-
-func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost {
- http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
- return
- }
-
- var req Request
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
-
- // Validate and decode input images
- const maxInputImages = 2
- if len(req.Images) > maxInputImages {
- http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
- return
- }
-
- var inputImages []image.Image
- if len(req.Images) > 0 {
- // TODO: add memory check for input images
-
- inputImages = make([]image.Image, len(req.Images))
- for i, imgBytes := range req.Images {
- img, err := imagegen.DecodeImage(imgBytes)
- if err != nil {
- http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
- return
- }
- inputImages[i] = img
- }
- slog.Info("decoded input images", "count", len(inputImages))
-
- // Default width/height to first input image dimensions, scaled to max 1024
- bounds := inputImages[0].Bounds()
- w, h := bounds.Dx(), bounds.Dy()
- if w > 1024 || h > 1024 {
- if w > h {
- h = h * 1024 / w
- w = 1024
- } else {
- w = w * 1024 / h
- h = 1024
- }
- }
- req.Width = int32(w)
- req.Height = int32(h)
- }
-
- // Serialize generation requests - MLX model may not handle concurrent generation
- s.mu.Lock()
- defer s.mu.Unlock()
-
- // Model applies its own defaults for width/height/steps
- // Only seed needs to be set here if not provided
- if req.Seed <= 0 {
- req.Seed = time.Now().UnixNano()
- }
-
- // Set up streaming response
- w.Header().Set("Content-Type", "application/x-ndjson")
- w.Header().Set("Transfer-Encoding", "chunked")
- flusher, ok := w.(http.Flusher)
- if !ok {
- http.Error(w, "streaming not supported", http.StatusInternalServerError)
- return
- }
-
- // Generate image using the common interface
- ctx := r.Context()
- enc := json.NewEncoder(w)
-
- // Progress callback streams step updates
- progress := func(step, total int) {
- resp := Response{Step: step, Total: total}
- enc.Encode(resp)
- w.Write([]byte("\n"))
- flusher.Flush()
- }
-
- // Use ImageEditModel if available and images provided, otherwise use basic ImageModel
- var img *mlx.Array
- var err error
- if len(inputImages) > 0 {
- editModel, ok := s.model.(ImageEditModel)
- if !ok {
- http.Error(w, "model does not support image editing", http.StatusBadRequest)
- return
- }
- img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
- } else {
- img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
- }
-
- if err != nil {
- // Don't send error for cancellation
- if ctx.Err() != nil {
- return
- }
- resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
- data, _ := json.Marshal(resp)
- w.Write(data)
- w.Write([]byte("\n"))
- return
- }
-
- // Encode image as base64 PNG
- imageData, err := imagegen.EncodeImageBase64(img)
- if err != nil {
- resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
- data, _ := json.Marshal(resp)
- w.Write(data)
- w.Write([]byte("\n"))
- return
- }
-
- // Free the generated image array and clean up MLX state
- img.Free()
- mlx.ClearCache()
- mlx.MetalResetPeakMemory()
-
- // Send final response with image data
- resp := Response{
- Image: imageData,
- Done: true,
- }
- data, _ := json.Marshal(resp)
- w.Write(data)
- w.Write([]byte("\n"))
- flusher.Flush()
-}
diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go
index 7f8860b06..4c1d0a9af 100644
--- a/x/imagegen/safetensors/loader.go
+++ b/x/imagegen/safetensors/loader.go
@@ -17,17 +17,31 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
- Quantization() string // Returns "FP4", "FP8", or ""
+ Quantization() string // Returns "NVFP4", "Q4", "Q8", or ""
+ GroupSize() int // Returns quantization group size, or 0 if not specified
}
-// quantizationParams returns groupSize, bits, mode for a quantization type.
-// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
-func quantizationParams(quantization string) (groupSize, bits int, mode string) {
+// QuantizationParams returns groupSize, bits, mode for a quantization type.
+// MLX quantization modes:
+// - "affine": scale + zero-point bias, group_size=32/64/128
+// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias)
+// - "mxfp8": Microsoft MX FP8 with E4M3 scales, group_size=32 (no bias)
+func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
- case "FP4":
+ case "NVFP4":
+ // NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias)
+ return 16, 4, "nvfp4"
+ case "FP4", "Q4", "INT4":
+ // 4-bit quantization with affine mode (scale + qbias)
return 32, 4, "affine"
+ case "MXFP8":
+ // Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias)
+ return 32, 8, "mxfp8"
+ case "FP8", "Q8", "INT8", "":
+ // 8-bit quantization with affine mode (default for quantized models)
+ return 64, 8, "affine"
default:
- return 32, 8, "affine" // FP8 or unknown
+ return 32, 8, "affine" // Default to affine
}
}
@@ -122,7 +136,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
// Handle nn.LinearLayer interface fields specially
- if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
+ linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem()
+ if field.Type == linearLayerType {
if !hasTag {
continue // no tag = skip
}
@@ -137,6 +152,23 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
continue
}
+ // Handle nn.MultiLinearLayer interface fields specially
+ multiLinearLayerType := reflect.TypeOf((*nn.MultiLinearLayer)(nil)).Elem()
+ if field.Type == multiLinearLayerType {
+ if !hasTag {
+ continue // no tag = skip
+ }
+ layer, err := LoadMultiLinearLayer(weights, fullPath)
+ if err != nil {
+ if !optional {
+ *errs = append(*errs, fullPath+": "+err.Error())
+ }
+ continue
+ }
+ fieldVal.Set(reflect.ValueOf(layer))
+ continue
+ }
+
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -216,12 +248,86 @@ func joinPath(prefix, suffix string) string {
return prefix + "." + suffix
}
+// LoadMultiLinearLayer loads a per-head linear layer from weights.
+// Weight shape should be [num_heads, output_dims, input_dims].
+// If quantized, always dequantizes since batched quantized matmul isn't supported.
+func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLayer, error) {
+ // Check if this is a quantized layer by looking for scale tensor
+ scalePath := path + ".weight_scale"
+ hasScale := weights.HasTensor(scalePath)
+
+ weight, err := weights.GetTensor(path + ".weight")
+ if err != nil {
+ return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
+ }
+
+ if hasScale {
+ scales, err := weights.GetTensor(scalePath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
+ }
+
+ var qbiases *mlx.Array
+ qbiasPath := path + ".weight_qbias"
+ if weights.HasTensor(qbiasPath) {
+ qbiases, _ = weights.GetTensor(qbiasPath)
+ }
+
+ // Always dequantize for MultiLinear - no batched quantized matmul support
+ // Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
+ weightShape := weight.Shape()
+ scalesShape := scales.Shape()
+ weightCols := int(weightShape[len(weightShape)-1])
+ scalesCols := int(scalesShape[len(scalesShape)-1])
+
+ // Detect quantization from tensor shapes
+ // groupSize = weightCols * packFactor / scalesCols
+ // Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
+ groupSize4 := weightCols * 8 / scalesCols
+ groupSize8 := weightCols * 4 / scalesCols
+
+ var bits, groupSize int
+ // Use metadata to help disambiguate when shapes are ambiguous
+ // (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
+ quantType := strings.ToUpper(weights.Quantization())
+ isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
+
+ if groupSize4 == 32 {
+ // Unambiguous: Q4 with group_size=32
+ bits = 4
+ groupSize = 32
+ } else if groupSize8 == 64 {
+ // Unambiguous: Q8 with group_size=64
+ bits = 8
+ groupSize = 64
+ } else if groupSize4 == 64 && groupSize8 == 32 {
+ // Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
+ if isQ8Type {
+ bits = 8
+ groupSize = 32
+ } else {
+ bits = 4
+ groupSize = 64
+ }
+ } else {
+ // Fallback: use global quantization params
+ _, bits, _ = QuantizationParams(weights.Quantization())
+ packFactor := 32 / bits
+ groupSize = weightCols * packFactor / scalesCols
+ }
+ weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine")
+ }
+
+ return nn.NewMultiLinear(weight), nil
+}
+
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
-// If {path}.weight_scale exists, dequantizes the weights.
+// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support).
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
- if weights.HasTensor(scalePath) {
+ hasScale := weights.HasTensor(scalePath)
+ if hasScale {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
@@ -245,9 +351,52 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
qbiases, _ = weights.GetTensor(qbiasPath)
}
- groupSize, bits, mode := quantizationParams(weights.Quantization())
+ // Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
+ weightShape := weight.Shape()
+ scalesShape := scales.Shape()
+ weightCols := int(weightShape[len(weightShape)-1])
+ scalesCols := int(scalesShape[len(scalesShape)-1])
- if mlx.MetalIsAvailable() {
+ // Detect quantization from tensor shapes
+ // groupSize = weightCols * packFactor / scalesCols
+ // Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
+ groupSize4 := weightCols * 8 / scalesCols
+ groupSize8 := weightCols * 4 / scalesCols
+
+ var bits, groupSize int
+ mode := "affine"
+ // Use metadata to help disambiguate when shapes are ambiguous
+ // (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
+ quantType := strings.ToUpper(weights.Quantization())
+ isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
+
+ if groupSize4 == 32 {
+ // Unambiguous: Q4 with group_size=32
+ bits = 4
+ groupSize = 32
+ } else if groupSize8 == 64 {
+ // Unambiguous: Q8 with group_size=64
+ bits = 8
+ groupSize = 64
+ } else if groupSize4 == 64 && groupSize8 == 32 {
+ // Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
+ if isQ8Type {
+ bits = 8
+ groupSize = 32
+ } else {
+ bits = 4
+ groupSize = 64
+ }
+ } else {
+ // Fallback: use global quantization params
+ _, bits, mode = QuantizationParams(weights.Quantization())
+ packFactor := 32 / bits
+ groupSize = weightCols * packFactor / scalesCols
+ }
+
+ // NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX,
+ // so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.
+ if mlx.MetalIsAvailable() && mode != "nvfp4" && mode != "mxfp8" {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go
index a36052fce..4dbcf59a3 100644
--- a/x/imagegen/safetensors/safetensors.go
+++ b/x/imagegen/safetensors/safetensors.go
@@ -303,6 +303,11 @@ func (mw *ModelWeights) Quantization() string {
return ""
}
+// GroupSize returns 0 for directory-based weights (use default).
+func (mw *ModelWeights) GroupSize() int {
+ return 0
+}
+
// ReleaseAll releases all cached native file handles.
func (mw *ModelWeights) ReleaseAll() {
for path, native := range mw.nativeCache {
diff --git a/x/imagegen/server_test.go b/x/imagegen/server_test.go
deleted file mode 100644
index 396aa140b..000000000
--- a/x/imagegen/server_test.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package imagegen
-
-import (
- "runtime"
- "testing"
-)
-
-// TestPlatformSupport verifies platform validation works correctly.
-func TestPlatformSupport(t *testing.T) {
- err := CheckPlatformSupport()
-
- switch runtime.GOOS {
- case "darwin":
- if runtime.GOARCH == "arm64" {
- // Apple Silicon should be supported
- if err != nil {
- t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
- }
- } else {
- // Intel Mac should fail
- if err == nil {
- t.Error("Expected error on darwin/amd64 (Intel), got nil")
- }
- if err != nil && err.Error() == "" {
- t.Error("Expected meaningful error message for unsupported platform")
- }
- }
- case "linux", "windows":
- // Linux/Windows are allowed (CUDA support checked at runtime)
- if err != nil {
- t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
- }
- default:
- // Other platforms should fail
- if err == nil {
- t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
- }
- }
-}
-
-// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
-// This is a compile-time check but we document it as a test.
-func TestServerInterfaceCompliance(t *testing.T) {
- // The var _ llm.LlamaServer = (*Server)(nil) line in server.go
- // ensures compile-time interface compliance.
- // This test documents that requirement.
- t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
-}
diff --git a/x/imagegen/weights.go b/x/imagegen/weights.go
index f49c7e77e..eb60c9895 100644
--- a/x/imagegen/weights.go
+++ b/x/imagegen/weights.go
@@ -20,20 +20,28 @@ type ManifestWeights struct {
nativeCache []*mlx.SafetensorsFile // keep native handles alive
}
-// LoadWeightsFromManifest creates a weight loader for a component from manifest storage.
+// LoadWeightsFromManifest creates a weight loader from manifest storage.
+// If component is empty, loads all tensors (for LLM models).
+// If component is specified, loads only tensors for that component and strips the prefix.
func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
layers := manifest.GetTensorLayers(component)
if len(layers) == 0 {
+ if component == "" {
+ return nil, fmt.Errorf("no tensor layers found in manifest")
+ }
return nil, fmt.Errorf("no tensor layers found for component %q", component)
}
// Strip component prefix from tensor names for model loading
// e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
- prefix := component + "/"
tensors := make(map[string]ManifestLayer, len(layers))
for _, layer := range layers {
- tensorName := strings.TrimPrefix(layer.Name, prefix)
- tensors[tensorName] = layer
+ if component == "" {
+ tensors[layer.Name] = layer
+ } else {
+ tensorName := strings.TrimPrefix(layer.Name, component+"/")
+ tensors[tensorName] = layer
+ }
}
return &ManifestWeights{
@@ -48,19 +56,30 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
// If dtype is non-zero, tensors are converted to the specified dtype.
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
+ // Track native handles to free after batch eval
+ nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
+ arrays := make([]*mlx.Array, 0, len(mw.tensors))
+
for name, layer := range mw.tensors {
path := mw.manifest.BlobPath(layer.Digest)
// Load blob as safetensors (native mmap, zero-copy)
sf, err := mlx.LoadSafetensorsNative(path)
if err != nil {
+ // Free any handles we've accumulated
+ for _, h := range nativeHandles {
+ h.Free()
+ }
return fmt.Errorf("load %s: %w", name, err)
}
+ nativeHandles = append(nativeHandles, sf)
// Blob contains single tensor named "data"
arr := sf.Get("data")
if arr == nil {
- sf.Free()
+ for _, h := range nativeHandles {
+ h.Free()
+ }
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
}
@@ -68,11 +87,18 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
if dtype != 0 && arr.Dtype() != dtype {
arr = mlx.AsType(arr, dtype)
}
- // ALWAYS make a contiguous copy to ensure independence from mmap
+ // Make contiguous copy to ensure independence from mmap
arr = mlx.Contiguous(arr)
- mlx.Eval(arr)
mw.cache[name] = arr
- sf.Free() // Safe to free - arr is now an independent copy
+ arrays = append(arrays, arr)
+ }
+
+ // Batch evaluate all tensors at once (much faster than one at a time)
+ mlx.Eval(arrays...)
+
+ // Now safe to free all native handles
+ for _, sf := range nativeHandles {
+ sf.Free()
}
return nil
@@ -107,18 +133,112 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
}
// Quantization returns the model's quantization type from model_index.json.
-// Returns empty string if not quantized or unknown.
+// Returns empty string if not quantized.
+// Falls back to detecting from tensor names and shapes if not in config.
func (mw *ManifestWeights) Quantization() string {
if mw.manifest == nil {
return ""
}
+
+ // Try to read from model_index.json first
var index struct {
Quantization string `json:"quantization"`
}
- if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
+ if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.Quantization != "" {
+ return index.Quantization
+ }
+
+ // Fallback: detect from tensor names
+ // Check if any tensors have _scale suffix (indicates quantization)
+ hasScales := false
+ hasQBias := false
+ for name := range mw.tensors {
+ if strings.HasSuffix(name, ".weight_scale") {
+ hasScales = true
+ }
+ if strings.HasSuffix(name, ".weight_qbias") {
+ hasQBias = true
+ }
+ }
+
+ if !hasScales {
+ // No scales = not quantized
return ""
}
- return index.Quantization
+
+ // Has scales but no qbias = NVFP4 (or other non-affine mode)
+ if !hasQBias {
+ return "NVFP4"
+ }
+
+ // Has both scales and qbias = affine mode
+ // Need to determine FP4 vs FP8 from tensor shapes
+ // FP4: weight last dim is 1/8 of scales last dim * group_size
+ // FP8: weight last dim is 1/4 of scales last dim * group_size
+ //
+ // For affine mode with group_size=32:
+ // - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8
+ // - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4
+ // scales_dim = orig_dim / group_size
+ // So: weight_dim / scales_dim = group_size / pack_factor
+ // FP4: ratio = 32/8 = 4
+ // FP8: ratio = 32/4 = 8
+
+ // Find a weight/scale pair to check the ratio
+ for name := range mw.tensors {
+ if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") {
+ continue
+ }
+ scaleName := name + "_scale"
+ if _, ok := mw.tensors[scaleName]; !ok {
+ continue
+ }
+
+ // Load both tensors to check shapes
+ weightLayer := mw.tensors[name]
+ scaleLayer := mw.tensors[scaleName]
+
+ // Get shapes from manifest layer metadata if available
+ // For now, default to FP4 since it's more common
+ // The actual shape check would require loading the tensor
+
+ // Simple heuristic: check if scale tensor is ~4x smaller than weight
+ // FP4: weight is packed 8 per uint32, scales are 1 per group (32)
+ // So scale size should be ~weight_size * 8 / 32 = weight_size / 4
+ // FP8: weight is packed 4 per uint32, scales are 1 per group (32)
+ // So scale size should be ~weight_size * 4 / 32 = weight_size / 8
+
+ // Rough size heuristic (assuming float16 scales)
+ // Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
+ // Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
+ ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
+ if ratio < 12 {
+ // Closer to 8 = Q4
+ return "Q4"
+ }
+ // Closer to 16 = Q8
+ return "Q8"
+ }
+
+ // Default to Q4 for affine mode (most common)
+ return "Q4"
+}
+
+// GroupSize returns the quantization group size from model_index.json.
+// Returns 0 if not specified (caller should use default based on quantization type).
+func (mw *ManifestWeights) GroupSize() int {
+ if mw.manifest == nil {
+ return 0
+ }
+
+ var index struct {
+ GroupSize int `json:"group_size"`
+ }
+ if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.GroupSize > 0 {
+ return index.GroupSize
+ }
+
+ return 0
}
// ReleaseAll frees all native handles and clears the tensor cache.
diff --git a/x/kvcache/causal.go b/x/kvcache/causal.go
index 967fed674..eb5ef6f96 100644
--- a/x/kvcache/causal.go
+++ b/x/kvcache/causal.go
@@ -1,797 +1,144 @@
+//go:build mlx
+
package kvcache
-// import (
-// "errors"
-// "fmt"
-// "log/slog"
-// "math"
-// "slices"
-
-// "github.com/ollama/ollama/ml"
-// "github.com/ollama/ollama/model/input"
-// )
-
-// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
-
-// // Causal cache stores K and V tensors according to their position in the
-// // sequence. Returns the history and a mask for attending to past tokens
-// //
-// // The tensors are of shape embed dim, kv heads, batch size
-// // The mask is of shape history size, batch size
-// type Causal struct {
-// DType ml.DType
-
-// // swaWindowSize is the number of tokens that will be included in the mask
-// // during attention operations. swaMemorySize is the number of tokens that
-// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
-// // for unlimited or if sliding window attention is not being used.
-// swaWindowSize int32
-// swaMemorySize int32
-
-// chunkSize int32
-
-// opts CausalOptions
-
-// // maxBatch is the largest batch that we might receive
-// maxBatch int
-
-// // config controls mostly backend-specific optimizations
-// config *ml.CacheConfig
-
-// // ** current forward pass **
-
-// // size of the current batch
-// curBatchSize int
-
-// // locations for data storage for this batch
-// curLoc ml.Tensor
-
-// // mask of the cache as used by this batch
-// curMask ml.Tensor
-
-// // the active layer for Get and Put
-// curLayer int
-
-// // locations in the cache that are needed for this batch
-// curCellRange cellRange
-
-// // curSequences is the sequences corresponding to this pass's entries in the cache
-// curSequences []int
-
-// // curPositions is the positions corresponding to this pass's entries in the cache
-// curPositions []int32
-
-// // ** cache metadata **
-
-// // for each possible location in the cache, stores the position and set of sequences
-// // that reference the data there
-// cells []cacheCell
-
-// // maps from sequence to the range of locations where it is stored in the cache
-// cellRanges map[int]cellRange
-
-// // ** cache data storage **
-
-// shiftFn shiftFn
-// backend ml.Backend
-// ctxs map[int]ml.Context
-// keys, values map[int]ml.Tensor
-
-// kHeadDims, vHeadDims, numKVHeads map[int]int
-// }
-
-// type cacheCell struct {
-// pos int32
-// sequences []int
-// }
-
-// type cellRange struct {
-// min int
-// max int
-// }
-
-// func NewCausalCache(shift shiftFn) *Causal {
-// return &Causal{
-// shiftFn: shift,
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// kHeadDims: make(map[int]int),
-// vHeadDims: make(map[int]int),
-// numKVHeads: make(map[int]int),
-// }
-// }
-
-// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
-// return &Causal{
-// swaWindowSize: windowSize,
-// shiftFn: shift,
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// kHeadDims: make(map[int]int),
-// vHeadDims: make(map[int]int),
-// numKVHeads: make(map[int]int),
-// }
-// }
-
-// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
-// return &Causal{
-// swaWindowSize: windowSize,
-// swaMemorySize: memorySize,
-// shiftFn: shift,
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// kHeadDims: make(map[int]int),
-// vHeadDims: make(map[int]int),
-// numKVHeads: make(map[int]int),
-// }
-// }
-
-// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
-// return &Causal{
-// chunkSize: chunkSize,
-// shiftFn: shift,
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// kHeadDims: make(map[int]int),
-// vHeadDims: make(map[int]int),
-// numKVHeads: make(map[int]int),
-// }
-// }
-
-// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
-// if c.config == nil {
-// var config ml.CacheConfig
-// if cc, ok := backend.(ml.BackendCacheConfig); ok {
-// config = cc.CacheConfig()
-// }
-// c.config = &config
-// }
-
-// if c.config.CachePadding == 0 {
-// c.config.CachePadding = 1
-// }
-
-// if c.config.MaskBatchPadding == 0 {
-// c.config.MaskBatchPadding = 1
-// }
-
-// // TODO what types do we handle here?
-// // if c.config.MaskDType == ml.DTypeOther {
-// // c.config.MaskDType = ml.DTypeFloat32
-// // }
-
-// if c.swaWindowSize == 0 {
-// c.swaWindowSize = math.MaxInt32
-// }
-// if c.swaMemorySize == 0 {
-// c.swaMemorySize = c.swaWindowSize
-// }
-// // We will allocate space in the cache for the stop token, which won't be part of a follow on
-// // sequence, so allocate an extra token of storage to ensure that we can jump back without
-// // causing a cache break. As an optimization, only do this when we have parallel sequences
-// // because the extra token will live in the batch buffer and won't get overwritten if we
-// // only have a single sequence.
-// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
-// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
-// }
-// if int(c.swaMemorySize) >= capacity {
-// c.swaMemorySize = math.MaxInt32
-// }
-
-// if c.swaMemorySize < c.swaWindowSize {
-// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
-// }
-
-// var cacheSize int
-// if c.swaMemorySize == math.MaxInt32 {
-// cacheSize = maxSequences * capacity
-// } else {
-// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
-// }
-// cacheSize = roundUp(cacheSize, c.config.CachePadding)
-// c.cells = make([]cacheCell, cacheSize)
-
-// c.DType = dtype
-// c.cellRanges = make(map[int]cellRange)
-// c.backend = backend
-// c.maxBatch = maxBatch
-// }
-
-// func (c *Causal) SetConfig(config ml.CacheConfig) {
-// if c.config != nil {
-// panic("config cannot be changed after being previously set, either by the model or backend")
-// }
-
-// c.config = &config
-// }
-
-// func (c *Causal) Close() {
-// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
-// for _, ctx := range c.ctxs {
-// ctx.Close()
-// }
-// }
-
-// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
-// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
-// // panic("XXX Causal.StartForward")
-// c.curBatchSize = len(batch.Positions)
-// c.curSequences = batch.Sequences
-// c.curPositions = batch.Positions
-// c.opts.Except = nil
-
-// var locs []int32
-// if !reserve {
-// c.updateSlidingWindow()
-
-// var err error
-// locs, err = c.findLocs()
-// if err != nil {
-// return err
-// }
-// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
-
-// for i, pos := range batch.Positions {
-// seq := batch.Sequences[i]
-// loc := int(locs[i])
-
-// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
-
-// seqRange, ok := c.cellRanges[seq]
-// if !ok {
-// seqRange = newRange()
-// }
-
-// seqRange.min = min(seqRange.min, loc)
-// c.curCellRange.min = min(c.curCellRange.min, loc)
-
-// seqRange.max = max(seqRange.max, loc)
-// c.curCellRange.max = max(c.curCellRange.max, loc)
-
-// c.cellRanges[seq] = seqRange
-// }
-// } else {
-// // If we are reserving memory, don't update any of the cache metadata but set the size
-// // to the worst case.
-// locs = make([]int32, c.curBatchSize)
-// for i := range locs {
-// locs[i] = int32(i)
-// }
-// c.curCellRange.min = 0
-// c.curCellRange.max = len(c.cells) - 1
-// }
-
-// // XXX Building up the locs for what's already processed (if any)
-// dummyLocs := []int{}
-// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
-// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
-
-// for i := range c.curBatchSize {
-// enabled := !slices.Contains(c.opts.Except, i)
-// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
-// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
-// (enabled && c.cells[j].pos > c.curPositions[i]) ||
-// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
-// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
-// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
-// } else {
-// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
-// dummyLocs = append(dummyLocs, i)
-// }
-// }
-// }
-// }
-// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
-
-// slog.Info("XXX Causal.StartForward", "locs", locs)
-// c.curLoc = ctx.Input().FromInts(locs, len(locs))
-// c.curMask = c.buildMask(ctx)
-
-// return nil
-// }
-
-// func newRange() cellRange {
-// return cellRange{
-// min: math.MaxInt,
-// max: 0,
-// }
-// }
-
-// // Returns a slice of locations where each token in the batch should be stored
-// func (c *Causal) findLocs() ([]int32, error) {
-// loc := make([]int32, 0, c.curBatchSize)
-
-// for i := range c.cells {
-// if len(c.cells[i].sequences) == 0 {
-// loc = append(loc, int32(i))
-// if len(loc) >= c.curBatchSize {
-// return loc, nil
-// }
-// }
-// }
-
-// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
-// }
-
-// func (c *Causal) updateSlidingWindow() {
-// c.curCellRange = newRange()
-
-// if c.swaMemorySize == math.MaxInt32 {
-// for _, seq := range c.curSequences {
-// if seqRange, ok := c.cellRanges[seq]; ok {
-// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
-// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
-// }
-// }
-
-// return
-// }
-
-// type lowestPosition struct {
-// pos int32
-// curBatch bool
-// }
-
-// // create a map of unique sequences to the lowest position in that sequence
-// lowestPos := make(map[int]lowestPosition)
-// for i := range c.curPositions {
-// seq := c.curSequences[i]
-
-// lowest, ok := lowestPos[seq]
-// if !ok {
-// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
-// } else if c.curPositions[i] < lowest.pos {
-// lowest.pos = c.curPositions[i]
-// }
-
-// lowestPos[seq] = lowest
-// }
-
-// // for any sequences are not part of this batch, clean up any tokens
-// // that are no longer needed after the processing of the previous
-// // batch
-// for seq, seqRange := range c.cellRanges {
-// if _, ok := lowestPos[seq]; !ok {
-// var last int32
-// for i := seqRange.min; i <= seqRange.max; i++ {
-// if slices.Contains(c.cells[i].sequences, seq) {
-// last = max(last, c.cells[i].pos)
-// }
-// }
-
-// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
-// }
-// }
-
-// // delete any entries that are beyond the window of the oldest position in the sequence
-// for seq, lowest := range lowestPos {
-// oldRange, ok := c.cellRanges[seq]
-// if !ok {
-// continue
-// }
-
-// newRange := newRange()
-
-// for i := oldRange.min; i <= oldRange.max; i++ {
-// if slices.Contains(c.cells[i].sequences, seq) {
-// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
-// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
-// } else {
-// newRange.min = min(newRange.min, i)
-// newRange.max = max(newRange.max, i)
-// }
-// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
-// c.curCellRange.min = min(c.curCellRange.min, i)
-// c.curCellRange.max = max(c.curCellRange.max, i)
-// }
-// }
-// }
-
-// c.cellRanges[seq] = newRange
-// }
-// }
-
-// func roundDown(length, pad int) int {
-// return (length / pad) * pad
-// }
-
-// func roundUp(length, pad int) int {
-// return ((length + pad - 1) / pad) * pad
-// }
-
-// // Builds a mask of history x batch indicating whether for each token in the batch the
-// // token in the history should apply. This is based on both the sequence and causality (the
-// // position of the history is not ahead of the token in the batch).
-// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
-// // Align and pad the two dimensions as required by the backend
-// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
-
-// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
-// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
-
-// length := c.curCellRange.max - c.curCellRange.min + 1
-
-// mask := make([]float32, batchSize*length)
-
-// for i := range c.curBatchSize {
-// enabled := !slices.Contains(c.opts.Except, i)
-// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
-// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
-// (enabled && c.cells[j].pos > c.curPositions[i]) ||
-// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
-// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
-// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
-// }
-// }
-// }
-
-// // Mask out any padding tokens we added. For padding that we added to the cache history, this
-// // has already been masked out because the sequence doesn't match.
-// for i := c.curBatchSize * length; i < len(mask); i++ {
-// mask[i] = float32(math.Inf(-1))
-// }
-
-// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
-
-// // if c.config.MaskDType != ml.DTypeFloat32 {
-// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
-// // }
-
-// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
-
-// return maskTensor
-// }
-
-// func (c *Causal) SetLayer(layer int) {
-// c.curLayer = layer
-// }
-
-// type CausalOptions struct {
-// // Enabled controls whether the causal mask is generated for a particular index in a batch
-// Except []int
-// }
-
-// // SetCausal disables causal mask generation for a particular range of indicies in
-// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
-// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
-// if !slices.Equal(c.opts.Except, opts.Except) {
-// c.opts = opts
-// if ctx != nil {
-// c.curMask = c.buildMask(ctx)
-// }
-// }
-// }
-
-// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
-// key := c.keys[c.curLayer]
-// value := c.values[c.curLayer]
-
-// kHeadDim := c.kHeadDims[c.curLayer]
-// vHeadDim := c.vHeadDims[c.curLayer]
-// numKVHeads := c.numKVHeads[c.curLayer]
-// // rowSize := numKVHeads * c.curBatchSize
-// // cachedSize := c.curMask.Dim(1)
-// cachedSize := c.curLoc.Dim(0)
-// // kCellSize := kHeadDim * numKVHeads
-// // vCellSize := vHeadDim * numKVHeads
-
-// slog.Info("XXX Causal.Get full cache", "key", key)
-// slog.Info("XXX Causal.Get full cache", "value", value)
-// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
-// slog.Info("XXX Causal.Get", "curMask", c.curMask)
-// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
-// // panic("XXX")
-
-// // fmt.Fprintln(os.Stderr, key.ToString())
-// // panic("full cache value")
-
-// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
-// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
-// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
-
-// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
-// // panic("XXX")
-
-// // if c.config.PermutedV {
-// // panic("permuted")
-// // // TODO not converted
-// // vHeadDim := value.Dim(1)
-// // elemSize := value.Stride(2)
-
-// // value = value.AsStrided(ctx,
-// // []int{numKVHeads, vHeadDim, cachedSize},
-// // []int{value.Stride(0), value.Stride(1)},
-// // elemSize*c.curCellRange.min,
-// // )
-// // } else {
-// // vHeadDim := c.vHeadDims[c.curLayer]
-// // rowSize := value.Stride(2)
-// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
-// // panic("XXX")
-
-// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
-// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
-// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
-
-// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
-// // panic("XXX")
-
-// // }
-
-// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
-// // // the 1 becomes trailing and messes up later operations
-// // // This isn't the right solution, but works around it...
-// // if c.curMask.Dim(1) == 1 {
-// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
-// // }
-// // fmt.Fprintln(os.Stderr, key.ToString())
-// // fmt.Fprintln(os.Stderr, value.ToString())
-// // panic("XXX")
-// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
-
-// return key, value, c.curMask
-// }
-
-// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
-// kHeadDim := key.Dim(3)
-// vHeadDim := value.Dim(3)
-// numKVHeads := key.Dim(1)
-// batchSize := key.Dim(2)
-// kCellSize := kHeadDim * numKVHeads
-// vCellSize := vHeadDim * numKVHeads
-
-// // slog.Info("XXX Causal.Put", "key", key, "value", value)
-// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
-// // panic("XXX")
-
-// if c.curBatchSize != batchSize {
-// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
-// }
-
-// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
-// if _, ok := c.ctxs[c.curLayer]; !ok {
-// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
-// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
-// }
-
-// if _, ok := c.keys[c.curLayer]; !ok {
-// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
-
-// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
-// c.kHeadDims[c.curLayer] = kHeadDim
-// c.vHeadDims[c.curLayer] = vHeadDim
-// c.numKVHeads[c.curLayer] = numKVHeads
-// }
-
-// if _, ok := c.values[c.curLayer]; !ok {
-// // if c.config.PermutedV {
-// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
-// // } else {
-// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
-// // }
-// }
-
-// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
-
-// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
-// // panic("XXX")
-// // curLoc := 0 // TODO c.curLoc is now a tensor
-// // kSize := numKVHeads * kHeadDim
-// // vSize := numKVHeads * vHeadDim
-// // start := []int{int(curLoc), 0}
-// // kStop := []int{int(curLoc + batchSize), int(kSize)}
-// // vStop := []int{int(curLoc + batchSize), int(vSize)}
-// // strides := []int{1, 1}
-
-// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
-// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
-
-// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
-
-// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
-// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
-// // fmt.Fprintln(os.Stderr, keyCache.ToString())
-// // panic("input value")
-
-// // fmt.Fprintln(os.Stderr, t.ToString())
-// // panic("XXX")
-
-// // if c.config.PermutedV {
-// // panic("permuted")
-// // // TODO not adjusted
-// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
-// // value = value.Transpose(ctx, 2, 0, 1, 3)
-
-// // valueCache := c.values[c.curLayer]
-// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
-
-// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
-// // } else {
-// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
-// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
-// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
-// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
-
-// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
-// // }
-// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
-// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
-// // panic("XXX")
-
-// }
-
-// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
-// seqRange := newRange()
-
-// for i := range c.cells {
-// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
-// if slices.Contains(c.cells[i].sequences, dstSeq) {
-// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
-// }
-
-// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
-// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
-// if i < seqRange.min {
-// seqRange.min = i
-// }
-// if i > seqRange.max {
-// seqRange.max = i
-// }
-// }
-// }
-
-// c.cellRanges[dstSeq] = seqRange
-// }
-
-// func (c *Causal) CanResume(seq int, pos int32) bool {
-// if c.swaMemorySize == math.MaxInt32 {
-// return true
-// }
-
-// seqRange, ok := c.cellRanges[seq]
-// if !ok {
-// return false
-// }
-
-// // for sliding window, check that the window of the new sequence is contained in
-// // the window of what we are storing
-// var first int32 = math.MaxInt32
-// var last int32 = -1
-// for i := seqRange.min; i <= seqRange.max; i++ {
-// if slices.Contains(c.cells[i].sequences, seq) {
-// first = min(first, c.cells[i].pos)
-// last = max(last, c.cells[i].pos)
-// }
-// }
-
-// if last == -1 {
-// return false
-// }
-
-// posWindowStart := max(0, pos-c.swaWindowSize)
-// return posWindowStart >= first && pos <= last+1
-// }
-
-// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
-// if c.shiftFn == nil {
-// return ErrNotSupported
-// }
-
-// seqRange := c.cellRanges[seq]
-
-// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
-// size := min(seqRange.max-start+1, c.maxBatch)
-// offsets := make([]int32, size)
-
-// var batchFirst, batchLast int
-
-// batchFirst = -1
-// for i := range offsets {
-// cell := c.cells[start+i]
-
-// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
-// offsets[i] = offset
-// if batchFirst < 0 {
-// batchFirst = i
-// }
-// batchLast = i
-// }
-// }
-
-// if batchFirst < 0 {
-// continue
-// }
-
-// offsets = offsets[batchFirst : batchLast+1]
-
-// slog.Info("XXX Causal.shift creating new temporary context")
-// ctx := c.backend.NewContext()
-// kShift := ctx.Input().FromInts(offsets, len(offsets))
-
-// for i, key := range c.keys {
-// if key == nil {
-// continue
-// }
-
-// kHeadDim := key.Dim(2)
-// numKVHeads := key.Dim(1)
-// rowSize := key.Stride(0)
-
-// key = key.AsStrided(ctx,
-// []int{len(offsets), numKVHeads, kHeadDim},
-// []int{key.Stride(0), key.Stride(1)},
-// rowSize*(start+batchFirst),
-// )
-
-// roped, err := c.shiftFn(ctx, i, key, kShift)
-// if err != nil {
-// ctx.Close()
-// return err
-// }
-
-// ctx.Forward(roped.Copy(ctx, key))
-// }
-
-// ctx.Compute()
-// ctx.Close()
-// }
-
-// return nil
-// }
-
-// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
-// // TODO(jessegross): We should check to see if removing the middle of the sequence will
-// // cause the sliding window to encompass tokens that we no longer have. If so, then we
-// // should return an error, which will trigger the runner to evaluate the full history and
-// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
-// // results in use after free, so we don't do it for now.
-
-// var offset int32
-// if endIndex != math.MaxInt32 {
-// offset = beginIndex - endIndex
-// }
-
-// seqRange := newRange()
-
-// for i := range c.cells {
-// if slices.Contains(c.cells[i].sequences, seq) {
-// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
-// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
-// } else {
-// if c.cells[i].pos >= endIndex {
-// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
-// return errors.New("shifting cells shared by multiple sequences not supported")
-// }
-
-// c.cells[i].pos += offset
-// }
-// if i < seqRange.min {
-// seqRange.min = i
-// }
-// if i > seqRange.max {
-// seqRange.max = i
-// }
-// }
-// }
-// }
-
-// if seqRange == newRange() {
-// delete(c.cellRanges, seq)
-// return nil
-// }
-
-// c.cellRanges[seq] = seqRange
-
-// if endIndex != math.MaxInt32 {
-// err := c.shift(seq, endIndex+offset, offset)
-// if err != nil {
-// return err
-// }
-// }
-
-// return nil
-// }
+import (
+ "github.com/ollama/ollama/x/ml"
+ "github.com/ollama/ollama/x/model/input"
+)
+
+// Causal cache stores K and V tensors according to their position in the
+// sequence. Returns the history and a mask for attending to past tokens
+type Causal struct {
+ DType ml.DType
+
+ // locations for data storage for this batch
+ curLocPut ml.Tensor
+
+ // locations for data storage for this batch
+ curLocGet ml.Tensor
+
+ // the active layer for Get and Put
+ curLayer int
+
+ capacity int
+
+ offset int
+
+ backend ml.Backend
+ ctxs map[int]ml.Context
+ keys, values map[int]ml.Tensor
+
+ // TODO is this needed per layer, or will it always be consistent?
+ kHeadDims, vHeadDims, numKVHeads map[int]int
+}
+
+func NewCausalCache() *Causal {
+ return &Causal{
+ ctxs: make(map[int]ml.Context),
+ keys: make(map[int]ml.Tensor),
+ values: make(map[int]ml.Tensor),
+ kHeadDims: make(map[int]int),
+ vHeadDims: make(map[int]int),
+ numKVHeads: make(map[int]int),
+ }
+}
+
+func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
+ c.DType = dtype
+ c.capacity = capacity
+ c.backend = backend
+}
+
+func (c *Causal) SetConfig(config ml.CacheConfig) {}
+
+func (c *Causal) SetLayer(layer int) {
+ c.curLayer = layer
+}
+
+func (c *Causal) Close() {
+ // slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
+ for _, ctx := range c.ctxs {
+ ctx.Close()
+ }
+}
+
+func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
+ locsPut := make([]int32, len(batch.Positions))
+ for i := c.offset; i < len(batch.Positions); i++ {
+ locsPut[i-c.offset] = int32(i)
+ }
+ c.offset += len(batch.Positions)
+ locsGet := make([]int32, c.offset)
+ for i := range c.offset {
+ locsGet[i] = int32(i)
+ }
+ c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
+ c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
+ // slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
+
+ return nil
+}
+func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
+ kHeadDim := key.Dim(3)
+ vHeadDim := value.Dim(3)
+ numKVHeads := key.Dim(1)
+ batchSize := key.Dim(2)
+ kCellSize := kHeadDim * numKVHeads
+ vCellSize := vHeadDim * numKVHeads
+ // slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
+
+ if _, ok := c.ctxs[c.curLayer]; !ok {
+ // slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
+ c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
+ }
+
+ if _, ok := c.keys[c.curLayer]; !ok {
+ // slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
+ c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
+ c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
+ c.kHeadDims[c.curLayer] = kHeadDim
+ c.vHeadDims[c.curLayer] = vHeadDim
+ c.numKVHeads[c.curLayer] = numKVHeads
+ }
+ key = key.Reshape(ctx, batchSize, 1, kCellSize)
+
+ // slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
+ // slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
+ // slog.Info("XXX Causal.Put ", "key", key)
+ ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
+ value = value.Reshape(ctx, batchSize, 1, vCellSize)
+ ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
+
+}
+
+func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
+ key := c.keys[c.curLayer]
+ value := c.values[c.curLayer]
+
+ kHeadDim := c.kHeadDims[c.curLayer]
+ vHeadDim := c.vHeadDims[c.curLayer]
+ numKVHeads := c.numKVHeads[c.curLayer]
+ // rowSize := numKVHeads * c.curBatchSize
+ // cachedSize := c.curMask.Dim(1)
+ cachedSize := c.curLocGet.Dim(0)
+ // kCellSize := kHeadDim * numKVHeads
+ // vCellSize := vHeadDim * numKVHeads
+ // slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
+
+ key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
+ value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
+ return key, value, nil
+}
+
+func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
+ panic("not implemented")
+}
+
+func (c *Causal) CanResume(seq int, pos int32) bool {
+ panic("not implemented")
+}
+
+func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
+ panic("not implemented")
+}
diff --git a/x/kvcache/causal_test.go b/x/kvcache/causal_test.go
deleted file mode 100644
index d7ac430b1..000000000
--- a/x/kvcache/causal_test.go
+++ /dev/null
@@ -1,973 +0,0 @@
-package kvcache
-
-// import (
-// "fmt"
-// "math"
-// "slices"
-// "testing"
-
-// "github.com/ollama/ollama/ml"
-// "github.com/ollama/ollama/model/input"
-// )
-
-// type testCase struct {
-// name string
-// in []float32
-// inShape []int
-// seqs []int
-// pos []int32
-// expected []float32
-// expectedShape []int
-// expectedMask []float32
-// }
-
-// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
-// t.Helper()
-// for _, permuted := range []bool{false, true} {
-// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
-// fn(t, &testBackend{permutedV: permuted})
-// })
-// }
-// }
-
-// func TestStore(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewCausalCache(nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
-// inShape: []int{2, 3, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
-// expectedShape: []int{2, 3, 4},
-// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{115, 215, 125, 225, 135, 235},
-// inShape: []int{2, 3, 1},
-// seqs: []int{0},
-// pos: []int32{4},
-// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
-// expectedShape: []int{2, 3, 5},
-// expectedMask: []float32{0, 0, 0, 0, 0},
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestSWA(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewSWACache(1, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// x := float32(math.Inf(-1))
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, x,
-// 0, 0, x, x,
-// x, 0, 0, x,
-// x, x, 0, 0,
-// },
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{4, 5},
-// expected: []float32{5, 6, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, 0,
-// 0, 0, x, x,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestSWASeparateBatches(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewSWACache(1, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
-
-// x := float32(math.Inf(-1))
-
-// tests := []testCase{
-// {
-// name: "First seq 0",
-// in: []float32{1, 2},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{0, 1},
-// expected: []float32{1, 2},
-// expectedShape: []int{1, 1, 2},
-// expectedMask: []float32{
-// 0, x,
-// 0, 0,
-// },
-// },
-// {
-// name: "Second seq 0",
-// in: []float32{3, 4},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{2, 3},
-// expected: []float32{2, 3, 4},
-// expectedShape: []int{1, 1, 3},
-// expectedMask: []float32{
-// 0, 0, x,
-// x, 0, 0,
-// },
-// },
-// {
-// name: "First seq 1",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{1, 1},
-// pos: []int32{0, 1},
-// expected: []float32{5, 6},
-// expectedShape: []int{1, 1, 2},
-// expectedMask: []float32{
-// 0, x,
-// 0, 0,
-// },
-// },
-// {
-// name: "Second seq 1",
-// in: []float32{7, 8},
-// inShape: []int{1, 1, 2},
-// seqs: []int{1, 1},
-// pos: []int32{2, 3},
-// expected: []float32{6, 3, 4, 7, 8},
-// expectedShape: []int{1, 1, 5},
-// expectedMask: []float32{
-// 0, x, x, 0, x,
-// x, x, x, 0, 0,
-// },
-// },
-// {
-// name: "Third seq 0",
-// in: []float32{9, 10},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{4, 5},
-// expected: []float32{9, 10, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, 0,
-// 0, 0, x, x,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestSWAMem(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewSWAMemCache(1, 3, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// x := float32(math.Inf(-1))
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, x,
-// 0, 0, x, x,
-// x, 0, 0, x,
-// x, x, 0, 0,
-// },
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{4, 5},
-// expected: []float32{5, 2, 3, 4, 6},
-// expectedShape: []int{1, 1, 5},
-// expectedMask: []float32{
-// 0, x, x, 0, x,
-// 0, x, x, x, 0,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestChunkedAttention(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewChunkedAttentionCache(2, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// x := float32(math.Inf(-1))
-
-// testCache(
-// t, backend, cache,
-// []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, x,
-// 0, 0, x, x,
-// x, x, 0, x,
-// x, x, 0, 0,
-// },
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{5, 6, 7},
-// inShape: []int{1, 1, 3},
-// seqs: []int{0, 0, 0},
-// pos: []int32{4, 5, 6},
-// expected: []float32{1, 2, 3, 4, 5, 6, 7},
-// expectedShape: []int{1, 1, 7},
-// expectedMask: []float32{
-// x, x, x, x, 0, x, x,
-// x, x, x, x, 0, 0, x,
-// x, x, x, x, x, x, 0,
-// },
-// },
-// {
-// name: "ThirdBatch",
-// in: []float32{8, 9},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{7, 8},
-// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
-// expectedShape: []int{1, 1, 9},
-// expectedMask: []float32{
-// x, x, x, x, x, x, 0, 0, x,
-// x, x, x, x, x, x, x, x, 0,
-// },
-// },
-// },
-// )
-// })
-// }
-
-// func TestSequences(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewCausalCache(nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 1, 1},
-// pos: []int32{0, 1, 0, 1},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 1},
-// pos: []int32{2, 2},
-// expected: []float32{1, 2, 3, 4, 5, 6},
-// expectedShape: []int{1, 1, 6},
-// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestRemove(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-// return key.Add(ctx, shift), nil
-// })
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// x := float32(math.Inf(-1))
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 1, 1},
-// pos: []int32{0, 1, 0, 1},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, x,
-// 0, 0, x, x,
-// x, x, 0, x,
-// x, x, 0, 0,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-
-// err := cache.Remove(0, 1, math.MaxInt32)
-// if err != nil {
-// panic(err)
-// }
-
-// tests = []testCase{
-// {
-// name: "RemoveEnd",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 1},
-// pos: []int32{1, 2},
-// expected: []float32{1, 5, 3, 4, 6},
-// expectedShape: []int{1, 1, 5},
-// expectedMask: []float32{
-// 0, 0, x, x, x,
-// x, x, 0, 0, 0,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-
-// err = cache.Remove(0, 0, 1)
-// if err != nil {
-// panic(err)
-// }
-
-// tests = []testCase{
-// {
-// name: "RemoveMiddle",
-// in: []float32{7, 8},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{1, 2},
-// expected: []float32{7, 4, 3, 4, 6, 8},
-// expectedShape: []int{1, 1, 6},
-// expectedMask: []float32{
-// 0, 0, x, x, x, x,
-// 0, 0, x, x, x, 0,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestCopy(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-
-// cache.CopyPrefix(0, 1, 2)
-
-// tests = []testCase{
-// {
-// name: "Copy",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{1, 1},
-// pos: []int32{3, 4},
-// expected: []float32{1, 2, 3, 4, 5, 6},
-// expectedShape: []int{1, 1, 6},
-// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
-// for _, test := range tests {
-// t.Run(test.name, func(t *testing.T) {
-// context := backend.NewContext()
-// defer context.Close()
-
-// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
-// if err != nil {
-// panic(err)
-// }
-
-// cache.SetLayer(0)
-// tensor := context.FromFloats(test.in, test.inShape...)
-// cache.Put(context, tensor, tensor)
-
-// out, _, mask := cache.Get(context)
-
-// context.Forward(out, mask).Compute(out, mask)
-
-// if !slices.Equal(out.Floats(), test.expected) {
-// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
-// }
-
-// if !slices.Equal(out.Shape(), test.expectedShape) {
-// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
-// }
-
-// if !slices.Equal(mask.Floats(), test.expectedMask) {
-// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
-// }
-// })
-// }
-// }
-
-// func TestCanResume(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// windowSize := int32(4)
-// cache := NewSWACache(windowSize, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// context := backend.NewContext()
-// defer context.Close()
-
-// err := cache.StartForward(context, input.Batch{
-// Positions: []int32{0, 1, 2, 3, 4},
-// Sequences: []int{0, 0, 0, 0, 0},
-// }, false)
-// if err != nil {
-// t.Fatalf("StartForward failed: %v", err)
-// }
-
-// cache.SetLayer(0)
-// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
-// cache.Put(context, tensor, tensor)
-
-// // with window size 4, nothing has slid out of the window yet
-// if !cache.CanResume(0, 0) {
-// t.Errorf("CanResume(0, 0) = false, want true (within window)")
-// }
-// if !cache.CanResume(0, 1) {
-// t.Errorf("CanResume(0, 1) = false, want true (within window)")
-// }
-// if !cache.CanResume(0, 2) {
-// t.Errorf("CanResume(0, 2) = false, want true (within window)")
-// }
-// if !cache.CanResume(0, 3) {
-// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
-// }
-// if !cache.CanResume(0, 4) {
-// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
-// }
-
-// // shift window by adding position 5
-// err = cache.StartForward(context, input.Batch{
-// Positions: []int32{5},
-// Sequences: []int{0},
-// }, false)
-// if err != nil {
-// t.Fatalf("StartForward failed: %v", err)
-// }
-
-// cache.SetLayer(0)
-// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
-// cache.Put(context, tensor, tensor)
-
-// // only the latest position has overlapping windows
-// if cache.CanResume(0, 0) {
-// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 1) {
-// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 2) {
-// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 3) {
-// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 4) {
-// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
-// }
-// if !cache.CanResume(0, 5) {
-// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
-// }
-// })
-// }
-
-// func TestCanResumeSWAMem(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// windowSize := int32(4)
-// memSize := int32(5)
-// cache := NewSWAMemCache(windowSize, memSize, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// context := backend.NewContext()
-// defer context.Close()
-
-// err := cache.StartForward(context, input.Batch{
-// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
-// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
-// }, false)
-// if err != nil {
-// t.Fatalf("StartForward failed: %v", err)
-// }
-
-// cache.SetLayer(0)
-// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
-// cache.Put(context, tensor, tensor)
-
-// // shift window by adding position 7
-// err = cache.StartForward(context, input.Batch{
-// Positions: []int32{7},
-// Sequences: []int{0},
-// }, false)
-// if err != nil {
-// t.Fatalf("StartForward failed: %v", err)
-// }
-
-// cache.SetLayer(0)
-// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
-// cache.Put(context, tensor, tensor)
-
-// // only the latest position has overlapping windows
-// if cache.CanResume(0, 0) {
-// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 1) {
-// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 2) {
-// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 3) {
-// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 4) {
-// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 5) {
-// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
-// }
-// if !cache.CanResume(0, 6) {
-// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
-// }
-// if !cache.CanResume(0, 7) {
-// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
-// }
-// })
-// }
-
-// type testBackend struct {
-// ml.Backend
-// permutedV bool
-// }
-
-// func (b *testBackend) NewContext() ml.Context {
-// return &testContext{}
-// }
-
-// func (b *testBackend) NewContextSize(int) ml.Context {
-// return &testContext{}
-// }
-
-// func (b *testBackend) CacheConfig() ml.CacheConfig {
-// return ml.CacheConfig{PermutedV: b.permutedV}
-// }
-
-// type testContext struct {
-// ml.Context
-// }
-
-// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
-// total := 0
-
-// if len(shape) > 0 {
-// total = 1
-// for _, s := range shape {
-// total *= s
-// }
-// }
-
-// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
-// }
-
-// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
-// return c.Empty(dtype, shape...)
-// }
-
-// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
-// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
-
-// copy(t.data, s)
-
-// return t
-// }
-
-// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
-// f := make([]float32, len(s))
-// for i := range f {
-// f[i] = float32(s[i])
-// }
-
-// out := c.FromFloats(f, shape...)
-// out.(*testTensor).dtype = ml.DTypeI32
-
-// return out
-// }
-
-// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
-// s := make([]float32, 0, int((stop-start)/step))
-// for i := start; i < stop; i += step {
-// s = append(s, i)
-// }
-
-// out := c.FromFloats(s, len(s))
-// out.(*testTensor).dtype = dtype
-// return out
-// }
-
-// func (c *testContext) Input() ml.Context { return c }
-// func (c *testContext) Layer(int) ml.Context { return c }
-
-// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
-
-// func (c *testContext) Compute(...ml.Tensor) {}
-
-// func (c *testContext) Reserve() {}
-
-// func (c *testContext) MaxGraphNodes() int {
-// return 10
-// }
-
-// func (c *testContext) Close() {}
-
-// type testTensor struct {
-// ml.Tensor
-
-// dtype ml.DType
-// elementSize int
-// data []float32
-// shape []int
-// }
-
-// func (t *testTensor) Dim(n int) int {
-// return t.shape[n]
-// }
-
-// func (t *testTensor) Stride(n int) int {
-// stride := t.elementSize
-// for i := range n {
-// stride *= t.shape[i]
-// }
-
-// return stride
-// }
-
-// func (t *testTensor) Shape() []int {
-// return t.shape
-// }
-
-// func (t *testTensor) DType() ml.DType {
-// return t.dtype
-// }
-
-// func (t *testTensor) Floats() []float32 {
-// out := make([]float32, len(t.data))
-// copy(out, t.data)
-// return out
-// }
-
-// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
-// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
-// for i := range out.data {
-// out.data[i] = -t.data[i]
-// }
-// return out
-// }
-
-// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
-// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
-
-// for i := range out.data {
-// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
-// }
-
-// return out
-// }
-
-// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
-// return &testTensor{
-// dtype: t.dtype,
-// elementSize: t.elementSize,
-// data: t.data,
-// shape: shape,
-// }
-// }
-
-// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
-// offset /= t.elementSize
-
-// var s []int
-
-// switch len(shape) {
-// case 1:
-// s = []int{shape[0]}
-// case 3:
-// s = []int{shape[0], shape[2]}
-// case 5:
-// s = []int{shape[0], shape[2], shape[4]}
-// default:
-// panic("unsupported number of dimensions")
-// }
-
-// context := &testContext{}
-
-// view := context.Empty(t.dtype, s...).(*testTensor)
-// view.data = t.data[offset : offset+len(view.data)]
-
-// return view
-// }
-
-// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
-// if len(t.shape) > 4 || len(order) > 4 {
-// panic("permute only supports up to 4 dimensions")
-// }
-
-// if len(order) != len(t.shape) && len(order) != 4 {
-// panic("invalid number of dimensions for permute")
-// }
-
-// // ggml_permute expects 4 axes, so fill in any missing dimensions.
-// orderFull := append(make([]int, 0, 4), order...)
-// for len(orderFull) < 4 {
-// orderFull = append(orderFull, len(orderFull))
-// }
-
-// seen := [4]bool{}
-
-// shape4 := [4]int{1, 1, 1, 1}
-// for i := 0; i < len(t.shape) && i < 4; i++ {
-// shape4[i] = t.shape[i]
-// }
-
-// newShape4 := [4]int{1, 1, 1, 1}
-// for axis := range 4 {
-// dst := orderFull[axis]
-// if dst < 0 || dst >= 4 {
-// panic("invalid axis for permute")
-// }
-// if seen[dst] {
-// panic("duplicate axis for permute")
-// }
-// seen[dst] = true
-// newShape4[dst] = shape4[axis]
-// }
-
-// total := len(t.data)
-// newData := make([]float32, total)
-
-// if total > 0 {
-// oldDims := shape4
-// newDims := newShape4
-
-// oldStride := [4]int{1, 1, 1, 1}
-// newStride := [4]int{1, 1, 1, 1}
-// for i := 1; i < 4; i++ {
-// oldStride[i] = oldStride[i-1] * oldDims[i-1]
-// newStride[i] = newStride[i-1] * newDims[i-1]
-// }
-
-// var coords [4]int
-// var newCoords [4]int
-
-// for idx := range total {
-// remainder := idx
-// for axis := range 4 {
-// dim := oldDims[axis]
-// if dim == 0 {
-// coords[axis] = 0
-// continue
-// }
-// coords[axis] = remainder % dim
-// remainder /= dim
-// }
-
-// for axis := range 4 {
-// newCoords[orderFull[axis]] = coords[axis]
-// }
-
-// newIndex := 0
-// for axis := range 4 {
-// if newDims[axis] == 0 {
-// continue
-// }
-// newIndex += newCoords[axis] * newStride[axis]
-// }
-
-// newData[newIndex] = t.data[idx]
-// }
-// }
-
-// numDims := 4
-// for numDims > 1 && newShape4[numDims-1] <= 1 {
-// numDims--
-// }
-
-// newShape := make([]int, numDims)
-// copy(newShape, newShape4[:numDims])
-
-// return &testTensor{
-// dtype: t.dtype,
-// elementSize: t.elementSize,
-// data: newData,
-// shape: newShape,
-// }
-// }
-
-// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
-// dst := t
-// srcTensor := src.(*testTensor)
-// idxTensor := idxs.(*testTensor)
-
-// shapeTo4D := func(shape []int) [4]int {
-// out := [4]int{1, 1, 1, 1}
-// for i := 0; i < len(shape) && i < 4; i++ {
-// out[i] = shape[i]
-// }
-// return out
-// }
-
-// computeStrides := func(shape [4]int) [4]int {
-// out := [4]int{1, 1, 1, 1}
-// for i := 1; i < 4; i++ {
-// out[i] = out[i-1] * shape[i-1]
-// }
-// return out
-// }
-
-// dstShape4D := shapeTo4D(dst.shape)
-// srcShape4D := shapeTo4D(srcTensor.shape)
-// idxShape4D := shapeTo4D(idxTensor.shape)
-
-// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
-// panic("SetRows requires matching tensor shapes")
-// }
-
-// if srcShape4D[1] != idxShape4D[0] {
-// panic("SetRows rows/index mismatch")
-// }
-
-// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
-// panic("SetRows cannot broadcast indices")
-// }
-
-// if idxShape4D[3] != 1 {
-// panic("SetRows expects 1D or 2D index tensors")
-// }
-
-// dstStride := computeStrides(dstShape4D)
-// srcStride := computeStrides(srcShape4D)
-// idxStride := computeStrides(idxShape4D)
-
-// numColumns := srcShape4D[0]
-// numRows := srcShape4D[1]
-
-// for dim3Index := range dstShape4D[3] {
-// for dim2Index := range dstShape4D[2] {
-// idxDim2 := 0
-// idxDim3 := 0
-// if idxShape4D[1] > 0 {
-// idxDim2 = dim2Index % idxShape4D[1]
-// }
-// if idxShape4D[2] > 0 {
-// idxDim3 = dim3Index % idxShape4D[2]
-// }
-
-// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
-// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
-// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
-
-// for row := range numRows {
-// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
-// if idx < 0 || idx >= dstShape4D[1] {
-// panic("SetRows index out of range")
-// }
-
-// srcOffset := srcBase + row*srcStride[1]
-// dstOffset := dstBase + idx*dstStride[1]
-
-// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
-// }
-// }
-// }
-
-// return dst
-// }
-
-// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
-// copy(t2.(*testTensor).data, t.data)
-// return nil
-// }
diff --git a/x/kvcache/mlx.go b/x/kvcache/mlx.go
deleted file mode 100644
index fa3865104..000000000
--- a/x/kvcache/mlx.go
+++ /dev/null
@@ -1,144 +0,0 @@
-//go:build mlx
-
-package kvcache
-
-import (
- "github.com/ollama/ollama/x/ml"
- "github.com/ollama/ollama/x/model/input"
-)
-
-// Causal cache stores K and V tensors according to their position in the
-// sequence. Returns the history and a mask for attending to past tokens
-type MLXCausal struct {
- DType ml.DType
-
- // locations for data storage for this batch
- curLocPut ml.Tensor
-
- // locations for data storage for this batch
- curLocGet ml.Tensor
-
- // the active layer for Get and Put
- curLayer int
-
- capacity int
-
- offset int
-
- backend ml.Backend
- ctxs map[int]ml.Context
- keys, values map[int]ml.Tensor
-
- // TODO is this needed per layer, or will it always be consistent?
- kHeadDims, vHeadDims, numKVHeads map[int]int
-}
-
-func NewMLXCausalCache() *MLXCausal {
- return &MLXCausal{
- ctxs: make(map[int]ml.Context),
- keys: make(map[int]ml.Tensor),
- values: make(map[int]ml.Tensor),
- kHeadDims: make(map[int]int),
- vHeadDims: make(map[int]int),
- numKVHeads: make(map[int]int),
- }
-}
-
-func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
- c.DType = dtype
- c.capacity = capacity
- c.backend = backend
-}
-
-func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
-
-func (c *MLXCausal) SetLayer(layer int) {
- c.curLayer = layer
-}
-
-func (c *MLXCausal) Close() {
- // slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
- for _, ctx := range c.ctxs {
- ctx.Close()
- }
-}
-
-func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
- locsPut := make([]int32, len(batch.Positions))
- for i := c.offset; i < len(batch.Positions); i++ {
- locsPut[i-c.offset] = int32(i)
- }
- c.offset += len(batch.Positions)
- locsGet := make([]int32, c.offset)
- for i := range c.offset {
- locsGet[i] = int32(i)
- }
- c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
- c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
- // slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
-
- return nil
-}
-func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
- kHeadDim := key.Dim(3)
- vHeadDim := value.Dim(3)
- numKVHeads := key.Dim(1)
- batchSize := key.Dim(2)
- kCellSize := kHeadDim * numKVHeads
- vCellSize := vHeadDim * numKVHeads
- // slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
-
- if _, ok := c.ctxs[c.curLayer]; !ok {
- // slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
- c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
- }
-
- if _, ok := c.keys[c.curLayer]; !ok {
- // slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
- c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
- c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
- c.kHeadDims[c.curLayer] = kHeadDim
- c.vHeadDims[c.curLayer] = vHeadDim
- c.numKVHeads[c.curLayer] = numKVHeads
- }
- key = key.Reshape(ctx, batchSize, 1, kCellSize)
-
- // slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
- // slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
- // slog.Info("XXX MLXCausal.Put ", "key", key)
- ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
- value = value.Reshape(ctx, batchSize, 1, vCellSize)
- ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
-
-}
-
-func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
- key := c.keys[c.curLayer]
- value := c.values[c.curLayer]
-
- kHeadDim := c.kHeadDims[c.curLayer]
- vHeadDim := c.vHeadDims[c.curLayer]
- numKVHeads := c.numKVHeads[c.curLayer]
- // rowSize := numKVHeads * c.curBatchSize
- // cachedSize := c.curMask.Dim(1)
- cachedSize := c.curLocGet.Dim(0)
- // kCellSize := kHeadDim * numKVHeads
- // vCellSize := vHeadDim * numKVHeads
- // slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
-
- key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
- value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
- return key, value, nil
-}
-
-func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
- panic("not implemented")
-}
-
-func (c *MLXCausal) CanResume(seq int, pos int32) bool {
- panic("not implemented")
-}
-
-func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
- panic("not implemented")
-}
diff --git a/x/mlxrunner/imagegen.go b/x/mlxrunner/imagegen.go
new file mode 100644
index 000000000..b1cdd91df
--- /dev/null
+++ b/x/mlxrunner/imagegen.go
@@ -0,0 +1,134 @@
+//go:build mlx
+
+package mlxrunner
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/mlx"
+ "github.com/ollama/ollama/x/imagegen/models/flux2"
+ "github.com/ollama/ollama/x/imagegen/models/zimage"
+)
+
+// ImageModel is the interface for image generation models.
+type ImageModel interface {
+ GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
+}
+
+var imageGenMu sync.Mutex
+
+// loadImageModel loads an image generation model.
+func (s *server) loadImageModel() error {
+ // Check memory requirements before loading
+ var requiredMemory uint64
+ if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
+ requiredMemory = uint64(manifest.TotalTensorSize())
+ }
+ availableMemory := mlx.GetMemoryLimit()
+ if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
+ return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
+ requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
+ }
+
+ // Detect model type and load appropriate model
+ modelType := imagegen.DetectModelType(s.modelName)
+ slog.Info("detected image model type", "type", modelType)
+
+ var model ImageModel
+ switch modelType {
+ case "Flux2KleinPipeline":
+ m := &flux2.Model{}
+ if err := m.Load(s.modelName); err != nil {
+ return fmt.Errorf("failed to load flux2 model: %w", err)
+ }
+ model = m
+ default:
+ // Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
+ m := &zimage.Model{}
+ if err := m.Load(s.modelName); err != nil {
+ return fmt.Errorf("failed to load zimage model: %w", err)
+ }
+ model = m
+ }
+
+ s.imageModel = model
+ return nil
+}
+
+// handleImageCompletion handles image generation requests.
+func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, req Request) {
+ // Serialize generation requests - MLX model may not handle concurrent generation
+ imageGenMu.Lock()
+ defer imageGenMu.Unlock()
+
+ // Set seed if not provided
+ if req.Seed <= 0 {
+ req.Seed = time.Now().UnixNano()
+ }
+
+ // Set up streaming response
+ w.Header().Set("Content-Type", "application/x-ndjson")
+ w.Header().Set("Transfer-Encoding", "chunked")
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ http.Error(w, "streaming not supported", http.StatusInternalServerError)
+ return
+ }
+
+ ctx := r.Context()
+ enc := json.NewEncoder(w)
+
+ // Progress callback streams step updates
+ progress := func(step, total int) {
+ resp := Response{Step: step, Total: total}
+ enc.Encode(resp)
+ w.Write([]byte("\n"))
+ flusher.Flush()
+ }
+
+ // Generate image
+ img, err := s.imageModel.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
+ if err != nil {
+ // Don't send error for cancellation
+ if ctx.Err() != nil {
+ return
+ }
+ resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
+ data, _ := json.Marshal(resp)
+ w.Write(data)
+ w.Write([]byte("\n"))
+ return
+ }
+
+ // Encode image as base64 PNG
+ imageData, err := imagegen.EncodeImageBase64(img)
+ if err != nil {
+ resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
+ data, _ := json.Marshal(resp)
+ w.Write(data)
+ w.Write([]byte("\n"))
+ return
+ }
+
+ // Free the generated image array and clean up MLX state
+ img.Free()
+ mlx.ClearCache()
+ mlx.MetalResetPeakMemory()
+
+ // Send final response with image data
+ resp := Response{
+ Image: imageData,
+ Done: true,
+ }
+ data, _ := json.Marshal(resp)
+ w.Write(data)
+ w.Write([]byte("\n"))
+ flusher.Flush()
+}
diff --git a/x/mlxrunner/llm.go b/x/mlxrunner/llm.go
new file mode 100644
index 000000000..865750573
--- /dev/null
+++ b/x/mlxrunner/llm.go
@@ -0,0 +1,420 @@
+//go:build mlx
+
+package mlxrunner
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/cache"
+ "github.com/ollama/ollama/x/imagegen/mlx"
+ "github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
+ "github.com/ollama/ollama/x/imagegen/tokenizer"
+)
+
+// TextModel is the interface for LLM text generation models.
+type TextModel interface {
+ Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
+ NewCache(maxSeqLen int32) []cache.Cache
+ Tokenizer() *tokenizer.Tokenizer
+ VocabSize() int32
+ MaxContextLength() int32
+ NumLayers() int
+}
+
+// llmState holds the state for LLM generation
+type llmState struct {
+ model TextModel
+}
+
+var llmMu sync.Mutex
+
+// Dedicated stream for generation (like mlx-lm's generation_stream)
+var generationStream *mlx.Stream
+
+// withStream runs fn with the generation stream as default
+func withStream(fn func()) {
+ // Lazy initialization of generationStream
+ if generationStream == nil {
+ generationStream = mlx.NewStream()
+ }
+ orig := mlx.GetDefaultStream()
+ mlx.SetDefaultStream(generationStream)
+ fn()
+ mlx.SetDefaultStream(orig)
+}
+
+// Decoder wraps model + cache for autoregressive generation.
+// This matches the pattern from cmd/engine/generate.go
+type Decoder struct {
+ model TextModel
+ caches []cache.Cache
+ vocabSize int32
+ temp float32
+ token *mlx.Array // Current token (kept across iterations)
+ oldCacheState []*mlx.Array // Preallocated slice for old cache state
+}
+
+func NewDecoder(m TextModel, temp float32) *Decoder {
+ caches := m.NewCache(0)
+ return &Decoder{
+ model: m,
+ caches: caches,
+ vocabSize: m.VocabSize(),
+ temp: temp,
+ oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
+ }
+}
+
+func (d *Decoder) prefill(inputIDs []int32) int {
+ processed := 0
+
+ // Track old cache state to free after each chunk
+ var oldCacheState []*mlx.Array
+
+ // Process all-but-1 tokens in chunks, eval cache state for memory management
+ for len(inputIDs) > 1 {
+ chunkSize := min(2048, len(inputIDs)-1)
+ if chunkSize <= 0 {
+ break
+ }
+ chunk := inputIDs[:chunkSize]
+
+ // Save old cache state before forward
+ oldCacheState = oldCacheState[:0]
+ for _, c := range d.caches {
+ oldCacheState = append(oldCacheState, c.State()...)
+ }
+
+ var cacheState []*mlx.Array
+ withStream(func() {
+ x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
+ d.model.Forward(x, d.caches)
+ for _, c := range d.caches {
+ cacheState = append(cacheState, c.State()...)
+ }
+ })
+ mlx.Eval(cacheState...)
+
+ // Free old cache state
+ for _, arr := range oldCacheState {
+ if arr != nil {
+ arr.Free()
+ }
+ }
+
+ inputIDs = inputIDs[chunkSize:]
+ processed += chunkSize
+ }
+
+ // Save old cache state before final step
+ oldCacheState = oldCacheState[:0]
+ for _, c := range d.caches {
+ oldCacheState = append(oldCacheState, c.State()...)
+ }
+
+ // Final token + sampling
+ withStream(func() {
+ x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
+ mlx.Eval(x) // Materialize before any other evals
+ logits := d.model.Forward(x, d.caches)
+ d.token = sample(logits, d.temp, d.vocabSize)
+ })
+ // Keep cache state (token auto-kept by AsyncEval)
+ for _, c := range d.caches {
+ mlx.Keep(c.State()...)
+ }
+ mlx.AsyncEval(d.token)
+
+ // Free old cache state from before final step
+ for _, arr := range oldCacheState {
+ if arr != nil {
+ arr.Free()
+ }
+ }
+
+ mlx.ClearCache()
+
+ return processed + len(inputIDs)
+}
+
+func (d *Decoder) step() int32 {
+ prevToken := d.token
+
+ // Save old cache state (reuse preallocated slice)
+ d.oldCacheState = d.oldCacheState[:0]
+ for _, c := range d.caches {
+ d.oldCacheState = append(d.oldCacheState, c.State()...)
+ }
+
+ withStream(func() {
+ logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
+ d.token = sample(logits, d.temp, d.vocabSize)
+ })
+ // Keep token and new cache state so they survive cleanup
+ mlx.Keep(d.token)
+ for _, c := range d.caches {
+ mlx.Keep(c.State()...)
+ }
+ mlx.AsyncEval(d.token)
+
+ // Sync on previous token (GPU already working on next step)
+ val := prevToken.ItemInt32()
+
+ // Free old token and old cache state
+ prevToken.Free()
+ for _, arr := range d.oldCacheState {
+ arr.Free()
+ }
+ return val
+}
+
+// sample samples from logits using temperature scaling
+func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
+ // Get last position logits: [1, L, vocab] -> [vocab]
+ shape := logits.Shape()
+ seqLen := shape[1]
+ lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
+ lastLogits = mlx.Reshape(lastLogits, vocabSize)
+
+ if temp <= 0 || temp < 0.01 {
+ // Greedy decoding
+ return mlx.Argmax(lastLogits, -1, false)
+ }
+
+ // Apply temperature scaling
+ scaled := mlx.DivScalar(lastLogits, temp)
+ return mlx.RandomCategorical(scaled, -1, 1)
+}
+
+// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
+func (s *server) loadLLMModel() error {
+ // Load the manifest to get model information
+ manifest, err := imagegen.LoadManifest(s.modelName)
+ if err != nil {
+ return fmt.Errorf("failed to load manifest: %w", err)
+ }
+
+ // Detect model architecture from config.json
+ configData, err := manifest.ReadConfig("config.json")
+ if err != nil {
+ return fmt.Errorf("failed to read config.json: %w", err)
+ }
+
+ var modelConfig struct {
+ Architectures []string `json:"architectures"`
+ ModelType string `json:"model_type"`
+ }
+ if err := json.Unmarshal(configData, &modelConfig); err != nil {
+ return fmt.Errorf("failed to parse config.json: %w", err)
+ }
+
+ arch := ""
+ if len(modelConfig.Architectures) > 0 {
+ arch = modelConfig.Architectures[0]
+ }
+ if arch == "" {
+ arch = modelConfig.ModelType
+ }
+
+ slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
+
+ // Load the appropriate model based on architecture
+ var model TextModel
+ archLower := strings.ToLower(arch)
+
+ switch {
+ case strings.Contains(archLower, "glm4moelite"):
+ m, err := glm4_moe_lite.LoadFromManifest(manifest)
+ if err != nil {
+ return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
+ }
+ model = m
+ slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
+
+ default:
+ return fmt.Errorf("LLM architecture %q is not yet supported. "+
+ "Supported architectures: glm4-moe-lite. "+
+ "Please convert your model to GGUF format or use a supported architecture", arch)
+ }
+
+ s.llmModel = &llmState{
+ model: model,
+ }
+
+ return nil
+}
+
+// handleLLMCompletion handles LLM text generation requests.
+func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
+ if s.llmModel == nil {
+ http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
+ return
+ }
+
+ // Serialize generation requests
+ llmMu.Lock()
+ defer llmMu.Unlock()
+
+ if err := s.llmGenerate(w, r, req); err != nil {
+ slog.Error("LLM generation failed", "error", err)
+ // Don't send error if we've already started streaming
+ }
+}
+
+// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
+func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
+ state := s.llmModel
+
+ // Set up streaming response
+ w.Header().Set("Content-Type", "application/x-ndjson")
+ w.Header().Set("Transfer-Encoding", "chunked")
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ return errors.New("streaming not supported")
+ }
+
+ tok := state.model.Tokenizer()
+
+ // The prompt is already formatted by the server using the model's renderer
+ // (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
+ prompt := req.Prompt
+
+ // Tokenize the prompt
+ inputIDs := tok.Encode(prompt, true)
+ slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
+
+ // Generation parameters
+ maxTokens := int(state.model.MaxContextLength())
+ if maxTokens <= 0 {
+ maxTokens = 4096
+ }
+ if req.Options != nil && req.Options.NumPredict > 0 {
+ maxTokens = req.Options.NumPredict
+ }
+
+ temperature := float32(0.7)
+ if req.Options != nil && req.Options.Temperature > 0 {
+ temperature = float32(req.Options.Temperature)
+ }
+
+ // Enable MLX compilation for better performance
+ mlx.EnableCompile()
+
+ // Create decoder with fresh caches
+ dec := NewDecoder(state.model, temperature)
+
+ prefillStart := time.Now()
+ prefillTokens := dec.prefill(inputIDs)
+ // Prefill measurement includes time to first token
+ firstToken := dec.step()
+ prefillDuration := time.Since(prefillStart)
+ promptEvalDuration := prefillDuration
+
+ enc := json.NewEncoder(w)
+ ctx := r.Context()
+ generated := 0
+ stopReason := "max_tokens"
+
+ // Handle first token
+ generated++
+ if tok.IsEOS(firstToken) {
+ resp := Response{
+ Done: true,
+ StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
+ PromptEvalCount: prefillTokens,
+ PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
+ }
+ enc.Encode(resp)
+ flusher.Flush()
+ return nil
+ }
+
+ text := tok.Decode([]int32{firstToken})
+ resp := Response{Content: text}
+ enc.Encode(resp)
+ flusher.Flush()
+
+ genStart := time.Now()
+
+ // Generation loop
+ for n := 1; n < maxTokens; n++ {
+ // Check for cancellation
+ select {
+ case <-ctx.Done():
+ stopReason = fmt.Sprintf("context_cancelled:%d", generated)
+ break
+ default:
+ }
+ if stopReason != "max_tokens" {
+ break
+ }
+
+ token := dec.step()
+ generated++
+
+ if tok.IsEOS(token) {
+ stopReason = fmt.Sprintf("eos_token:%d", token)
+ break
+ }
+
+ text := tok.Decode([]int32{token})
+
+ // Check for stop sequences
+ if req.Options != nil && len(req.Options.Stop) > 0 {
+ shouldStop := false
+ var matchedStop string
+ for _, stop := range req.Options.Stop {
+ if strings.Contains(text, stop) {
+ text = strings.Split(text, stop)[0]
+ shouldStop = true
+ matchedStop = stop
+ break
+ }
+ }
+ if shouldStop {
+ if text != "" {
+ resp := Response{Content: text}
+ enc.Encode(resp)
+ flusher.Flush()
+ }
+ stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
+ break
+ }
+ }
+
+ resp := Response{Content: text}
+ enc.Encode(resp)
+ flusher.Flush()
+
+ // Periodically clear MLX cache
+ if n%256 == 0 {
+ mlx.ClearCache()
+ }
+ }
+
+ // Clean up
+ mlx.ClearCache()
+
+ // Send final response with stats
+ evalDuration := time.Since(genStart)
+ resp = Response{
+ Done: true,
+ StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
+ PromptEvalCount: prefillTokens,
+ PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
+ EvalCount: generated,
+ EvalDuration: int(evalDuration.Nanoseconds()),
+ }
+ enc.Encode(resp)
+ flusher.Flush()
+
+ return nil
+}
diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go
new file mode 100644
index 000000000..df0b7eafa
--- /dev/null
+++ b/x/mlxrunner/runner.go
@@ -0,0 +1,204 @@
+//go:build mlx
+
+// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
+package mlxrunner
+
+import (
+ "context"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+
+ "github.com/ollama/ollama/envconfig"
+ "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/mlx"
+)
+
+// Execute is the entry point for the unified MLX runner subprocess.
+func Execute(args []string) error {
+ // Set up logging with appropriate level from environment
+ slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
+
+ fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
+ modelName := fs.String("model", "", "path to model")
+ port := fs.Int("port", 0, "port to listen on")
+
+ if err := fs.Parse(args); err != nil {
+ return err
+ }
+
+ if *modelName == "" {
+ return fmt.Errorf("--model is required")
+ }
+ if *port == 0 {
+ return fmt.Errorf("--port is required")
+ }
+
+ // Initialize MLX
+ if err := mlx.InitMLX(); err != nil {
+ slog.Error("unable to initialize MLX", "error", err)
+ return err
+ }
+ slog.Info("MLX library initialized")
+
+ // Detect model type from capabilities
+ mode := detectModelMode(*modelName)
+ slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
+
+ // Create and start server
+ server, err := newServer(*modelName, *port, mode)
+ if err != nil {
+ return fmt.Errorf("failed to create server: %w", err)
+ }
+
+ // Set up HTTP handlers
+ mux := http.NewServeMux()
+ mux.HandleFunc("/health", server.healthHandler)
+ mux.HandleFunc("/completion", server.completionHandler)
+
+ // LLM-specific endpoints
+ if mode == ModeLLM {
+ mux.HandleFunc("/tokenize", server.tokenizeHandler)
+ mux.HandleFunc("/embedding", server.embeddingHandler)
+ }
+
+ httpServer := &http.Server{
+ Addr: fmt.Sprintf("127.0.0.1:%d", *port),
+ Handler: mux,
+ }
+
+ // Handle shutdown
+ done := make(chan struct{})
+ go func() {
+ sigCh := make(chan os.Signal, 1)
+ signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
+ <-sigCh
+ slog.Info("shutting down mlx runner")
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ httpServer.Shutdown(ctx)
+ close(done)
+ }()
+
+ slog.Info("mlx runner listening", "addr", httpServer.Addr)
+ if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
+ return err
+ }
+
+ <-done
+ return nil
+}
+
+// detectModelMode determines whether a model is an LLM or image generation model.
+func detectModelMode(modelName string) ModelMode {
+ // Check for image generation model by looking at model_index.json
+ modelType := imagegen.DetectModelType(modelName)
+ if modelType != "" {
+ // Known image generation model types
+ switch modelType {
+ case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
+ return ModeImageGen
+ }
+ }
+
+ // Default to LLM mode for safetensors models without known image gen types
+ return ModeLLM
+}
+
+// server holds the model and handles HTTP requests.
+type server struct {
+ mode ModelMode
+ modelName string
+ port int
+
+ // Image generation model (when mode == ModeImageGen)
+ imageModel ImageModel
+
+ // LLM model (when mode == ModeLLM)
+ llmModel *llmState
+}
+
+// newServer creates a new server instance and loads the appropriate model.
+func newServer(modelName string, port int, mode ModelMode) (*server, error) {
+ s := &server{
+ mode: mode,
+ modelName: modelName,
+ port: port,
+ }
+
+ switch mode {
+ case ModeImageGen:
+ if err := s.loadImageModel(); err != nil {
+ return nil, fmt.Errorf("failed to load image model: %w", err)
+ }
+ case ModeLLM:
+ if err := s.loadLLMModel(); err != nil {
+ return nil, fmt.Errorf("failed to load LLM model: %w", err)
+ }
+ }
+
+ return s, nil
+}
+
+func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
+ resp := HealthResponse{Status: "ok"}
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+}
+
+func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req Request
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ switch s.mode {
+ case ModeImageGen:
+ s.handleImageCompletion(w, r, req)
+ case ModeLLM:
+ s.handleLLMCompletion(w, r, req)
+ }
+}
+
+func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
+ if s.llmModel == nil {
+ http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
+ return
+ }
+
+ var req struct {
+ Content string `json:"content"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ tok := s.llmModel.model.Tokenizer()
+ tokens := tok.Encode(req.Content, false)
+
+ // Convert int32 to int for JSON response
+ intTokens := make([]int, len(tokens))
+ for i, t := range tokens {
+ intTokens[i] = int(t)
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
+}
+
+func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
+}
diff --git a/x/imagegen/runner/runner_stub.go b/x/mlxrunner/runner_stub.go
similarity index 60%
rename from x/imagegen/runner/runner_stub.go
rename to x/mlxrunner/runner_stub.go
index eafad7bca..3b0f35500 100644
--- a/x/imagegen/runner/runner_stub.go
+++ b/x/mlxrunner/runner_stub.go
@@ -1,10 +1,10 @@
//go:build !mlx
-package runner
+package mlxrunner
import "errors"
// Execute returns an error when not built with MLX support.
func Execute(args []string) error {
- return errors.New("image generation not available: build with mlx tag")
+ return errors.New("MLX runner not available: build with mlx tag")
}
diff --git a/x/imagegen/server.go b/x/mlxrunner/server.go
similarity index 58%
rename from x/imagegen/server.go
rename to x/mlxrunner/server.go
index ca9367694..89dd0bf04 100644
--- a/x/imagegen/server.go
+++ b/x/mlxrunner/server.go
@@ -1,4 +1,4 @@
-package imagegen
+package mlxrunner
import (
"bufio"
@@ -23,19 +23,19 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/x/imagegen"
)
-// Server wraps an image generation subprocess to implement llm.LlamaServer.
+// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
-// like any other model. The plan is to eventually bring this into the llm/ package
-// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
-// separate allows for independent iteration on image generation support.
+// like any other model. It supports both LLM (safetensors) and image generation models.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
+ mode ModelMode
vramSize uint64
done chan error
client *http.Client
@@ -43,10 +43,10 @@ type Server struct {
lastErrLock sync.Mutex
}
-// NewServer spawns a new image generation subprocess and waits until it's ready.
-func NewServer(modelName string) (*Server, error) {
+// NewServer spawns a new MLX runner subprocess and waits until it's ready.
+func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Validate platform support before attempting to start
- if err := CheckPlatformSupport(); err != nil {
+ if err := imagegen.CheckPlatformSupport(); err != nil {
return nil, err
}
@@ -71,8 +71,8 @@ func NewServer(modelName string) (*Server, error) {
exe = eval
}
- // Spawn subprocess: ollama runner --image-engine --model --port
- cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
+ // Spawn subprocess: ollama runner --mlx-engine --model --port
+ cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
@@ -105,17 +105,21 @@ func NewServer(modelName string) (*Server, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
- // Get total weight size from manifest
- var weightSize uint64
- if manifest, err := LoadManifest(modelName); err == nil {
- weightSize = uint64(manifest.TotalTensorSize())
+ // Estimate VRAM based on tensor size from manifest
+ var vramSize uint64
+ if manifest, err := imagegen.LoadManifest(modelName); err == nil {
+ vramSize = uint64(manifest.TotalTensorSize())
+ } else {
+ // Fallback: default to 8GB if manifest can't be loaded
+ vramSize = 8 * 1024 * 1024 * 1024
}
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
- vramSize: weightSize,
+ mode: mode,
+ vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
@@ -126,23 +130,23 @@ func NewServer(modelName string) (*Server, error) {
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
- slog.Info("image-runner", "msg", scanner.Text())
+ slog.Info("mlx-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
- slog.Warn("image-runner", "msg", line)
+ slog.Warn("mlx-runner", "msg", line)
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
- slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
+ slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
if err := cmd.Start(); err != nil {
- return nil, fmt.Errorf("failed to start image runner: %w", err)
+ return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
@@ -165,6 +169,7 @@ func (s *Server) ModelPath() string {
return s.modelName
}
+// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
@@ -200,18 +205,18 @@ func (s *Server) waitUntilRunning() error {
// Include recent stderr lines for better error context
errMsg := s.getLastErr()
if errMsg != "" {
- return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
+ return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
}
- return fmt.Errorf("image runner exited unexpectedly: %w", err)
+ return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := s.getLastErr()
if errMsg != "" {
- return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
+ return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
}
- return errors.New("timeout waiting for image runner to start")
+ return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := s.Ping(ctx); err == nil {
- slog.Info("image runner is ready", "port", s.port)
+ slog.Info("mlx runner is ready", "port", s.port)
return nil
}
}
@@ -225,8 +230,12 @@ func (s *Server) getLastErr() string {
return s.lastErr
}
-func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
+// WaitUntilRunning satisfies the LlamaServer interface.
+func (s *Server) WaitUntilRunning(ctx context.Context) error {
+ return nil
+}
+// Completion handles both text and image generation requests.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed
if seed == 0 {
@@ -240,22 +249,26 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
// Build request for subprocess
- creq := struct {
- Prompt string `json:"prompt"`
- Width int32 `json:"width,omitempty"`
- Height int32 `json:"height,omitempty"`
- Steps int32 `json:"steps,omitempty"`
- Seed int64 `json:"seed,omitempty"`
- Images [][]byte `json:"images,omitempty"`
- }{
+ creq := Request{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
- Steps: req.Steps,
+ Steps: int(req.Steps),
Seed: seed,
Images: images,
}
+ // Pass LLM options if present
+ if req.Options != nil {
+ creq.Options = &RequestOptions{
+ NumPredict: req.Options.NumPredict,
+ Temperature: float64(req.Options.Temperature),
+ TopP: float64(req.Options.TopP),
+ TopK: req.Options.TopK,
+ Stop: req.Options.Stop,
+ }
+ }
+
body, err := json.Marshal(creq)
if err != nil {
return err
@@ -282,25 +295,40 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
- // Parse subprocess response (has singular "image" field)
+ // Parse subprocess response
var raw struct {
- Image string `json:"image,omitempty"`
- Content string `json:"content,omitempty"`
- Done bool `json:"done"`
- Step int `json:"step,omitempty"`
- Total int `json:"total,omitempty"`
+ Image string `json:"image,omitempty"`
+ Content string `json:"content,omitempty"`
+ Done bool `json:"done"`
+ Step int `json:"step,omitempty"`
+ Total int `json:"total,omitempty"`
+ StopReason string `json:"stop_reason,omitempty"`
+ PromptEvalCount int `json:"prompt_eval_count,omitempty"`
+ PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
+ EvalCount int `json:"eval_count,omitempty"`
+ EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
+ slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
+ // Log stop reason when generation completes
+ if raw.Done && raw.StopReason != "" {
+ slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
+ }
+
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
- Content: raw.Content,
- Done: raw.Done,
- Step: raw.Step,
- TotalSteps: raw.Total,
- Image: raw.Image,
+ Content: raw.Content,
+ Done: raw.Done,
+ Step: raw.Step,
+ TotalSteps: raw.Total,
+ Image: raw.Image,
+ PromptEvalCount: raw.PromptEvalCount,
+ PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
+ EvalCount: raw.EvalCount,
+ EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
@@ -309,7 +337,20 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
}
- return scanner.Err()
+ // Scanner exited without receiving Done - connection was likely closed
+ scanErr := scanner.Err()
+ if scanErr != nil {
+ slog.Error("mlx scanner error", "error", scanErr)
+ } else {
+ slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
+ }
+
+ // Check if subprocess is still alive
+ if s.HasExited() {
+ slog.Error("mlx subprocess has exited unexpectedly")
+ }
+
+ return scanErr
}
// Close terminates the subprocess.
@@ -318,7 +359,7 @@ func (s *Server) Close() error {
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
- slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
+ slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
s.cmd.Process.Signal(os.Interrupt)
// Wait briefly for graceful shutdown
@@ -347,23 +388,56 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
-// Context length is not applicable for image generation.
+// ContextLength returns the context length (not applicable for image generation).
func (s *Server) ContextLength() int {
return 0
}
+// Embedding returns embeddings for the input.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
- return nil, 0, errors.New("not supported")
+ return nil, 0, errors.New("embeddings not supported for MLX models")
}
+// Tokenize tokenizes the input content.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
- return nil, errors.New("not supported")
+ body, err := json.Marshal(map[string]string{"content": content})
+ if err != nil {
+ return nil, err
+ }
+
+ url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := s.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
+ }
+
+ var result struct {
+ Tokens []int `json:"tokens"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, err
+ }
+
+ return result.Tokens, nil
}
+// Detokenize converts tokens back to text.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
- return "", errors.New("not supported")
+ return "", errors.New("detokenization not supported for MLX models")
}
+// Pid returns the process ID of the subprocess.
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
@@ -373,9 +447,17 @@ func (s *Server) Pid() int {
return -1
}
-func (s *Server) GetPort() int { return s.port }
-func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
+// GetPort returns the port the subprocess is listening on.
+func (s *Server) GetPort() int {
+ return s.port
+}
+// GetDeviceInfos returns device information.
+func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
+ return nil
+}
+
+// HasExited returns whether the subprocess has exited.
func (s *Server) HasExited() bool {
select {
case <-s.done:
diff --git a/x/mlxrunner/types.go b/x/mlxrunner/types.go
new file mode 100644
index 000000000..cd22d3941
--- /dev/null
+++ b/x/mlxrunner/types.go
@@ -0,0 +1,81 @@
+// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
+//
+// This package handles safetensors models created with `ollama create --experimental`,
+// supporting both text generation (LLM) and image generation (diffusion) models
+// through a single unified interface.
+package mlxrunner
+
+// Request is the request format for completion requests.
+type Request struct {
+ Prompt string `json:"prompt"`
+
+ // LLM-specific fields
+ Options *RequestOptions `json:"options,omitempty"`
+
+ // Image generation fields
+ Width int32 `json:"width,omitempty"`
+ Height int32 `json:"height,omitempty"`
+ Steps int `json:"steps,omitempty"`
+ Seed int64 `json:"seed,omitempty"`
+ Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
+}
+
+// RequestOptions contains LLM-specific generation options.
+type RequestOptions struct {
+ NumPredict int `json:"num_predict,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+}
+
+// Response is streamed back for each progress update.
+type Response struct {
+ // Text generation response
+ Content string `json:"content,omitempty"`
+
+ // Image generation response
+ Image string `json:"image,omitempty"` // Base64-encoded PNG
+
+ // Common fields
+ Done bool `json:"done"`
+ DoneReason int `json:"done_reason,omitempty"`
+ StopReason string `json:"stop_reason,omitempty"` // Debug: why generation stopped
+
+ // Progress fields
+ Step int `json:"step,omitempty"`
+ Total int `json:"total,omitempty"`
+
+ // Statistics
+ PromptEvalCount int `json:"prompt_eval_count,omitempty"`
+ PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
+ EvalCount int `json:"eval_count,omitempty"`
+ EvalDuration int `json:"eval_duration,omitempty"`
+}
+
+// HealthResponse is returned by the health endpoint.
+type HealthResponse struct {
+ Status string `json:"status"`
+ Progress float32 `json:"progress,omitempty"`
+}
+
+// ModelMode represents the type of model being run.
+type ModelMode int
+
+const (
+ // ModeLLM indicates a text generation model.
+ ModeLLM ModelMode = iota
+ // ModeImageGen indicates an image generation model.
+ ModeImageGen
+)
+
+func (m ModelMode) String() string {
+ switch m {
+ case ModeLLM:
+ return "llm"
+ case ModeImageGen:
+ return "imagegen"
+ default:
+ return "unknown"
+ }
+}
diff --git a/x/model/models/gemma3/model.go b/x/model/models/gemma3/model.go
index 23f78f207..4072122c6 100644
--- a/x/model/models/gemma3/model.go
+++ b/x/model/models/gemma3/model.go
@@ -87,7 +87,7 @@ func New(c fs.Config) (model.Model, error) {
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
// TODO need to implement sliding window...
- m.Cache = kvcache.NewMLXCausalCache()
+ m.Cache = kvcache.NewCausalCache()
return &m, nil
}
diff --git a/x/server/show.go b/x/server/show.go
index dd95774d3..652293e77 100644
--- a/x/server/show.go
+++ b/x/server/show.go
@@ -163,9 +163,18 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
+// For quantized models, groups weight/scale/qbias into single entries with detected quantization type.
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
var tensors []api.Tensor
+ // First pass: collect all tensor info and identify scale tensors
+ type tensorData struct {
+ info *safetensorsTensorInfo
+ digest string
+ }
+ tensorMap := make(map[string]*tensorData)
+ scaleMap := make(map[string]*tensorData) // base name -> scale tensor info
+
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
@@ -178,28 +187,96 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
}
info, err := readSafetensorsHeader(blobPath)
if err != nil {
- // Skip tensors we can't read
continue
}
- // Convert shape from int to uint64
- shape := make([]uint64, len(info.Shape))
- for i, s := range info.Shape {
- shape[i] = uint64(s)
+ td := &tensorData{info: info, digest: layer.Digest}
+
+ if strings.HasSuffix(layer.Name, "_scale") {
+ baseName := strings.TrimSuffix(layer.Name, "_scale")
+ scaleMap[baseName] = td
+ } else if strings.HasSuffix(layer.Name, "_qbias") {
+ // Skip qbias tensors - they're included with the quantized weight
+ continue
+ } else {
+ tensorMap[layer.Name] = td
+ }
+ }
+
+ // Second pass: build tensor list with quantization info
+ for _, layer := range mf.Layers {
+ if layer.MediaType != manifest.MediaTypeImageTensor {
+ continue
}
- tensors = append(tensors, api.Tensor{
- Name: layer.Name,
- Type: info.Dtype,
- Shape: shape,
- })
+ // Skip scale and qbias tensors
+ if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") {
+ continue
+ }
+
+ td := tensorMap[layer.Name]
+ if td == nil {
+ continue
+ }
+
+ // Check if this tensor has a corresponding scale tensor (quantized)
+ scaleTd := scaleMap[layer.Name]
+ if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 {
+ // Quantized tensor - detect bits from shapes
+ weightCols := td.info.Shape[len(td.info.Shape)-1]
+ scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1]
+
+ // Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4
+ // Q4 uses group_size=32: weightCols * 8 / scaleCols = 32
+ // Q8 uses group_size=64: weightCols * 4 / scaleCols = 64
+ var bits int
+ var quantType string
+ if weightCols*8/scaleCols == 32 {
+ bits = 4
+ quantType = "Q4"
+ } else if weightCols*4/scaleCols == 64 {
+ bits = 8
+ quantType = "Q8"
+ } else {
+ // Unknown quantization, show raw
+ quantType = td.info.Dtype
+ }
+
+ // Calculate unpacked shape
+ shape := make([]uint64, len(td.info.Shape))
+ for i, s := range td.info.Shape {
+ shape[i] = uint64(s)
+ }
+ if bits > 0 {
+ packFactor := int64(32 / bits)
+ shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor)
+ }
+
+ tensors = append(tensors, api.Tensor{
+ Name: layer.Name,
+ Type: quantType,
+ Shape: shape,
+ })
+ } else {
+ // Non-quantized tensor
+ shape := make([]uint64, len(td.info.Shape))
+ for i, s := range td.info.Shape {
+ shape[i] = uint64(s)
+ }
+
+ tensors = append(tensors, api.Tensor{
+ Name: layer.Name,
+ Type: td.info.Dtype,
+ Shape: shape,
+ })
+ }
}
return tensors, nil
}
// GetSafetensorsDtype returns the quantization type for a safetensors model.
-// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
+// Reads from model_index.json first, falls back to detection from tensor names.
// Otherwise returns the torch_dtype from config.json.
func GetSafetensorsDtype(name model.Name) (string, error) {
mf, err := manifest.ParseNamedManifest(name)
@@ -207,16 +284,38 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
- // Check if model is quantized by looking for _scale tensors
+ // First try to read quantization from model_index.json
+ var modelIndex struct {
+ Quantization string `json:"quantization"`
+ }
+ if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" {
+ return modelIndex.Quantization, nil
+ }
+
+ // Fallback: detect from tensor names
+ hasScales := false
+ hasQBias := false
for _, layer := range mf.Layers {
if layer.MediaType == manifest.MediaTypeImageTensor {
if strings.HasSuffix(layer.Name, "_scale") {
- // Model is quantized - return FP8 (affine quantization)
- return "FP8", nil
+ hasScales = true
+ }
+ if strings.HasSuffix(layer.Name, "_qbias") {
+ hasQBias = true
}
}
}
+ if hasScales {
+ if hasQBias {
+ // Affine mode (has scale + qbias) - could be Q4 or Q8
+ // Default to Q4 as it's more common
+ return "Q4", nil
+ }
+ // No qbias = NVFP4
+ return "NVFP4", nil
+ }
+
// Not quantized - return torch_dtype from config.json
var cfg struct {
TorchDtype string `json:"torch_dtype"`