chore: move x/mlxrunner into x/imagegen (#14100)

This commit is contained in:
Michael Yang
2026-02-05 18:25:56 -08:00
committed by GitHub
parent f1373193dc
commit 6ddd8862cd
21 changed files with 143 additions and 152 deletions

View File

@@ -3,7 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/mlxrunner"
"github.com/ollama/ollama/x/imagegen"
)
func Execute(args []string) error {
@@ -11,22 +11,13 @@ func Execute(args []string) error {
args = args[1:]
}
var newRunner bool
var mlxRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--mlx-engine" {
args = args[1:]
mlxRunner = true
}
if mlxRunner {
return mlxrunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)
if len(args) > 0 {
switch args[0] {
case "--ollama-engine":
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
}
}
return llamarunner.Execute(args)
}

View File

@@ -52,7 +52,7 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
xserver "github.com/ollama/ollama/x/server"
)
@@ -1106,7 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization

View File

@@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -567,16 +567,16 @@ iGPUScan:
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode mlxrunner.ModelMode
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = mlxrunner.ModeImageGen
mode = imagegen.ModeImageGen
} else {
mode = mlxrunner.ModeLLM
mode = imagegen.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := mlxrunner.NewServer(modelName, mode)
server, err := imagegen.NewServer(modelName, mode)
if err != nil {
req.errCh <- err
return true

View File

@@ -1,6 +1,6 @@
//go:build mlx
package mlxrunner
package imagegen
import (
"context"
@@ -11,7 +11,7 @@ import (
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
@@ -28,8 +28,8 @@ var imageGenMu sync.Mutex
func (s *server) loadImageModel() error {
// Check memory requirements before loading
var requiredMemory uint64
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
requiredMemory = uint64(manifest.TotalTensorSize())
if modelManifest, err := manifest.LoadManifest(s.modelName); err == nil {
requiredMemory = uint64(modelManifest.TotalTensorSize())
}
availableMemory := mlx.GetMemoryLimit()
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
@@ -38,7 +38,7 @@ func (s *server) loadImageModel() error {
}
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(s.modelName)
modelType := DetectModelType(s.modelName)
slog.Info("detected image model type", "type", modelType)
var model ImageModel
@@ -108,7 +108,7 @@ func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, r
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
imageData, err := EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)

View File

@@ -1,6 +1,6 @@
//go:build mlx
package mlxrunner
package imagegen
import (
"encoding/json"
@@ -12,8 +12,8 @@ import (
"sync"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -197,13 +197,13 @@ func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
func (s *server) loadLLMModel() error {
// Load the manifest to get model information
manifest, err := imagegen.LoadManifest(s.modelName)
modelManifest, err := manifest.LoadManifest(s.modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Detect model architecture from config.json
configData, err := manifest.ReadConfig("config.json")
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return fmt.Errorf("failed to read config.json: %w", err)
}
@@ -232,7 +232,7 @@ func (s *server) loadLLMModel() error {
switch {
case strings.Contains(archLower, "glm4moelite"):
m, err := glm4_moe_lite.LoadFromManifest(manifest)
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
if err != nil {
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
}

View File

@@ -1,4 +1,4 @@
package imagegen
package manifest
import (
"encoding/json"

View File

@@ -1,4 +1,4 @@
package imagegen
package manifest
import (
"path/filepath"

View File

@@ -1,6 +1,6 @@
//go:build mlx
package imagegen
package manifest
import (
"fmt"
@@ -15,9 +15,9 @@ import (
type ManifestWeights struct {
manifest *ModelManifest
component string
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
}
// LoadWeightsFromManifest creates a weight loader from manifest storage.

View File

@@ -14,6 +14,8 @@ import (
"encoding/json"
"fmt"
"runtime"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// SupportedBackends lists the backends that support image generation.
@@ -41,8 +43,8 @@ func CheckPlatformSupport() error {
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
manifest, err := LoadManifest(modelName)
if err == nil && manifest.HasTensorLayers() {
modelManifest, err := manifest.LoadManifest(modelName)
if err == nil && modelManifest.HasTensorLayers() {
return modelName
}
return ""
@@ -52,12 +54,12 @@ func ResolveModelName(modelName string) string {
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.
func DetectModelType(modelName string) string {
manifest, err := LoadManifest(modelName)
modelManifest, err := manifest.LoadManifest(modelName)
if err != nil {
return ""
}
data, err := manifest.ReadConfig("model_index.json")
data, err := modelManifest.ReadConfig("model_index.json")
if err != nil {
return ""
}

View File

@@ -12,7 +12,7 @@ import (
"math"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -61,7 +61,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -14,19 +14,19 @@ import (
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
@@ -392,12 +392,12 @@ type Flux2Transformer2DModel struct {
}
// Load loads the Flux2 transformer from ollama blob storage.
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
@@ -412,7 +412,7 @@ func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
}
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -15,21 +15,21 @@ import (
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
@@ -356,18 +356,18 @@ func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
}
// Load loads the Flux2 VAE from ollama blob storage.
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -9,8 +9,8 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,11 +38,11 @@ type Config struct {
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
@@ -82,7 +82,7 @@ type MLAAttention struct {
// Absorbed MLA projections (derived from kv_b_proj)
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
EmbedQ *nn.MultiLinear `weight:"-"`
EmbedQ *nn.MultiLinear `weight:"-"`
UnembedOut *nn.MultiLinear `weight:"-"`
// Output projection
@@ -194,8 +194,8 @@ func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
Gate nn.LinearLayer `weight:"mlp.gate"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
}
// Forward computes expert selection indices and scores
@@ -617,9 +617,9 @@ func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numE
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
// Read config from manifest
configData, err := manifest.ReadConfig("config.json")
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
@@ -634,7 +634,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
cfg.Scale = computeScale(&cfg)
// Load weights from manifest blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
@@ -653,7 +653,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
}
// Load tokenizer from manifest with config files for EOS token detection
tokData, err := manifest.ReadConfig("tokenizer.json")
tokData, err := modelManifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
@@ -664,12 +664,12 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
}
// Try to load generation_config.json if available (preferred source for EOS)
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
// Try to load tokenizer_config.json if available
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}

View File

@@ -7,7 +7,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -181,19 +181,19 @@ type TextEncoder struct {
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -7,8 +7,8 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,7 +38,7 @@ type TransformerConfig struct {
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
@@ -85,9 +85,9 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
@@ -103,10 +103,9 @@ type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
OutDim int32 // computed from W2
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -132,11 +131,11 @@ type Attention struct {
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
@@ -288,13 +287,13 @@ func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -350,7 +349,7 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
OutDim int32 // computed from Output
}
// Forward computes final output
@@ -401,12 +400,12 @@ type Transformer struct {
}
// Load loads the Z-Image transformer from ollama blob storage.
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
if len(cfg.AllPatchSize) > 0 {
@@ -417,7 +416,7 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -562,7 +562,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
prev := x
prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
@@ -570,7 +570,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 2: Upsample conv
{
prev := x
prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
@@ -643,16 +643,16 @@ type VAEDecoder struct {
}
// Load loads the VAE decoder from ollama blob storage.
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}

View File

@@ -8,8 +8,8 @@ import (
"fmt"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -18,14 +18,14 @@ import (
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
@@ -58,7 +58,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}

View File

@@ -1,7 +1,7 @@
//go:build mlx
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
package mlxrunner
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
package imagegen
import (
"context"
@@ -16,7 +16,6 @@ import (
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
@@ -98,7 +97,7 @@ func Execute(args []string) error {
// detectModelMode determines whether a model is an LLM or image generation model.
func detectModelMode(modelName string) ModelMode {
// Check for image generation model by looking at model_index.json
modelType := imagegen.DetectModelType(modelName)
modelType := DetectModelType(modelName)
if modelType != "" {
// Known image generation model types
switch modelType {

View File

@@ -1,6 +1,6 @@
//go:build !mlx
package mlxrunner
package imagegen
import "errors"

View File

@@ -1,4 +1,4 @@
package mlxrunner
package imagegen
import (
"bufio"
@@ -23,7 +23,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
@@ -46,7 +46,7 @@ type Server struct {
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Validate platform support before attempting to start
if err := imagegen.CheckPlatformSupport(); err != nil {
if err := CheckPlatformSupport(); err != nil {
return nil, err
}
@@ -71,8 +71,8 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
exe = eval
}
// Spawn subprocess: ollama runner --mlx-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
@@ -107,8 +107,8 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if manifest, err := imagegen.LoadManifest(modelName); err == nil {
vramSize = uint64(manifest.TotalTensorSize())
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
// Fallback: default to 8GB if manifest can't be loaded
vramSize = 8 * 1024 * 1024 * 1024

View File

@@ -1,9 +1,9 @@
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
//
// This package handles safetensors models created with `ollama create --experimental`,
// supporting both text generation (LLM) and image generation (diffusion) models
// through a single unified interface.
package mlxrunner
package imagegen
// Request is the request format for completion requests.
type Request struct {