cmd: claude launch improvements (#14064)

This commit is contained in:
Parth Sareen
2026-02-03 22:33:58 -05:00
committed by GitHub
parent b1fccabb34
commit ee25219edd
15 changed files with 1609 additions and 81 deletions

View File

@@ -1,18 +1,23 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration
// Claude implements Runner and AliasConfigurer for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
@@ -60,3 +65,96 @@ func (c *Claude) Run(model string, args []string) error {
)
return cmd.Run()
}
// ConfigureAliases sets up Primary and Fast model aliases for Claude Code.
func (c *Claude) ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existing {
aliases[k] = v
}
if primaryModel != "" {
aliases["primary"] = primaryModel
}
if !force && aliases["primary"] != "" && aliases["fast"] != "" {
return aliases, false, nil
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%sClaude Code uses multiple models for various tasks%s\n\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "%sPrimary%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%sHandles complex reasoning: planning, code generation, debugging.%s\n\n", ansiGray, ansiReset)
if aliases["primary"] == "" || force {
primary, err := selectPrompt("Select Primary model:", items)
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
} else {
fmt.Fprintf(os.Stderr, " %s\n\n", aliases["primary"])
}
fmt.Fprintf(os.Stderr, "%sFast%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%sHandles quick operations: file searches, simple edits, status checks.%s\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "%sSmaller models work well and respond faster.%s\n\n", ansiGray, ansiReset)
if aliases["fast"] == "" || force {
fast, err := selectPrompt("Select Fast model:", items)
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, fast); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{fast}); err != nil {
return nil, false, err
}
aliases["fast"] = fast
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -13,7 +13,8 @@ import (
)
type integration struct {
Models []string `json:"models"`
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
}
type config struct {
@@ -133,8 +134,16 @@ func saveIntegration(appName string, models []string) error {
return err
}
cfg.Integrations[strings.ToLower(appName)] = &integration{
Models: models,
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
aliases = existing.Aliases
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
}
return save(cfg)
@@ -154,6 +163,33 @@ func loadIntegration(appName string) (*integration, error) {
return ic, nil
}
func saveAliases(appName string, aliases map[string]string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
if existing.Aliases == nil {
existing.Aliases = make(map[string]string)
}
for k, v := range aliases {
existing.Aliases[k] = v
}
cfg.Integrations[key] = existing
return save(cfg)
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {

View File

@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
}
})
t.Run("save and load aliases", func(t *testing.T) {
models := []string{"llama3.2"}
if err := saveIntegration("claude", models); err != nil {
t.Fatal(err)
}
aliases := map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
}
if err := saveAliases("claude", aliases); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases == nil {
t.Fatal("expected aliases to be saved")
}
for k, v := range aliases {
if config.Aliases[k] != v {
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
}
}
})
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases["primary"] != "model-a" {
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
saveIntegration("codex", []string{"model-a", "model-b"})

View File

@@ -39,6 +39,15 @@ type Editor interface {
Models() []string
}
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
// Integrations like Claude and Codex use this to route model requests to local models.
type AliasConfigurer interface {
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
// SetAliases syncs the configured aliases to the server
SetAliases(ctx context.Context, aliases map[string]string) error
}
// integrations is the registry of available integrations.
var integrations = map[string]Runner{
"claude": &Claude{},
@@ -129,7 +138,11 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return nil, err
}
} else {
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
prompt := fmt.Sprintf("Select model for %s:", r)
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := selectPrompt(prompt, items)
if err != nil {
return nil, err
}
@@ -157,73 +170,146 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
}
}
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
return nil, err
}
return selected, nil
}
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
if existingModels[model] {
return nil
}
msg := fmt.Sprintf("Download %s?", model)
if ok, err := confirmPrompt(msg); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
if err := pullModel(ctx, client, model); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, nil, nil, nil, err
}
models, err := client.List(ctx)
if err != nil {
return nil, nil, nil, nil, err
}
var existing []modelInfo
for _, m := range models.Models {
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
}
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
if len(items) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
return items, existingModels, cloudModels, client, nil
}
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) > 0 {
// ensure user is signed in
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return selected, nil
}
if len(selectedCloudModels) == 0 {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return nil, err
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return nil, fmt.Errorf("%s requires sign in", modelList)
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
modelList := strings.Join(selectedCloudModels, ", ")
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if err != nil || !yes {
return fmt.Errorf("%s requires sign in", modelList)
}
// TODO(parthsareen): extract into auth package for cmd
// Auto-open browser (best effort, fail silently)
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", aErr.SigninURL).Start()
case "linux":
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
}
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return nil, ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return selected, nil
}
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
// poll every 10th frame (~2 seconds)
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
}
}
}
}
}
return selected, nil
func ensureAliases(ctx context.Context, r Runner, name string, primaryModel string, existing map[string]string, force bool) (bool, error) {
ac, ok := r.(AliasConfigurer)
if !ok {
return false, nil
}
aliases, updated, err := ac.ConfigureAliases(ctx, primaryModel, existing, force)
if err != nil {
return false, err
}
if !updated {
return false, nil
}
if err := saveAliases(name, aliases); err != nil {
return false, err
}
if err := ac.SetAliases(ctx, aliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases to server: %v%s\n", ansiGray, err, ansiReset)
fmt.Fprintf(os.Stderr, "%sAliases saved locally. Server sync will retry on next launch.%s\n\n", ansiGray, ansiReset)
}
return true, nil
}
func runIntegration(name, modelName string, args []string) error {
@@ -231,6 +317,17 @@ func runIntegration(name, modelName string, args []string) error {
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
if _, ok := r.(AliasConfigurer); ok {
if config, err := loadIntegration(name); err == nil && config.Aliases != nil {
primary, fast := config.Aliases["primary"], config.Aliases["fast"]
if primary != "" && fast != "" {
fmt.Fprintf(os.Stderr, "\nLaunching %s with Primary: %s, Fast: %s...\n", r, primary, fast)
return r.Run(modelName, args)
}
}
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName, args)
}
@@ -304,10 +401,50 @@ Examples:
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
if _, err := ensureAliases(cmd.Context(), r, name, config.Models[0], config.Aliases, false); errors.Is(err, errCancelled) {
return nil
} else if err != nil {
return err
}
return runIntegration(name, config.Models[0], passArgs)
}
}
if ac, ok := r.(AliasConfigurer); ok {
var existingAliases map[string]string
if existing, err := loadIntegration(name); err == nil {
existingAliases = existing.Aliases
}
aliases, updated, err := ac.ConfigureAliases(cmd.Context(), "", existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if updated {
if err := saveAliases(name, aliases); err != nil {
return err
}
if err := ac.SetAliases(cmd.Context(), aliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases to server: %v%s\n", ansiGray, err, ansiReset)
}
fmt.Fprintf(os.Stderr, "\n%sConfiguration Complete%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "Primary: %s\n", aliases["primary"])
fmt.Fprintf(os.Stderr, "Fast: %s\n\n", aliases["fast"])
}
if err := saveIntegration(name, []string{aliases["primary"]}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
return runIntegration(name, aliases["primary"], passArgs)
}
return nil
}
return runIntegration(name, aliases["primary"], passArgs)
}
var models []string
if modelFlag != "" {
models = []string{modelFlag}

View File

@@ -509,3 +509,19 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
t.Error("llama3.2 should not be in cloudModels")
}
}
func TestAliasConfigurerInterface(t *testing.T) {
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
claude := &Claude{}
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
t.Error("Claude should implement AliasConfigurer")
}
})
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
codex := &Codex{}
if _, ok := interface{}(codex).(AliasConfigurer); ok {
t.Error("Codex should not implement AliasConfigurer")
}
})
}

