mirror of
https://github.com/ollama/ollama.git
synced 2026-03-11 17:34:04 -05:00
chore: move x/mlxrunner into x/imagegen (#14100)
This commit is contained in:
@@ -3,7 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -11,22 +11,13 @@ func Execute(args []string) error {
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var mlxRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--mlx-engine" {
|
||||
args = args[1:]
|
||||
mlxRunner = true
|
||||
}
|
||||
|
||||
if mlxRunner {
|
||||
return mlxrunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
if len(args) > 0 {
|
||||
switch args[0] {
|
||||
case "--ollama-engine":
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--imagegen-engine":
|
||||
return imagegen.Execute(args[1:])
|
||||
}
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
@@ -1106,7 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -567,16 +567,16 @@ iGPUScan:
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode mlxrunner.ModelMode
|
||||
var mode imagegen.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = mlxrunner.ModeImageGen
|
||||
mode = imagegen.ModeImageGen
|
||||
} else {
|
||||
mode = mlxrunner.ModeLLM
|
||||
mode = imagegen.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := mlxrunner.NewServer(modelName, mode)
|
||||
server, err := imagegen.NewServer(modelName, mode)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
@@ -28,8 +28,8 @@ var imageGenMu sync.Mutex
|
||||
func (s *server) loadImageModel() error {
|
||||
// Check memory requirements before loading
|
||||
var requiredMemory uint64
|
||||
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
|
||||
requiredMemory = uint64(manifest.TotalTensorSize())
|
||||
if modelManifest, err := manifest.LoadManifest(s.modelName); err == nil {
|
||||
requiredMemory = uint64(modelManifest.TotalTensorSize())
|
||||
}
|
||||
availableMemory := mlx.GetMemoryLimit()
|
||||
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
|
||||
@@ -38,7 +38,7 @@ func (s *server) loadImageModel() error {
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(s.modelName)
|
||||
modelType := DetectModelType(s.modelName)
|
||||
slog.Info("detected image model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
@@ -108,7 +108,7 @@ func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, r
|
||||
}
|
||||
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
imageData, err := EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -197,13 +197,13 @@ func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
|
||||
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
|
||||
func (s *server) loadLLMModel() error {
|
||||
// Load the manifest to get model information
|
||||
manifest, err := imagegen.LoadManifest(s.modelName)
|
||||
modelManifest, err := manifest.LoadManifest(s.modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Detect model architecture from config.json
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func (s *server) loadLLMModel() error {
|
||||
|
||||
switch {
|
||||
case strings.Contains(archLower, "glm4moelite"):
|
||||
m, err := glm4_moe_lite.LoadFromManifest(manifest)
|
||||
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
type ManifestWeights struct {
|
||||
manifest *ModelManifest
|
||||
component string
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
}
|
||||
|
||||
// LoadWeightsFromManifest creates a weight loader from manifest storage.
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// SupportedBackends lists the backends that support image generation.
|
||||
@@ -41,8 +43,8 @@ func CheckPlatformSupport() error {
|
||||
// ResolveModelName checks if a model name is a known image generation model.
|
||||
// Returns the normalized model name if found, empty string otherwise.
|
||||
func ResolveModelName(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err == nil && manifest.HasTensorLayers() {
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
if err == nil && modelManifest.HasTensorLayers() {
|
||||
return modelName
|
||||
}
|
||||
return ""
|
||||
@@ -52,12 +54,12 @@ func ResolveModelName(modelName string) string {
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
func DetectModelType(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
data, err := modelManifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -61,7 +61,7 @@ func (m *Model) Load(modelName string) error {
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
manifest, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -14,19 +14,19 @@ import (
|
||||
|
||||
// TransformerConfig holds Flux2 transformer configuration
|
||||
type TransformerConfig struct {
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
}
|
||||
|
||||
// Computed dimensions
|
||||
@@ -392,12 +392,12 @@ type Flux2Transformer2DModel struct {
|
||||
}
|
||||
|
||||
// Load loads the Flux2 transformer from ollama blob storage.
|
||||
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
|
||||
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = &cfg
|
||||
@@ -412,7 +412,7 @@ func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
|
||||
}
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -15,21 +15,21 @@ import (
|
||||
|
||||
// VAEConfig holds AutoencoderKLFlux2 configuration
|
||||
type VAEConfig struct {
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
}
|
||||
|
||||
// BatchNorm2D implements 2D batch normalization with running statistics
|
||||
@@ -356,18 +356,18 @@ func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
}
|
||||
|
||||
// Load loads the Flux2 VAE from ollama blob storage.
|
||||
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
|
||||
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -38,11 +38,11 @@ type Config struct {
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
@@ -82,7 +82,7 @@ type MLAAttention struct {
|
||||
// Absorbed MLA projections (derived from kv_b_proj)
|
||||
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
UnembedOut *nn.MultiLinear `weight:"-"`
|
||||
|
||||
// Output projection
|
||||
@@ -194,8 +194,8 @@ func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
@@ -617,9 +617,9 @@ func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numE
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
// Read config from manifest
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
@@ -634,7 +634,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
cfg.Scale = computeScale(&cfg)
|
||||
|
||||
// Load weights from manifest blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
@@ -653,7 +653,7 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config files for EOS token detection
|
||||
tokData, err := manifest.ReadConfig("tokenizer.json")
|
||||
tokData, err := modelManifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
@@ -664,12 +664,12 @@ func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
// Try to load generation_config.json if available (preferred source for EOS)
|
||||
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
// Try to load tokenizer_config.json if available
|
||||
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -181,19 +181,19 @@ type TextEncoder struct {
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
|
||||
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Config
|
||||
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
m.Layers = make([]*Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
@@ -38,7 +38,7 @@ type TransformerConfig struct {
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 nn.LinearLayer `weight:"mlp.0"`
|
||||
Linear2 nn.LinearLayer `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings -> [B, 256]
|
||||
@@ -85,9 +85,9 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
|
||||
@@ -103,10 +103,9 @@ type FeedForward struct {
|
||||
W1 nn.LinearLayer `weight:"w1"` // gate projection
|
||||
W2 nn.LinearLayer `weight:"w2"` // down projection
|
||||
W3 nn.LinearLayer `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
@@ -132,11 +131,11 @@ type Attention struct {
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Fused QKV (computed at init time for efficiency, not loaded from weights)
|
||||
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
// Computed fields (not loaded from weights)
|
||||
NHeads int32 `weight:"-"`
|
||||
HeadDim int32 `weight:"-"`
|
||||
@@ -288,13 +287,13 @@ func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
|
||||
// TransformerBlock is a single transformer block with optional AdaLN modulation
|
||||
type TransformerBlock struct {
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
@@ -350,7 +349,7 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
|
||||
type FinalLayer struct {
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
// Forward computes final output
|
||||
@@ -401,12 +400,12 @@ type Transformer struct {
|
||||
}
|
||||
|
||||
// Load loads the Z-Image transformer from ollama blob storage.
|
||||
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
@@ -417,7 +416,7 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
@@ -562,7 +562,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
if ub.Upsample != nil {
|
||||
// Stage 1: Upsample2x (nearest neighbor)
|
||||
{
|
||||
prev := x
|
||||
prev := x
|
||||
x = Upsample2x(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
@@ -570,7 +570,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Stage 2: Upsample conv
|
||||
{
|
||||
prev := x
|
||||
prev := x
|
||||
x = ub.Upsample.Forward(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
@@ -643,16 +643,16 @@ type VAEDecoder struct {
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from ollama blob storage.
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
@@ -18,14 +18,14 @@ import (
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
@@ -58,7 +58,7 @@ func (m *Model) Load(modelName string) error {
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
manifest, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
|
||||
package mlxrunner
|
||||
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
@@ -98,7 +97,7 @@ func Execute(args []string) error {
|
||||
// detectModelMode determines whether a model is an LLM or image generation model.
|
||||
func detectModelMode(modelName string) ModelMode {
|
||||
// Check for image generation model by looking at model_index.json
|
||||
modelType := imagegen.DetectModelType(modelName)
|
||||
modelType := DetectModelType(modelName)
|
||||
if modelType != "" {
|
||||
// Known image generation model types
|
||||
switch modelType {
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !mlx
|
||||
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
import "errors"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
|
||||
@@ -46,7 +46,7 @@ type Server struct {
|
||||
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
|
||||
func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
if err := imagegen.CheckPlatformSupport(); err != nil {
|
||||
if err := CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -71,8 +71,8 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --mlx-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
@@ -107,8 +107,8 @@ func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if manifest, err := imagegen.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(manifest.TotalTensorSize())
|
||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||
} else {
|
||||
// Fallback: default to 8GB if manifest can't be loaded
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
@@ -1,9 +1,9 @@
|
||||
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
|
||||
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
|
||||
//
|
||||
// This package handles safetensors models created with `ollama create --experimental`,
|
||||
// supporting both text generation (LLM) and image generation (diffusion) models
|
||||
// through a single unified interface.
|
||||
package mlxrunner
|
||||
package imagegen
|
||||
|
||||
// Request is the request format for completion requests.
|
||||
type Request struct {
|
||||
Reference in New Issue
Block a user