diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index f5799fc1e..82bfb291e 100755 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -518,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string { // StreamConverter manages state for converting Ollama streaming responses to Anthropic format type StreamConverter struct { - ID string - Model string - firstWrite bool - contentIndex int - inputTokens int - outputTokens int - thinkingStarted bool - thinkingDone bool - textStarted bool - toolCallsSent map[string]bool + ID string + Model string + firstWrite bool + contentIndex int + inputTokens int + outputTokens int + estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0) + thinkingStarted bool + thinkingDone bool + textStarted bool + toolCallsSent map[string]bool } -func NewStreamConverter(id, model string) *StreamConverter { +func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter { return &StreamConverter{ - ID: id, - Model: model, - firstWrite: true, - toolCallsSent: make(map[string]bool), + ID: id, + Model: model, + firstWrite: true, + estimatedInputTokens: estimatedInputTokens, + toolCallsSent: make(map[string]bool), } } @@ -551,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent { if c.firstWrite { c.firstWrite = false + // Use actual metrics if available, otherwise use estimate c.inputTokens = r.Metrics.PromptEvalCount + if c.inputTokens == 0 && c.estimatedInputTokens > 0 { + c.inputTokens = c.estimatedInputTokens + } events = append(events, StreamEvent{ Event: "message_start", @@ -779,3 +785,123 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments { } return args } + +// CountTokensRequest represents an Anthropic count_tokens request +type CountTokensRequest struct { + Model string `json:"model"` + Messages []MessageParam `json:"messages"` + System any `json:"system,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Thinking *ThinkingConfig `json:"thinking,omitempty"` +} + +// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic) +func EstimateInputTokens(req MessagesRequest) int { + return estimateTokens(CountTokensRequest{ + Model: req.Model, + Messages: req.Messages, + System: req.System, + Tools: req.Tools, + Thinking: req.Thinking, + }) +} + +// CountTokensResponse represents an Anthropic count_tokens response +type CountTokensResponse struct { + InputTokens int `json:"input_tokens"` +} + +// estimateTokens returns a rough estimate of tokens (len/4). +// TODO: Replace with actual tokenization via Tokenize API for accuracy. +// Current len/4 heuristic is a rough approximation (~4 chars/token average). +func estimateTokens(req CountTokensRequest) int { + var totalLen int + + // Count system prompt + if req.System != nil { + totalLen += countAnyContent(req.System) + } + + // Count messages + for _, msg := range req.Messages { + // Count role (always present) + totalLen += len(msg.Role) + // Count content + contentLen := countAnyContent(msg.Content) + totalLen += contentLen + } + + for _, tool := range req.Tools { + totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema) + } + + // Return len/4 as rough token estimate, minimum 1 if there's any content + tokens := totalLen / 4 + if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) { + tokens = 1 + } + return tokens +} + +func countAnyContent(content any) int { + if content == nil { + return 0 + } + + switch c := content.(type) { + case string: + return len(c) + case []any: + total := 0 + for _, block := range c { + total += countContentBlock(block) + } + return total + default: + if data, err := json.Marshal(content); err == nil { + return len(data) + } + return 0 + } +} + +func countContentBlock(block any) int { + blockMap, ok := block.(map[string]any) + if !ok { + if s, ok := block.(string); ok { + return len(s) + } + return 0 + } + + total := 0 + blockType, _ := blockMap["type"].(string) + + if text, ok := blockMap["text"].(string); ok { + total += len(text) + } + + if thinking, ok := blockMap["thinking"].(string); ok { + total += len(thinking) + } + + if blockType == "tool_use" { + if data, err := json.Marshal(blockMap); err == nil { + total += len(data) + } + } + + if blockType == "tool_result" { + if data, err := json.Marshal(blockMap); err == nil { + total += len(data) + } + } + + if source, ok := blockMap["source"].(map[string]any); ok { + if data, ok := source["data"].(string); ok { + total += len(data) + } + } + + return total +} diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go index 1c2a4a868..2f2717bf0 100755 --- a/anthropic/anthropic_test.go +++ b/anthropic/anthropic_test.go @@ -321,8 +321,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) { } } -// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only -// a thinking block (no text, images, or tool calls) are preserved and not dropped. func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) { req := MessagesRequest{ Model: "test-model", @@ -605,7 +603,7 @@ func TestGenerateMessageID(t *testing.T) { } func TestStreamConverter_Basic(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) // First chunk resp1 := api.ChatResponse{ @@ -678,7 +676,7 @@ func TestStreamConverter_Basic(t *testing.T) { } func TestStreamConverter_WithToolCalls(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) resp := api.ChatResponse{ Model: "test-model", @@ -731,7 +729,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) { func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { // Test that unmarshalable arguments (like channels) are handled gracefully // and don't cause a panic or corrupt stream - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) // Create a channel which cannot be JSON marshaled unmarshalable := make(chan int) @@ -778,7 +776,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { // Test that valid tool calls still work when mixed with invalid ones - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) unmarshalable := make(chan int) badArgs := api.NewToolCallFunctionArguments() @@ -842,10 +840,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { } } -// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields -// are serialized in JSON output. The Anthropic SDK requires these fields to be present -// (even when empty) in content_block_start events to properly accumulate streaming deltas. -// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'" func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) { tests := []struct { name string @@ -899,11 +893,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) { } } -// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start -// events include the required empty fields for SDK compatibility. func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { t.Run("text block start includes empty text", func(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) resp := api.ChatResponse{ Model: "test-model", @@ -937,7 +929,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { }) t.Run("thinking block start includes empty thinking", func(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model") + conv := NewStreamConverter("msg_123", "test-model", 0) resp := api.ChatResponse{ Model: "test-model", @@ -969,3 +961,105 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { } }) } + +func TestEstimateTokens_SimpleMessage(t *testing.T) { + req := CountTokensRequest{ + Model: "test-model", + Messages: []MessageParam{ + {Role: "user", Content: "Hello, world!"}, + }, + } + + tokens := estimateTokens(req) + + // "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens + if tokens < 1 { + t.Errorf("expected at least 1 token, got %d", tokens) + } + // Sanity check: shouldn't be wildly off + if tokens > 10 { + t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens) + } +} + +func TestEstimateTokens_WithSystemPrompt(t *testing.T) { + req := CountTokensRequest{ + Model: "test-model", + System: "You are a helpful assistant.", + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + }, + } + + tokens := estimateTokens(req) + + // System prompt adds to count + if tokens < 5 { + t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens) + } +} + +func TestEstimateTokens_WithTools(t *testing.T) { + req := CountTokensRequest{ + Model: "test-model", + Messages: []MessageParam{ + {Role: "user", Content: "What's the weather?"}, + }, + Tools: []Tool{ + { + Name: "get_weather", + Description: "Get the current weather for a location", + InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`), + }, + }, + } + + tokens := estimateTokens(req) + + // Tools add significant content + if tokens < 10 { + t.Errorf("expected at least 10 tokens with tools, got %d", tokens) + } +} + +func TestEstimateTokens_WithThinking(t *testing.T) { + req := CountTokensRequest{ + Model: "test-model", + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + { + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "thinking", + "thinking": "Let me think about this carefully...", + }, + map[string]any{ + "type": "text", + "text": "Here is my response.", + }, + }, + }, + }, + } + + tokens := estimateTokens(req) + + // Thinking content should be counted + if tokens < 10 { + t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens) + } +} + +func TestEstimateTokens_EmptyContent(t *testing.T) { + req := CountTokensRequest{ + Model: "test-model", + Messages: []MessageParam{}, + } + + tokens := estimateTokens(req) + + if tokens != 0 { + t.Errorf("expected 0 tokens for empty content, got %d", tokens) + } +} diff --git a/api/client.go b/api/client.go index d70672a6b..a09aa33bc 100644 --- a/api/client.go +++ b/api/client.go @@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) { } return &resp, nil } + +// AliasRequest is the request body for creating or updating a model alias. +type AliasRequest struct { + Alias string `json:"alias"` + Target string `json:"target"` + PrefixMatching bool `json:"prefix_matching,omitempty"` +} + +// SetAliasExperimental creates or updates a model alias via the experimental aliases API. +func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error { + return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil) +} + +// AliasDeleteRequest is the request body for deleting a model alias. +type AliasDeleteRequest struct { + Alias string `json:"alias"` +} + +// DeleteAliasExperimental deletes a model alias via the experimental aliases API. +func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error { + return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil) +} diff --git a/cmd/cmd.go b/cmd/cmd.go index 00611a3d1..5497d4e71 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1763,7 +1763,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { return err } if err := startApp(cmd.Context(), client); err != nil { - return fmt.Errorf("ollama server not responding - %w", err) + return err } } return nil diff --git a/cmd/config/claude.go b/cmd/config/claude.go index 80a72f564..1cb0ec907 100644 --- a/cmd/config/claude.go +++ b/cmd/config/claude.go @@ -1,18 +1,23 @@ package config import ( + "context" "fmt" "os" "os/exec" "path/filepath" "runtime" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" ) -// Claude implements Runner for Claude Code integration +// Claude implements Runner and AliasConfigurer for Claude Code integration type Claude struct{} +// Compile-time check that Claude implements AliasConfigurer +var _ AliasConfigurer = (*Claude)(nil) + func (c *Claude) String() string { return "Claude Code" } func (c *Claude) args(model string, extra []string) []string { @@ -60,3 +65,104 @@ func (c *Claude) Run(model string, args []string) error { ) return cmd.Run() } + +// ConfigureAliases sets up model aliases for Claude Code. +// model: the model to use (if empty, user will be prompted to select) +// aliases: existing alias configuration to preserve/update +// Cloud-only: subagent routing (fast model) is gated to cloud models only until +// there is a better strategy for prompt caching on local models. +func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) { + aliases := make(map[string]string) + for k, v := range existingAliases { + aliases[k] = v + } + + if model != "" { + aliases["primary"] = model + } + + if !force && aliases["primary"] != "" { + client, _ := api.ClientFromEnvironment() + if isCloudModel(ctx, client, aliases["primary"]) { + if isCloudModel(ctx, client, aliases["fast"]) { + return aliases, false, nil + } + } else { + delete(aliases, "fast") + return aliases, false, nil + } + } + + items, existingModels, cloudModels, client, err := listModels(ctx) + if err != nil { + return nil, false, err + } + + fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset) + + if aliases["primary"] == "" || force { + primary, err := selectPrompt("Select model:", items) + fmt.Fprintf(os.Stderr, "\033[3A\033[J") + if err != nil { + return nil, false, err + } + if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil { + return nil, false, err + } + if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil { + return nil, false, err + } + aliases["primary"] = primary + } + + if isCloudModel(ctx, client, aliases["primary"]) { + if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) { + aliases["fast"] = aliases["primary"] + } + } else { + delete(aliases, "fast") + } + + return aliases, true, nil +} + +// SetAliases syncs the configured aliases to the Ollama server using prefix matching. +// Cloud-only: for local models (fast is empty), we delete any existing aliases to +// prevent stale routing to a previous cloud model. +func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + prefixes := []string{"claude-sonnet-", "claude-haiku-"} + + if aliases["fast"] == "" { + for _, prefix := range prefixes { + _ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix}) + } + return nil + } + + prefixAliases := map[string]string{ + "claude-sonnet-": aliases["primary"], + "claude-haiku-": aliases["fast"], + } + + var errs []string + for prefix, target := range prefixAliases { + req := &api.AliasRequest{ + Alias: prefix, + Target: target, + PrefixMatching: true, + } + if err := client.SetAliasExperimental(ctx, req); err != nil { + errs = append(errs, prefix) + } + } + + if len(errs) > 0 { + return fmt.Errorf("failed to set aliases: %v", errs) + } + return nil +} diff --git a/cmd/config/config.go b/cmd/config/config.go index 5f98bd5ed..6e8845031 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -13,7 +13,8 @@ import ( ) type integration struct { - Models []string `json:"models"` + Models []string `json:"models"` + Aliases map[string]string `json:"aliases,omitempty"` } type config struct { @@ -133,8 +134,16 @@ func saveIntegration(appName string, models []string) error { return err } - cfg.Integrations[strings.ToLower(appName)] = &integration{ - Models: models, + key := strings.ToLower(appName) + existing := cfg.Integrations[key] + var aliases map[string]string + if existing != nil && existing.Aliases != nil { + aliases = existing.Aliases + } + + cfg.Integrations[key] = &integration{ + Models: models, + Aliases: aliases, } return save(cfg) @@ -154,6 +163,29 @@ func loadIntegration(appName string) (*integration, error) { return ic, nil } +func saveAliases(appName string, aliases map[string]string) error { + if appName == "" { + return errors.New("app name cannot be empty") + } + + cfg, err := load() + if err != nil { + return err + } + + key := strings.ToLower(appName) + existing := cfg.Integrations[key] + if existing == nil { + existing = &integration{} + } + + // Replace aliases entirely (not merge) so deletions are persisted + existing.Aliases = aliases + + cfg.Integrations[key] = existing + return save(cfg) +} + func listIntegrations() ([]integration, error) { cfg, err := load() if err != nil { diff --git a/cmd/config/config_cloud_test.go b/cmd/config/config_cloud_test.go new file mode 100644 index 000000000..b1002a54c --- /dev/null +++ b/cmd/config/config_cloud_test.go @@ -0,0 +1,677 @@ +package config + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" +) + +func TestSetAliases_CloudModel(t *testing.T) { + // Test the SetAliases logic by checking the alias map behavior + aliases := map[string]string{ + "primary": "kimi-k2.5:cloud", + "fast": "kimi-k2.5:cloud", + } + + // Verify fast is set (cloud model behavior) + if aliases["fast"] == "" { + t.Error("cloud model should have fast alias set") + } + if aliases["fast"] != aliases["primary"] { + t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"]) + } +} + +func TestSetAliases_LocalModel(t *testing.T) { + aliases := map[string]string{ + "primary": "llama3.2:latest", + } + // Simulate local model behavior: fast should be empty + delete(aliases, "fast") + + if aliases["fast"] != "" { + t.Error("local model should have empty fast alias") + } +} + +func TestSaveAliases_ReplacesNotMerges(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // First save with both primary and fast + initial := map[string]string{ + "primary": "cloud-model", + "fast": "cloud-model", + } + if err := saveAliases("claude", initial); err != nil { + t.Fatalf("failed to save initial aliases: %v", err) + } + + // Verify both are saved + loaded, err := loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if loaded.Aliases["fast"] != "cloud-model" { + t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"]) + } + + // Now save without fast (simulating switch to local model) + updated := map[string]string{ + "primary": "local-model", + // fast intentionally missing + } + if err := saveAliases("claude", updated); err != nil { + t.Fatalf("failed to save updated aliases: %v", err) + } + + // Verify fast is GONE (not merged/preserved) + loaded, err = loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load after update: %v", err) + } + if loaded.Aliases["fast"] != "" { + t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"]) + } + if loaded.Aliases["primary"] != "local-model" { + t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"]) + } +} + +func TestSaveAliases_PreservesModels(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // First save integration with models + if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil { + t.Fatalf("failed to save integration: %v", err) + } + + // Then update aliases + aliases := map[string]string{"primary": "new-model"} + if err := saveAliases("claude", aliases); err != nil { + t.Fatalf("failed to save aliases: %v", err) + } + + // Verify models are preserved + loaded, err := loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if len(loaded.Models) != 2 || loaded.Models[0] != "model1" { + t.Errorf("models should be preserved, got %v", loaded.Models) + } +} + +// TestSaveAliases_EmptyMap clears all aliases +func TestSaveAliases_EmptyMap(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Save with aliases + if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil { + t.Fatalf("failed to save: %v", err) + } + + // Save empty map + if err := saveAliases("claude", map[string]string{}); err != nil { + t.Fatalf("failed to save empty: %v", err) + } + + loaded, err := loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if len(loaded.Aliases) != 0 { + t.Errorf("aliases should be empty, got %v", loaded.Aliases) + } +} + +// TestSaveAliases_NilMap handles nil gracefully +func TestSaveAliases_NilMap(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Save with aliases first + if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil { + t.Fatalf("failed to save: %v", err) + } + + // Save nil map - should clear aliases + if err := saveAliases("claude", nil); err != nil { + t.Fatalf("failed to save nil: %v", err) + } + + loaded, err := loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if len(loaded.Aliases) > 0 { + t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases) + } +} + +// TestSaveAliases_EmptyAppName returns error +func TestSaveAliases_EmptyAppName(t *testing.T) { + err := saveAliases("", map[string]string{"primary": "model"}) + if err == nil { + t.Error("expected error for empty app name") + } +} + +func TestSaveAliases_CaseInsensitive(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil { + t.Fatalf("failed to save: %v", err) + } + + // Load with different case + loaded, err := loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if loaded.Aliases["primary"] != "model1" { + t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"]) + } + + // Update with different case + if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil { + t.Fatalf("failed to update: %v", err) + } + + loaded, err = loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load after update: %v", err) + } + if loaded.Aliases["primary"] != "model2" { + t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"]) + } +} + +// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist +func TestSaveAliases_CreatesIntegration(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Save aliases for non-existent integration + if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil { + t.Fatalf("failed to save: %v", err) + } + + loaded, err := loadIntegration("newintegration") + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if loaded.Aliases["primary"] != "model" { + t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"]) + } +} + +func TestConfigureAliases_AliasMap(t *testing.T) { + t.Run("cloud model auto-sets fast to primary", func(t *testing.T) { + aliases := make(map[string]string) + aliases["primary"] = "cloud-model" + + // Simulate cloud model behavior + isCloud := true + if isCloud { + if aliases["fast"] == "" { + aliases["fast"] = aliases["primary"] + } + } + + if aliases["fast"] != "cloud-model" { + t.Errorf("expected fast=cloud-model, got %q", aliases["fast"]) + } + }) + + t.Run("cloud model preserves custom fast", func(t *testing.T) { + aliases := map[string]string{ + "primary": "cloud-model", + "fast": "custom-fast-model", + } + + // Simulate cloud model behavior - should preserve existing fast + isCloud := true + if isCloud { + if aliases["fast"] == "" { + aliases["fast"] = aliases["primary"] + } + } + + if aliases["fast"] != "custom-fast-model" { + t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"]) + } + }) + + t.Run("local model clears fast", func(t *testing.T) { + aliases := map[string]string{ + "primary": "local-model", + "fast": "should-be-cleared", + } + + // Simulate local model behavior + isCloud := false + if !isCloud { + delete(aliases, "fast") + } + + if aliases["fast"] != "" { + t.Errorf("expected fast to be cleared, got %q", aliases["fast"]) + } + }) + + t.Run("switching cloud to local clears fast", func(t *testing.T) { + // Start with cloud config + aliases := map[string]string{ + "primary": "cloud-model", + "fast": "cloud-model", + } + + // Switch to local + aliases["primary"] = "local-model" + isCloud := false + if !isCloud { + delete(aliases, "fast") + } + + if aliases["fast"] != "" { + t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"]) + } + if aliases["primary"] != "local-model" { + t.Errorf("primary should be updated, got %q", aliases["primary"]) + } + }) + + t.Run("switching local to cloud sets fast", func(t *testing.T) { + // Start with local config (no fast) + aliases := map[string]string{ + "primary": "local-model", + } + + // Switch to cloud + aliases["primary"] = "cloud-model" + isCloud := true + if isCloud { + if aliases["fast"] == "" { + aliases["fast"] = aliases["primary"] + } + } + + if aliases["fast"] != "cloud-model" { + t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"]) + } + }) +} + +func TestSetAliases_PrefixMapping(t *testing.T) { + // This tests the expected mapping without needing a real client + aliases := map[string]string{ + "primary": "my-cloud-model", + "fast": "my-fast-model", + } + + expectedMappings := map[string]string{ + "claude-sonnet-": aliases["primary"], + "claude-haiku-": aliases["fast"], + } + + if expectedMappings["claude-sonnet-"] != "my-cloud-model" { + t.Errorf("claude-sonnet- should map to primary") + } + if expectedMappings["claude-haiku-"] != "my-fast-model" { + t.Errorf("claude-haiku- should map to fast") + } +} + +func TestSetAliases_LocalDeletesPrefixes(t *testing.T) { + aliases := map[string]string{ + "primary": "local-model", + // fast is empty/missing - indicates local model + } + + prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"} + + // Verify the logic: when fast is empty, we should delete + if aliases["fast"] != "" { + t.Error("fast should be empty for local model") + } + + // Verify we have the right prefixes to delete + if len(prefixesToDelete) != 2 { + t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete)) + } +} + +// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior +func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Simulate: server fails, config should NOT be saved + serverErr := errors.New("server unavailable") + + if serverErr == nil { + t.Error("config should NOT be saved when server fails") + } +} + +// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update +func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Simulate: server succeeds, config should be saved + var serverErr error + if serverErr != nil { + t.Fatal("server should succeed") + } + + if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil { + t.Fatalf("saveAliases failed: %v", err) + } + + // Verify it was actually saved + loaded, err := loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if loaded.Aliases["primary"] != "model" { + t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"]) + } +} + +func TestConfigFile_PreservesUnknownFields(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Write config with extra fields + configPath := filepath.Join(tmpDir, ".ollama", "config.json") + os.MkdirAll(filepath.Dir(configPath), 0o755) + + // Note: Our config struct only has Integrations, so top-level unknown fields + // won't be preserved by our current implementation. This test documents that. + initialConfig := `{ + "integrations": { + "claude": { + "models": ["model1"], + "aliases": {"primary": "model1"}, + "unknownField": "should be lost" + } + }, + "topLevelUnknown": "will be lost" +}` + os.WriteFile(configPath, []byte(initialConfig), 0o644) + + // Update aliases + if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil { + t.Fatalf("failed to save: %v", err) + } + + // Read raw file to check + data, _ := os.ReadFile(configPath) + content := string(data) + + // models should be preserved + if !contains(content, "model1") { + t.Error("models should be preserved") + } + + // primary should be updated + if !contains(content, "model2") { + t.Error("primary should be updated to model2") + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func TestClaudeImplementsAliasConfigurer(t *testing.T) { + c := &Claude{} + var _ AliasConfigurer = c // Compile-time check +} + +func TestModelNameEdgeCases(t *testing.T) { + testCases := []struct { + name string + model string + }{ + {"simple", "llama3.2"}, + {"with tag", "llama3.2:latest"}, + {"with cloud tag", "kimi-k2.5:cloud"}, + {"with namespace", "library/llama3.2"}, + {"with dots", "glm-4.7-flash"}, + {"with numbers", "qwen3:8b"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + aliases := map[string]string{"primary": tc.model} + if err := saveAliases("claude", aliases); err != nil { + t.Fatalf("failed to save model %q: %v", tc.model, err) + } + + loaded, err := loadIntegration("claude") + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if loaded.Aliases["primary"] != tc.model { + t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"]) + } + }) + } +} + +func TestSwitchingScenarios(t *testing.T) { + t.Run("cloud to local removes fast", func(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Initial cloud config + if err := saveAliases("claude", map[string]string{ + "primary": "cloud-model", + "fast": "cloud-model", + }); err != nil { + t.Fatal(err) + } + + // Switch to local (no fast) + if err := saveAliases("claude", map[string]string{ + "primary": "local-model", + }); err != nil { + t.Fatal(err) + } + + loaded, _ := loadIntegration("claude") + if loaded.Aliases["fast"] != "" { + t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"]) + } + if loaded.Aliases["primary"] != "local-model" { + t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"]) + } + }) + + t.Run("local to cloud adds fast", func(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Initial local config + if err := saveAliases("claude", map[string]string{ + "primary": "local-model", + }); err != nil { + t.Fatal(err) + } + + // Switch to cloud (with fast) + if err := saveAliases("claude", map[string]string{ + "primary": "cloud-model", + "fast": "cloud-model", + }); err != nil { + t.Fatal(err) + } + + loaded, _ := loadIntegration("claude") + if loaded.Aliases["fast"] != "cloud-model" { + t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"]) + } + }) + + t.Run("cloud to different cloud updates both", func(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Initial cloud config + if err := saveAliases("claude", map[string]string{ + "primary": "cloud-model-1", + "fast": "cloud-model-1", + }); err != nil { + t.Fatal(err) + } + + // Switch to different cloud + if err := saveAliases("claude", map[string]string{ + "primary": "cloud-model-2", + "fast": "cloud-model-2", + }); err != nil { + t.Fatal(err) + } + + loaded, _ := loadIntegration("claude") + if loaded.Aliases["primary"] != "cloud-model-2" { + t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"]) + } + if loaded.Aliases["fast"] != "cloud-model-2" { + t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"]) + } + }) +} + +func TestToolCapabilityFiltering(t *testing.T) { + t.Run("all models checked for tool capability", func(t *testing.T) { + // Both cloud and local models are checked for tool capability via Show API + // Only models with "tools" in capabilities are included + m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true} + if !m.ToolCapable { + t.Error("tool capable model should be marked as such") + } + }) + + t.Run("modelInfo includes ToolCapable field", func(t *testing.T) { + m := modelInfo{Name: "test", Remote: true, ToolCapable: true} + if !m.ToolCapable { + t.Error("ToolCapable field should be accessible") + } + }) +} + +func TestIsCloudModel_RequiresClient(t *testing.T) { + t.Run("nil client always returns false", func(t *testing.T) { + // isCloudModel now only uses Show API, no suffix detection + if isCloudModel(context.Background(), nil, "model:cloud") { + t.Error("nil client should return false regardless of suffix") + } + if isCloudModel(context.Background(), nil, "local-model") { + t.Error("nil client should return false") + } + }) +} + +func TestModelsAndAliasesMustStayInSync(t *testing.T) { + t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Save aliases with one model + if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil { + t.Fatal(err) + } + + // Save integration with same model (this is the pattern we use) + if err := saveIntegration("claude", []string{"model-a"}); err != nil { + t.Fatal(err) + } + + loaded, _ := loadIntegration("claude") + if loaded.Aliases["primary"] != loaded.Models[0] { + t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0]) + } + }) + + t.Run("out of sync config is detectable", func(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Simulate out-of-sync state (like manual edit or bug) + if err := saveIntegration("claude", []string{"old-model"}); err != nil { + t.Fatal(err) + } + if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil { + t.Fatal(err) + } + + loaded, _ := loadIntegration("claude") + + // They should be different (this is the bug state) + if loaded.Models[0] == loaded.Aliases["primary"] { + t.Error("expected out-of-sync state for this test") + } + + // The fix: when updating aliases, also update models + if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil { + t.Fatal(err) + } + + loaded, _ = loadIntegration("claude") + if loaded.Models[0] != loaded.Aliases["primary"] { + t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)", + loaded.Models[0], loaded.Aliases["primary"]) + } + }) + + t.Run("updating primary alias updates models too", func(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Initial state + if err := saveIntegration("claude", []string{"initial-model"}); err != nil { + t.Fatal(err) + } + if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil { + t.Fatal(err) + } + + // Update aliases AND models together + newAliases := map[string]string{"primary": "updated-model"} + if err := saveAliases("claude", newAliases); err != nil { + t.Fatal(err) + } + if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil { + t.Fatal(err) + } + + loaded, _ := loadIntegration("claude") + if loaded.Models[0] != "updated-model" { + t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0]) + } + if loaded.Aliases["primary"] != "updated-model" { + t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"]) + } + }) +} diff --git a/cmd/config/config_test.go b/cmd/config/config_test.go index ae87c6a40..a491a276f 100644 --- a/cmd/config/config_test.go +++ b/cmd/config/config_test.go @@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) { } }) + t.Run("save and load aliases", func(t *testing.T) { + models := []string{"llama3.2"} + if err := saveIntegration("claude", models); err != nil { + t.Fatal(err) + } + aliases := map[string]string{ + "primary": "llama3.2:70b", + "fast": "llama3.2:8b", + } + if err := saveAliases("claude", aliases); err != nil { + t.Fatal(err) + } + + config, err := loadIntegration("claude") + if err != nil { + t.Fatal(err) + } + if config.Aliases == nil { + t.Fatal("expected aliases to be saved") + } + for k, v := range aliases { + if config.Aliases[k] != v { + t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k]) + } + } + }) + + t.Run("saveIntegration preserves aliases", func(t *testing.T) { + if err := saveIntegration("claude", []string{"model-a"}); err != nil { + t.Fatal(err) + } + if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil { + t.Fatal(err) + } + + if err := saveIntegration("claude", []string{"model-b"}); err != nil { + t.Fatal(err) + } + config, err := loadIntegration("claude") + if err != nil { + t.Fatal(err) + } + if config.Aliases["primary"] != "model-a" { + t.Errorf("expected aliases to be preserved, got %v", config.Aliases) + } + }) + t.Run("defaultModel returns first model", func(t *testing.T) { saveIntegration("codex", []string{"model-a", "model-b"}) diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go index 714eae625..6be2f1dc2 100644 --- a/cmd/config/integrations.go +++ b/cmd/config/integrations.go @@ -39,6 +39,15 @@ type Editor interface { Models() []string } +// AliasConfigurer can configure model aliases (e.g., for subagent routing). +// Integrations like Claude and Codex use this to route model requests to local models. +type AliasConfigurer interface { + // ConfigureAliases prompts the user to configure aliases and returns the updated map. + ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) + // SetAliases syncs the configured aliases to the server + SetAliases(ctx context.Context, aliases map[string]string) error +} + // integrations is the registry of available integrations. var integrations = map[string]Runner{ "claude": &Claude{}, @@ -129,7 +138,11 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) { return nil, err } } else { - model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items) + prompt := fmt.Sprintf("Select model for %s:", r) + if _, ok := r.(AliasConfigurer); ok { + prompt = fmt.Sprintf("Select Primary model for %s:", r) + } + model, err := selectPrompt(prompt, items) if err != nil { return nil, err } @@ -157,73 +170,123 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) { } } + if err := ensureAuth(ctx, client, cloudModels, selected); err != nil { + return nil, err + } + + return selected, nil +} + +func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error { + if existingModels[model] { + return nil + } + msg := fmt.Sprintf("Download %s?", model) + if ok, err := confirmPrompt(msg); err != nil { + return err + } else if !ok { + return errCancelled + } + fmt.Fprintf(os.Stderr, "\n") + if err := pullModel(ctx, client, model); err != nil { + return fmt.Errorf("failed to pull %s: %w", model, err) + } + return nil +} + +func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) { + client, err := api.ClientFromEnvironment() + if err != nil { + return nil, nil, nil, nil, err + } + + models, err := client.List(ctx) + if err != nil { + return nil, nil, nil, nil, err + } + + var existing []modelInfo + for _, m := range models.Models { + existing = append(existing, modelInfo{ + Name: m.Name, + Remote: m.RemoteModel != "", + }) + } + + items, _, existingModels, cloudModels := buildModelList(existing, nil, "") + + if len(items) == 0 { + return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull ' first") + } + + return items, existingModels, cloudModels, client, nil +} + +func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error { var selectedCloudModels []string for _, m := range selected { if cloudModels[m] { selectedCloudModels = append(selectedCloudModels, m) } } - if len(selectedCloudModels) > 0 { - // ensure user is signed in - user, err := client.Whoami(ctx) - if err == nil && user != nil && user.Name != "" { - return selected, nil - } + if len(selectedCloudModels) == 0 { + return nil + } - var aErr api.AuthorizationError - if !errors.As(err, &aErr) || aErr.SigninURL == "" { - return nil, err - } + user, err := client.Whoami(ctx) + if err == nil && user != nil && user.Name != "" { + return nil + } - modelList := strings.Join(selectedCloudModels, ", ") - yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) - if err != nil || !yes { - return nil, fmt.Errorf("%s requires sign in", modelList) - } + var aErr api.AuthorizationError + if !errors.As(err, &aErr) || aErr.SigninURL == "" { + return err + } - fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) + modelList := strings.Join(selectedCloudModels, ", ") + yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) + if err != nil || !yes { + return fmt.Errorf("%s requires sign in", modelList) + } - // TODO(parthsareen): extract into auth package for cmd - // Auto-open browser (best effort, fail silently) - switch runtime.GOOS { - case "darwin": - _ = exec.Command("open", aErr.SigninURL).Start() - case "linux": - _ = exec.Command("xdg-open", aErr.SigninURL).Start() - case "windows": - _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start() - } + fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) - spinnerFrames := []string{"|", "/", "-", "\\"} - frame := 0 + switch runtime.GOOS { + case "darwin": + _ = exec.Command("open", aErr.SigninURL).Start() + case "linux": + _ = exec.Command("xdg-open", aErr.SigninURL).Start() + case "windows": + _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start() + } - fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) + spinnerFrames := []string{"|", "/", "-", "\\"} + frame := 0 - ticker := time.NewTicker(200 * time.Millisecond) - defer ticker.Stop() + fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) - for { - select { - case <-ctx.Done(): - fmt.Fprintf(os.Stderr, "\r\033[K") - return nil, ctx.Err() - case <-ticker.C: - frame++ - fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() - // poll every 10th frame (~2 seconds) - if frame%10 == 0 { - u, err := client.Whoami(ctx) - if err == nil && u != nil && u.Name != "" { - fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) - return selected, nil - } + for { + select { + case <-ctx.Done(): + fmt.Fprintf(os.Stderr, "\r\033[K") + return ctx.Err() + case <-ticker.C: + frame++ + fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) + + // poll every 10th frame (~2 seconds) + if frame%10 == 0 { + u, err := client.Whoami(ctx) + if err == nil && u != nil && u.Name != "" { + fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) + return nil } } } } - - return selected, nil } func runIntegration(name, modelName string, args []string) error { @@ -231,10 +294,33 @@ func runIntegration(name, modelName string, args []string) error { if !ok { return fmt.Errorf("unknown integration: %s", name) } + fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName) return r.Run(modelName, args) } +// syncAliases syncs aliases to server and saves locally for an AliasConfigurer. +func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error { + aliases := make(map[string]string) + for k, v := range existing { + aliases[k] = v + } + aliases["primary"] = model + + if isCloudModel(ctx, client, model) { + if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) { + aliases["fast"] = model + } + } else { + delete(aliases, "fast") + } + + if err := ac.SetAliases(ctx, aliases); err != nil { + return err + } + return saveAliases(name, aliases) +} + // LaunchCmd returns the cobra command for launching integrations. func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command { var modelFlag string @@ -302,9 +388,87 @@ Examples: return fmt.Errorf("unknown integration: %s", name) } - if !configFlag && modelFlag == "" { - if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 { - return runIntegration(name, config.Models[0], passArgs) + // Handle AliasConfigurer integrations (claude, codex) + if ac, ok := r.(AliasConfigurer); ok { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + // Validate --model flag if provided + if modelFlag != "" { + if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil { + return fmt.Errorf("model %q not found", modelFlag) + } + } + + var model string + var existingAliases map[string]string + + // Load saved config + if cfg, err := loadIntegration(name); err == nil { + existingAliases = cfg.Aliases + if len(cfg.Models) > 0 { + model = cfg.Models[0] + // AliasConfigurer integrations use single model; sanitize if multiple + if len(cfg.Models) > 1 { + _ = saveIntegration(name, []string{model}) + } + } + } + + // --model flag overrides saved model + if modelFlag != "" { + model = modelFlag + } + + // Validate saved model still exists + if model != "" && modelFlag == "" { + if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil { + fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset) + model = "" + } + } + + // If no valid model or --config flag, show picker + if model == "" || configFlag { + aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag) + if errors.Is(err, errCancelled) { + return nil + } + if err != nil { + return err + } + model = aliases["primary"] + existingAliases = aliases + } + + // Sync aliases and save + if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil { + fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset) + } + if err := saveIntegration(name, []string{model}); err != nil { + return fmt.Errorf("failed to save: %w", err) + } + + // Launch (unless --config without confirmation) + if configFlag { + if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch { + return runIntegration(name, model, passArgs) + } + return nil + } + return runIntegration(name, model, passArgs) + } + + // Validate --model flag for non-AliasConfigurer integrations + if modelFlag != "" { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil { + return fmt.Errorf("model %q not found", modelFlag) } } @@ -380,8 +544,9 @@ Examples: } type modelInfo struct { - Name string - Remote bool + Name string + Remote bool + ToolCapable bool } // buildModelList merges existing models with recommendations, sorts them, and returns @@ -418,7 +583,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) ( continue } items = append(items, rec) - if isCloudModel(rec.Name) { + if strings.HasSuffix(rec.Name, ":cloud") { cloudModels[rec.Name] = true } } @@ -478,8 +643,16 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) ( return items, preChecked, existingModels, cloudModels } -func isCloudModel(name string) bool { - return strings.HasSuffix(name, ":cloud") +// isCloudModel checks if a model is a cloud model using the Show API. +func isCloudModel(ctx context.Context, client *api.Client, name string) bool { + if client == nil { + return false + } + resp, err := client.Show(ctx, &api.ShowRequest{Name: name}) + if err != nil { + return false + } + return resp.RemoteModel != "" } func pullModel(ctx context.Context, client *api.Client, model string) error { diff --git a/cmd/config/integrations_test.go b/cmd/config/integrations_test.go index dd2056e98..b14906db5 100644 --- a/cmd/config/integrations_test.go +++ b/cmd/config/integrations_test.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "slices" "strings" @@ -297,24 +298,15 @@ func TestParseArgs(t *testing.T) { } func TestIsCloudModel(t *testing.T) { - tests := []struct { - name string - want bool - }{ - {"glm-4.7:cloud", true}, - {"kimi-k2.5:cloud", true}, - {"glm-4.7-flash", false}, - {"glm-4.7-flash:latest", false}, - {"cloud-model", false}, - {"model:cloudish", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isCloudModel(tt.name); got != tt.want { - t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want) + // isCloudModel now only uses Show API, so nil client always returns false + t.Run("nil client returns false", func(t *testing.T) { + models := []string{"glm-4.7:cloud", "kimi-k2.5:cloud", "local-model"} + for _, model := range models { + if isCloudModel(context.Background(), nil, model) { + t.Errorf("isCloudModel(%q) with nil client should return false", model) } - }) - } + } + }) } func names(items []selectItem) []string { @@ -509,3 +501,19 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) { t.Error("llama3.2 should not be in cloudModels") } } + +func TestAliasConfigurerInterface(t *testing.T) { + t.Run("claude implements AliasConfigurer", func(t *testing.T) { + claude := &Claude{} + if _, ok := interface{}(claude).(AliasConfigurer); !ok { + t.Error("Claude should implement AliasConfigurer") + } + }) + + t.Run("codex does not implement AliasConfigurer", func(t *testing.T) { + codex := &Codex{} + if _, ok := interface{}(codex).(AliasConfigurer); ok { + t.Error("Codex should not implement AliasConfigurer") + } + }) +} diff --git a/cmd/config/openclaw.go b/cmd/config/openclaw.go index a1e4a537d..3cf025a56 100644 --- a/cmd/config/openclaw.go +++ b/cmd/config/openclaw.go @@ -17,8 +17,6 @@ type Openclaw struct{} func (c *Openclaw) String() string { return "OpenClaw" } -const ansiGreen = "\033[32m" - func (c *Openclaw) Run(model string, args []string) error { bin := "openclaw" if _, err := exec.LookPath(bin); err != nil { diff --git a/cmd/config/selector.go b/cmd/config/selector.go index 956e1f1ea..4bad85948 100644 --- a/cmd/config/selector.go +++ b/cmd/config/selector.go @@ -17,6 +17,7 @@ const ( ansiBold = "\033[1m" ansiReset = "\033[0m" ansiGray = "\033[37m" + ansiGreen = "\033[32m" ansiClearDown = "\033[J" ) diff --git a/cmd/config/selector_test.go b/cmd/config/selector_test.go index 74e8796ee..39557a535 100644 --- a/cmd/config/selector_test.go +++ b/cmd/config/selector_test.go @@ -96,6 +96,14 @@ func TestSelectState(t *testing.T) { } }) + t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) { + s := newSelectState([]selectItem{}) + done, result, err := s.handleInput(eventEnter, 0) + if done || result != "" || err != nil { + t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err) + } + }) + t.Run("Escape_ReturnsCancelledError", func(t *testing.T) { s := newSelectState(items) done, result, err := s.handleInput(eventEscape, 0) @@ -574,8 +582,19 @@ func TestRenderSelect(t *testing.T) { var buf bytes.Buffer renderSelect(&buf, "Select:", s) + output := buf.String() + if !strings.Contains(output, "no matches") { + t.Errorf("expected 'no matches' message, got: %s", output) + } + }) + + t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) { + s := newSelectState([]selectItem{}) + var buf bytes.Buffer + renderSelect(&buf, "Select:", s) + if !strings.Contains(buf.String(), "no matches") { - t.Error("expected 'no matches' message") + t.Error("expected 'no matches' message for empty list with no filter") } }) diff --git a/cmd/start_darwin.go b/cmd/start_darwin.go index 05a4551e1..008adf15e 100644 --- a/cmd/start_darwin.go +++ b/cmd/start_darwin.go @@ -10,19 +10,21 @@ import ( "github.com/ollama/ollama/api" ) +var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it") + func startApp(ctx context.Context, client *api.Client) error { exe, err := os.Executable() if err != nil { - return err + return errNotRunning } link, err := os.Readlink(exe) if err != nil { - return err + return errNotRunning } r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`) m := r.FindStringSubmatch(link) if len(m) != 1 { - return errors.New("could not find ollama app") + return errNotRunning } if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil { return err diff --git a/middleware/anthropic.go b/middleware/anthropic.go index ff55b6ebf..5df87a84a 100644 --- a/middleware/anthropic.go +++ b/middleware/anthropic.go @@ -131,12 +131,15 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc { messageID := anthropic.GenerateMessageID() + // Estimate input tokens for streaming (actual count not available until generation completes) + estimatedTokens := anthropic.EstimateInputTokens(req) + w := &AnthropicWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, stream: req.Stream, id: messageID, model: req.Model, - converter: anthropic.NewStreamConverter(messageID, req.Model), + converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens), } if req.Stream { diff --git a/server/aliases.go b/server/aliases.go new file mode 100644 index 000000000..727b3f2f6 --- /dev/null +++ b/server/aliases.go @@ -0,0 +1,422 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "github.com/ollama/ollama/manifest" + "github.com/ollama/ollama/types/model" +) + +const ( + serverConfigFilename = "server.json" + serverConfigVersion = 1 +) + +var errAliasCycle = errors.New("alias cycle detected") + +type aliasEntry struct { + Alias string `json:"alias"` + Target string `json:"target"` + PrefixMatching bool `json:"prefix_matching,omitempty"` +} + +type serverConfig struct { + Version int `json:"version"` + Aliases []aliasEntry `json:"aliases"` +} + +type store struct { + mu sync.RWMutex + path string + entries map[string]aliasEntry // normalized alias -> entry (exact matches) + prefixEntries []aliasEntry // prefix matches, sorted longest-first +} + +func createStore(path string) (*store, error) { + store := &store{ + path: path, + entries: make(map[string]aliasEntry), + } + if err := store.load(); err != nil { + return nil, err + } + return store, nil +} + +func (s *store) load() error { + data, err := os.ReadFile(s.path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + var cfg serverConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return err + } + + if cfg.Version != 0 && cfg.Version != serverConfigVersion { + return fmt.Errorf("unsupported router config version %d", cfg.Version) + } + + for _, entry := range cfg.Aliases { + targetName := model.ParseName(entry.Target) + if !targetName.IsValid() { + slog.Warn("invalid alias target in router config", "target", entry.Target) + continue + } + canonicalTarget := displayAliasName(targetName) + + if entry.PrefixMatching { + // Prefix aliases don't need to be valid model names + alias := strings.TrimSpace(entry.Alias) + if alias == "" { + slog.Warn("empty prefix alias in router config") + continue + } + s.prefixEntries = append(s.prefixEntries, aliasEntry{ + Alias: alias, + Target: canonicalTarget, + PrefixMatching: true, + }) + } else { + aliasName := model.ParseName(entry.Alias) + if !aliasName.IsValid() { + slog.Warn("invalid alias name in router config", "alias", entry.Alias) + continue + } + canonicalAlias := displayAliasName(aliasName) + s.entries[normalizeAliasKey(aliasName)] = aliasEntry{ + Alias: canonicalAlias, + Target: canonicalTarget, + } + } + } + + // Sort prefix entries by alias length descending (longest prefix wins) + s.sortPrefixEntriesLocked() + + return nil +} + +func (s *store) saveLocked() error { + dir := filepath.Dir(s.path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + // Combine exact and prefix entries + entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries)) + for _, entry := range s.entries { + entries = append(entries, entry) + } + entries = append(entries, s.prefixEntries...) + + sort.Slice(entries, func(i, j int) bool { + return strings.Compare(entries[i].Alias, entries[j].Alias) < 0 + }) + + cfg := serverConfig{ + Version: serverConfigVersion, + Aliases: entries, + } + + f, err := os.CreateTemp(dir, "router-*.json") + if err != nil { + return err + } + + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + if err := enc.Encode(cfg); err != nil { + _ = f.Close() + _ = os.Remove(f.Name()) + return err + } + + if err := f.Close(); err != nil { + _ = os.Remove(f.Name()) + return err + } + + if err := os.Chmod(f.Name(), 0o644); err != nil { + _ = os.Remove(f.Name()) + return err + } + + return os.Rename(f.Name(), s.path) +} + +func (s *store) ResolveName(name model.Name) (model.Name, bool, error) { + // If a local model exists, do not allow alias shadowing (highest priority). + exists, err := localModelExists(name) + if err != nil { + return name, false, err + } + if exists { + return name, false, nil + } + + key := normalizeAliasKey(name) + + s.mu.RLock() + entry, exactMatch := s.entries[key] + var prefixMatch *aliasEntry + if !exactMatch { + // Try prefix matching - prefixEntries is sorted longest-first + nameStr := strings.ToLower(displayAliasName(name)) + for i := range s.prefixEntries { + prefix := strings.ToLower(s.prefixEntries[i].Alias) + if strings.HasPrefix(nameStr, prefix) { + prefixMatch = &s.prefixEntries[i] + break // First match is longest due to sorting + } + } + } + s.mu.RUnlock() + + if !exactMatch && prefixMatch == nil { + return name, false, nil + } + + var current string + var visited map[string]struct{} + + if exactMatch { + visited = map[string]struct{}{key: {}} + current = entry.Target + } else { + // For prefix match, use the target as-is + visited = map[string]struct{}{} + current = prefixMatch.Target + } + + targetKey := normalizeAliasKeyString(current) + + for { + targetName := model.ParseName(current) + if !targetName.IsValid() { + return name, false, fmt.Errorf("alias target %q is invalid", current) + } + + if _, seen := visited[targetKey]; seen { + return name, false, errAliasCycle + } + visited[targetKey] = struct{}{} + + s.mu.RLock() + next, ok := s.entries[targetKey] + s.mu.RUnlock() + if !ok { + return targetName, true, nil + } + + current = next.Target + targetKey = normalizeAliasKeyString(current) + } +} + +func (s *store) Set(alias, target model.Name, prefixMatching bool) error { + targetKey := normalizeAliasKey(target) + + s.mu.Lock() + defer s.mu.Unlock() + + if prefixMatching { + // For prefix aliases, we skip cycle detection since prefix matching + // works differently and the target is a specific model + aliasStr := displayAliasName(alias) + + // Remove any existing prefix entry with the same alias + for i, e := range s.prefixEntries { + if strings.EqualFold(e.Alias, aliasStr) { + s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) + break + } + } + + s.prefixEntries = append(s.prefixEntries, aliasEntry{ + Alias: aliasStr, + Target: displayAliasName(target), + PrefixMatching: true, + }) + s.sortPrefixEntriesLocked() + return s.saveLocked() + } + + aliasKey := normalizeAliasKey(alias) + + if aliasKey == targetKey { + return fmt.Errorf("alias cannot point to itself") + } + + visited := map[string]struct{}{aliasKey: {}} + currentKey := targetKey + for { + if _, seen := visited[currentKey]; seen { + return errAliasCycle + } + visited[currentKey] = struct{}{} + + next, ok := s.entries[currentKey] + if !ok { + break + } + currentKey = normalizeAliasKeyString(next.Target) + } + + s.entries[aliasKey] = aliasEntry{ + Alias: displayAliasName(alias), + Target: displayAliasName(target), + } + + return s.saveLocked() +} + +func (s *store) Delete(alias model.Name) (bool, error) { + aliasKey := normalizeAliasKey(alias) + + s.mu.Lock() + defer s.mu.Unlock() + + // Try exact match first + if _, ok := s.entries[aliasKey]; ok { + delete(s.entries, aliasKey) + return true, s.saveLocked() + } + + // Try prefix entries + aliasStr := displayAliasName(alias) + for i, e := range s.prefixEntries { + if strings.EqualFold(e.Alias, aliasStr) { + s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) + return true, s.saveLocked() + } + } + + return false, nil +} + +// DeleteByString deletes an alias by its raw string value, useful for prefix +// aliases that may not be valid model names. +func (s *store) DeleteByString(alias string) (bool, error) { + alias = strings.TrimSpace(alias) + aliasLower := strings.ToLower(alias) + + s.mu.Lock() + defer s.mu.Unlock() + + // Try prefix entries first (since this is mainly for prefix aliases) + for i, e := range s.prefixEntries { + if strings.EqualFold(e.Alias, alias) { + s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) + return true, s.saveLocked() + } + } + + // Also check exact entries by normalized key + if _, ok := s.entries[aliasLower]; ok { + delete(s.entries, aliasLower) + return true, s.saveLocked() + } + + return false, nil +} + +func (s *store) List() []aliasEntry { + s.mu.RLock() + defer s.mu.RUnlock() + + entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries)) + for _, entry := range s.entries { + entries = append(entries, entry) + } + entries = append(entries, s.prefixEntries...) + + sort.Slice(entries, func(i, j int) bool { + return strings.Compare(entries[i].Alias, entries[j].Alias) < 0 + }) + return entries +} + +func normalizeAliasKey(name model.Name) string { + return strings.ToLower(displayAliasName(name)) +} + +func (s *store) sortPrefixEntriesLocked() { + sort.Slice(s.prefixEntries, func(i, j int) bool { + // Sort by length descending (longest prefix first) + return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias) + }) +} + +func normalizeAliasKeyString(value string) string { + n := model.ParseName(value) + if !n.IsValid() { + return strings.ToLower(strings.TrimSpace(value)) + } + return normalizeAliasKey(n) +} + +func displayAliasName(n model.Name) string { + display := n.DisplayShortest() + if strings.EqualFold(n.Tag, "latest") { + if idx := strings.LastIndex(display, ":"); idx != -1 { + return display[:idx] + } + } + return display +} + +func localModelExists(name model.Name) (bool, error) { + manifests, err := manifest.Manifests(true) + if err != nil { + return false, err + } + needle := name.String() + for existing := range manifests { + if strings.EqualFold(existing.String(), needle) { + return true, nil + } + } + return false, nil +} + +func serverConfigPath() string { + home, err := os.UserHomeDir() + if err != nil { + return filepath.Join(".ollama", serverConfigFilename) + } + return filepath.Join(home, ".ollama", serverConfigFilename) +} + +func (s *Server) aliasStore() (*store, error) { + s.aliasesOnce.Do(func() { + s.aliases, s.aliasesErr = createStore(serverConfigPath()) + }) + + return s.aliases, s.aliasesErr +} + +func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) { + store, err := s.aliasStore() + if err != nil { + return name, false, err + } + + if store == nil { + return name, false, nil + } + + return store.ResolveName(name) +} diff --git a/server/routes.go b/server/routes.go index 910b8e954..ffc3be015 100644 --- a/server/routes.go +++ b/server/routes.go @@ -22,6 +22,7 @@ import ( "os/signal" "slices" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -81,6 +82,9 @@ type Server struct { addr net.Addr sched *Scheduler defaultNumCtx int + aliasesOnce sync.Once + aliases *store + aliasesErr error } func init() { @@ -191,9 +195,16 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + resolvedName, _, err := s.resolveAlias(name) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + name = resolvedName + // We cannot currently consolidate this into GetModel because all we'll // induce infinite recursion given the current code structure. - name, err := getExistingName(name) + name, err = getExistingName(name) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) return @@ -1580,6 +1591,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) r.POST("/api/copy", s.CopyHandler) + r.GET("/api/experimental/aliases", s.ListAliasesHandler) + r.POST("/api/experimental/aliases", s.CreateAliasHandler) + r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler) // Inference r.GET("/api/ps", s.PsHandler) @@ -1950,13 +1964,20 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - name, err := getExistingName(name) + resolvedName, _, err := s.resolveAlias(name) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + name = resolvedName + + name, err = getExistingName(name) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return } - m, err := GetModel(req.Model) + m, err := GetModel(name.String()) if err != nil { switch { case os.IsNotExist(err): diff --git a/server/routes_aliases.go b/server/routes_aliases.go new file mode 100644 index 000000000..d68514e9c --- /dev/null +++ b/server/routes_aliases.go @@ -0,0 +1,159 @@ +package server + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/types/model" +) + +type aliasListResponse struct { + Aliases []aliasEntry `json:"aliases"` +} + +type aliasDeleteRequest struct { + Alias string `json:"alias"` +} + +func (s *Server) ListAliasesHandler(c *gin.Context) { + store, err := s.aliasStore() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var aliases []aliasEntry + if store != nil { + aliases = store.List() + } + + c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases}) +} + +func (s *Server) CreateAliasHandler(c *gin.Context) { + var req aliasEntry + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + } else if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + req.Alias = strings.TrimSpace(req.Alias) + req.Target = strings.TrimSpace(req.Target) + if req.Alias == "" || req.Target == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"}) + return + } + + // Target must always be a valid model name + targetName := model.ParseName(req.Target) + if !targetName.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)}) + return + } + + var aliasName model.Name + if req.PrefixMatching { + // For prefix aliases, we still parse the alias to normalize it, + // but we allow any non-empty string since prefix patterns may not be valid model names + aliasName = model.ParseName(req.Alias) + // Even if not valid as a model name, we accept it for prefix matching + } else { + aliasName = model.ParseName(req.Alias) + if !aliasName.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)}) + return + } + + if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"}) + return + } + + exists, err := localModelExists(aliasName) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if exists { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)}) + return + } + } + + store, err := s.aliasStore() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil { + status := http.StatusInternalServerError + if errors.Is(err, errAliasCycle) { + status = http.StatusBadRequest + } + c.AbortWithStatusJSON(status, gin.H{"error": err.Error()}) + return + } + + resp := aliasEntry{ + Alias: displayAliasName(aliasName), + Target: displayAliasName(targetName), + PrefixMatching: req.PrefixMatching, + } + if req.PrefixMatching && !aliasName.IsValid() { + // For prefix aliases that aren't valid model names, use the raw alias + resp.Alias = req.Alias + } + c.JSON(http.StatusOK, resp) +} + +func (s *Server) DeleteAliasHandler(c *gin.Context) { + var req aliasDeleteRequest + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + } else if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + req.Alias = strings.TrimSpace(req.Alias) + if req.Alias == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"}) + return + } + + store, err := s.aliasStore() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + aliasName := model.ParseName(req.Alias) + var deleted bool + if aliasName.IsValid() { + deleted, err = store.Delete(aliasName) + } else { + // For invalid model names (like prefix aliases), try deleting by raw string + deleted, err = store.DeleteByString(req.Alias) + } + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !deleted { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)}) + return + } + + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} diff --git a/server/routes_aliases_test.go b/server/routes_aliases_test.go new file mode 100644 index 000000000..e31529996 --- /dev/null +++ b/server/routes_aliases_test.go @@ -0,0 +1,426 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/types/model" +) + +func TestAliasShadowingRejected(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("HOME", t.TempDir()) + + s := Server{} + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "shadowed-model", + RemoteHost: "example.com", + From: "test", + Info: map[string]any{ + "capabilities": []string{"completion"}, + }, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"}) + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestAliasResolvesForChatRemote(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("HOME", t.TempDir()) + + var remoteModel string + rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req api.ChatRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatal(err) + } + remoteModel = req.Model + + w.Header().Set("Content-Type", "application/json") + resp := api.ChatResponse{ + Model: req.Model, + Done: true, + DoneReason: "load", + } + if err := json.NewEncoder(w).Encode(&resp); err != nil { + t.Fatal(err) + } + })) + defer rs.Close() + + p, err := url.Parse(rs.URL) + if err != nil { + t.Fatal(err) + } + + t.Setenv("OLLAMA_REMOTES", p.Hostname()) + + s := Server{} + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "target-model", + RemoteHost: rs.URL, + From: "test", + Info: map[string]any{ + "capabilities": []string{"completion"}, + }, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"}) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "alias-model", + Messages: []api.Message{{Role: "user", Content: "hi"}}, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp api.ChatResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + if resp.Model != "alias-model" { + t.Fatalf("expected response model to be alias-model, got %q", resp.Model) + } + + if remoteModel != "test" { + t.Fatalf("expected remote model to be 'test', got %q", remoteModel) + } +} + +func TestPrefixAliasBasicMatching(t *testing.T) { + tmpDir := t.TempDir() + store, err := createStore(filepath.Join(tmpDir, "server.json")) + if err != nil { + t.Fatal(err) + } + + // Create a prefix alias: "myprefix-" -> "targetmodel" + targetName := model.ParseName("targetmodel") + + // Set a prefix alias (using "myprefix-" as the pattern) + store.mu.Lock() + store.prefixEntries = append(store.prefixEntries, aliasEntry{ + Alias: "myprefix-", + Target: "targetmodel", + PrefixMatching: true, + }) + store.mu.Unlock() + + // Test that "myprefix-foo" resolves to "targetmodel" + testName := model.ParseName("myprefix-foo") + resolved, wasResolved, err := store.ResolveName(testName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + if resolved.DisplayShortest() != targetName.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest()) + } + + // Test that "otherprefix-foo" does not resolve + otherName := model.ParseName("otherprefix-foo") + _, wasResolved, err = store.ResolveName(otherName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wasResolved { + t.Fatal("expected name not to be resolved") + } + + // Test that exact alias takes precedence + exactAlias := model.ParseName("myprefix-exact") + exactTarget := model.ParseName("exacttarget") + if err := store.Set(exactAlias, exactTarget, false); err != nil { + t.Fatal(err) + } + + resolved, wasResolved, err = store.ResolveName(exactAlias) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + if resolved.DisplayShortest() != exactTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest()) + } +} + +func TestPrefixAliasLongestMatchWins(t *testing.T) { + tmpDir := t.TempDir() + store, err := createStore(filepath.Join(tmpDir, "server.json")) + if err != nil { + t.Fatal(err) + } + + // Add two prefix aliases with overlapping patterns + store.mu.Lock() + store.prefixEntries = []aliasEntry{ + {Alias: "abc-", Target: "short-target", PrefixMatching: true}, + {Alias: "abc-def-", Target: "long-target", PrefixMatching: true}, + } + store.sortPrefixEntriesLocked() + store.mu.Unlock() + + // "abc-def-ghi" should match the longer prefix "abc-def-" + testName := model.ParseName("abc-def-ghi") + resolved, wasResolved, err := store.ResolveName(testName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + expectedLongTarget := model.ParseName("long-target") + if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest()) + } + + // "abc-xyz" should match the shorter prefix "abc-" + testName2 := model.ParseName("abc-xyz") + resolved, wasResolved, err = store.ResolveName(testName2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + expectedShortTarget := model.ParseName("short-target") + if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest()) + } +} + +func TestPrefixAliasChain(t *testing.T) { + tmpDir := t.TempDir() + store, err := createStore(filepath.Join(tmpDir, "server.json")) + if err != nil { + t.Fatal(err) + } + + // Create a chain: prefix "test-" -> "intermediate" -> "final" + intermediate := model.ParseName("intermediate") + final := model.ParseName("final") + + // Add prefix alias + store.mu.Lock() + store.prefixEntries = []aliasEntry{ + {Alias: "test-", Target: "intermediate", PrefixMatching: true}, + } + store.mu.Unlock() + + // Add exact alias for the intermediate step + if err := store.Set(intermediate, final, false); err != nil { + t.Fatal(err) + } + + // "test-foo" should resolve through the chain to "final" + testName := model.ParseName("test-foo") + resolved, wasResolved, err := store.ResolveName(testName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved") + } + if resolved.DisplayShortest() != final.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest()) + } +} + +func TestPrefixAliasCRUD(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("HOME", t.TempDir()) + + s := Server{} + + // Create a prefix alias via API + w := createRequest(t, s.CreateAliasHandler, aliasEntry{ + Alias: "myprefix-", + Target: "llama2", + PrefixMatching: true, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var createResp aliasEntry + if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil { + t.Fatal(err) + } + if !createResp.PrefixMatching { + t.Fatal("expected prefix_matching to be true in response") + } + + // List aliases and verify the prefix alias is included + w = createRequest(t, s.ListAliasesHandler, nil) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var listResp aliasListResponse + if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil { + t.Fatal(err) + } + + found := false + for _, a := range listResp.Aliases { + if a.PrefixMatching && a.Target == "llama2" { + found = true + break + } + } + if !found { + t.Fatal("expected to find prefix alias in list") + } + + // Delete the prefix alias + w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"}) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify it's deleted + w = createRequest(t, s.ListAliasesHandler, nil) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil { + t.Fatal(err) + } + + for _, a := range listResp.Aliases { + if a.PrefixMatching { + t.Fatal("expected prefix alias to be deleted") + } + } +} + +func TestPrefixAliasCaseInsensitive(t *testing.T) { + tmpDir := t.TempDir() + store, err := createStore(filepath.Join(tmpDir, "server.json")) + if err != nil { + t.Fatal(err) + } + + // Add a prefix alias with mixed case + store.mu.Lock() + store.prefixEntries = []aliasEntry{ + {Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true}, + } + store.mu.Unlock() + + // Test that matching is case-insensitive + testName := model.ParseName("myprefix-foo") + resolved, wasResolved, err := store.ResolveName(testName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved (case-insensitive)") + } + expectedTarget := model.ParseName("targetmodel") + if resolved.DisplayShortest() != expectedTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest()) + } + + // Test uppercase request + testName2 := model.ParseName("MYPREFIX-BAR") + _, wasResolved, err = store.ResolveName(testName2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected name to be resolved (uppercase)") + } +} + +func TestPrefixAliasLocalModelPrecedence(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("HOME", t.TempDir()) + + s := Server{} + + // Create a local model that would match a prefix alias + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "myprefix-localmodel", + RemoteHost: "example.com", + From: "test", + Info: map[string]any{ + "capabilities": []string{"completion"}, + }, + Stream: &stream, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Create a prefix alias that would match the local model name + w = createRequest(t, s.CreateAliasHandler, aliasEntry{ + Alias: "myprefix-", + Target: "someothermodel", + PrefixMatching: true, + }) + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify that resolving "myprefix-localmodel" returns the local model, not the alias target + store, err := s.aliasStore() + if err != nil { + t.Fatal(err) + } + + localModelName := model.ParseName("myprefix-localmodel") + resolved, wasResolved, err := store.ResolveName(localModelName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wasResolved { + t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest()) + } + if resolved.DisplayShortest() != localModelName.DisplayShortest() { + t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest()) + } + + // Also verify that a non-local model matching the prefix DOES resolve to the alias target + nonLocalName := model.ParseName("myprefix-nonexistent") + resolved, wasResolved, err = store.ResolveName(nonLocalName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !wasResolved { + t.Fatal("expected non-local model to resolve via prefix alias") + } + expectedTarget := model.ParseName("someothermodel") + if resolved.DisplayShortest() != expectedTarget.DisplayShortest() { + t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest()) + } +}