View File

@@ -65,6 +65,10 @@ func (s *selectState) handleInput(event inputEvent, char byte) (done bool, resul
if len(filtered) > 0 && s.selected < len(filtered) {
return true, filtered[s.selected].Name, nil
}
// No matches but user typed something - return filter for pull prompt
if len(filtered) == 0 && s.filter != "" {
return true, s.filter, nil
}
case eventEscape:
return true, "", errCancelled
case eventBackspace:
@@ -283,7 +287,11 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int {
lineCount := 1
if len(filtered) == 0 {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
if s.filter != "" {
fmt.Fprintf(w, " %s→ Download model: '%s'? Press Enter%s\r\n", ansiGray, s.filter, ansiReset)
} else {
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
}
lineCount++
} else {
displayCount := min(len(filtered), maxDisplayedItems)

View File

@@ -87,10 +87,18 @@ func TestSelectState(t *testing.T) {
}
})
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
t.Run("Enter_EmptyFilteredList_ReturnsFilter", func(t *testing.T) {
s := newSelectState(items)
s.filter = "nonexistent"
done, result, err := s.handleInput(eventEnter, 0)
if !done || result != "nonexistent" || err != nil {
t.Errorf("expected (true, 'nonexistent', nil), got (%v, %v, %v)", done, result, err)
}
})
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
s := newSelectState([]selectItem{})
done, result, err := s.handleInput(eventEnter, 0)
if done || result != "" || err != nil {
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
}
@@ -568,14 +576,25 @@ func TestRenderSelect(t *testing.T) {
}
})
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
t.Run("EmptyFilteredList_ShowsPullPrompt", func(t *testing.T) {
s := newSelectState(items)
s.filter = "xyz"
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
output := buf.String()
if !strings.Contains(output, "Download model: 'xyz'?") {
t.Errorf("expected 'Download model: xyz?' message, got: %s", output)
}
})
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
s := newSelectState([]selectItem{})
var buf bytes.Buffer
renderSelect(&buf, "Select:", s)
if !strings.Contains(buf.String(), "no matches") {
t.Error("expected 'no matches' message")
t.Error("expected 'no matches' message for empty list with no filter")
}
})