From 97d2f05a6d74e843eb8c49ec537394b12bb80be9 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 3 Mar 2026 12:51:23 -0800 Subject: [PATCH] Revert "don't require pulling stubs for cloud models (#14574)" (#14596) This reverts commit 8207e55ec7eb3a2cf4cb20917518514d981a6a01. --- cmd/cmd.go | 44 +- cmd/cmd_test.go | 174 +---- cmd/config/claude.go | 17 +- cmd/config/config.go | 3 - cmd/config/droid.go | 4 +- cmd/config/integrations.go | 65 +- cmd/config/integrations_test.go | 139 +--- cmd/config/opencode.go | 13 +- cmd/tui/tui.go | 11 +- internal/modelref/modelref.go | 115 ---- internal/modelref/modelref_test.go | 268 -------- middleware/anthropic.go | 3 +- server/cloud_proxy.go | 460 -------------- server/cloud_proxy_test.go | 154 ----- server/create.go | 17 +- server/model_resolver.go | 81 --- server/model_resolver_test.go | 170 ----- server/routes.go | 150 +---- server/routes_cloud_test.go | 988 ----------------------------- server/routes_create_test.go | 37 -- server/routes_delete_test.go | 29 - x/cmd/run.go | 3 +- x/cmd/run_test.go | 18 +- 23 files changed, 114 insertions(+), 2849 deletions(-) delete mode 100644 internal/modelref/modelref.go delete mode 100644 internal/modelref/modelref_test.go delete mode 100644 server/cloud_proxy.go delete mode 100644 server/cloud_proxy_test.go delete mode 100644 server/model_resolver.go delete mode 100644 server/model_resolver_test.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 031b200a8..8c3131593 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -41,7 +41,6 @@ import ( "github.com/ollama/ollama/cmd/tui" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" - "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" @@ -407,14 +406,12 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { return err } - requestedCloud := modelref.HasExplicitCloudSource(opts.Model) - if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil { return err - } else if info.RemoteHost != "" || requestedCloud { + } else if info.RemoteHost != "" { // Cloud model, no need to load/unload - isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com") + isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com") // Check if user is signed in for ollama.com cloud models if isCloud { @@ -425,14 +422,10 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { if opts.ShowConnect { p.StopAndClear() - remoteModel := info.RemoteModel - if remoteModel == "" { - remoteModel = opts.Model - } if isCloud { - fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel) + fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel) } else { - fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost) + fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost) } } @@ -504,20 +497,6 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a return nil } -// TODO(parthsareen): consolidate with TUI signin flow -func handleCloudAuthorizationError(err error) bool { - var authErr api.AuthorizationError - if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized { - fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") - if authErr.SigninURL != "" { - fmt.Printf(ConnectInstructions, authErr.SigninURL) - } - return true - } - - return false -} - func RunHandler(cmd *cobra.Command, args []string) error { interactive := true @@ -625,16 +604,12 @@ func RunHandler(cmd *cobra.Command, args []string) error { } name := args[0] - requestedCloud := modelref.HasExplicitCloudSource(name) info, err := func() (*api.ShowResponse, error) { showReq := &api.ShowRequest{Name: name} info, err := client.Show(cmd.Context(), showReq) var se api.StatusError if errors.As(err, &se) && se.StatusCode == http.StatusNotFound { - if requestedCloud { - return nil, err - } if err := PullHandler(cmd, []string{name}); err != nil { return nil, err } @@ -643,9 +618,6 @@ func RunHandler(cmd *cobra.Command, args []string) error { return info, err }() if err != nil { - if handleCloudAuthorizationError(err) { - return nil - } return err } @@ -740,13 +712,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generateInteractive(cmd, opts) } - if err := generate(cmd, opts); err != nil { - if handleCloudAuthorizationError(err) { - return nil - } - return err - } - return nil + return generate(cmd, opts) } func SigninHandler(cmd *cobra.Command, args []string) error { diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index dfbd63a85..7217c3d13 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -18,7 +18,6 @@ import ( "github.com/spf13/cobra" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/types/model" ) @@ -706,139 +705,6 @@ func TestRunEmbeddingModelNoInput(t *testing.T) { } } -func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) { - var generateCalled bool - - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.URL.Path == "/api/show" && r.Method == http.MethodPost: - w.WriteHeader(http.StatusUnauthorized) - if err := json.NewEncoder(w).Encode(map[string]string{ - "error": "unauthorized", - "signin_url": "https://ollama.com/signin", - }); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return - case r.URL.Path == "/api/generate" && r.Method == http.MethodPost: - generateCalled = true - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return - default: - http.NotFound(w, r) - } - })) - - t.Setenv("OLLAMA_HOST", mockServer.URL) - t.Cleanup(mockServer.Close) - - cmd := &cobra.Command{} - cmd.SetContext(t.Context()) - cmd.Flags().String("keepalive", "", "") - cmd.Flags().Bool("truncate", false, "") - cmd.Flags().Int("dimensions", 0, "") - cmd.Flags().Bool("verbose", false, "") - cmd.Flags().Bool("insecure", false, "") - cmd.Flags().Bool("nowordwrap", false, "") - cmd.Flags().String("format", "", "") - cmd.Flags().String("think", "", "") - cmd.Flags().Bool("hidethinking", false, "") - - oldStdout := os.Stdout - readOut, writeOut, _ := os.Pipe() - os.Stdout = writeOut - t.Cleanup(func() { os.Stdout = oldStdout }) - - err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"}) - - _ = writeOut.Close() - var out bytes.Buffer - _, _ = io.Copy(&out, readOut) - - if err != nil { - t.Fatalf("RunHandler returned error: %v", err) - } - - if generateCalled { - t.Fatal("expected run to stop before /api/generate after unauthorized /api/show") - } - - if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") { - t.Fatalf("expected sign-in guidance message, got %q", out.String()) - } - - if !strings.Contains(out.String(), "https://ollama.com/signin") { - t.Fatalf("expected signin_url in output, got %q", out.String()) - } -} - -func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) { - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.URL.Path == "/api/show" && r.Method == http.MethodPost: - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(api.ShowResponse{ - Capabilities: []model.Capability{model.CapabilityCompletion}, - }); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return - case r.URL.Path == "/api/generate" && r.Method == http.MethodPost: - w.WriteHeader(http.StatusUnauthorized) - if err := json.NewEncoder(w).Encode(map[string]string{ - "error": "unauthorized", - "signin_url": "https://ollama.com/signin", - }); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return - default: - http.NotFound(w, r) - } - })) - - t.Setenv("OLLAMA_HOST", mockServer.URL) - t.Cleanup(mockServer.Close) - - cmd := &cobra.Command{} - cmd.SetContext(t.Context()) - cmd.Flags().String("keepalive", "", "") - cmd.Flags().Bool("truncate", false, "") - cmd.Flags().Int("dimensions", 0, "") - cmd.Flags().Bool("verbose", false, "") - cmd.Flags().Bool("insecure", false, "") - cmd.Flags().Bool("nowordwrap", false, "") - cmd.Flags().String("format", "", "") - cmd.Flags().String("think", "", "") - cmd.Flags().Bool("hidethinking", false, "") - - oldStdout := os.Stdout - readOut, writeOut, _ := os.Pipe() - os.Stdout = writeOut - t.Cleanup(func() { os.Stdout = oldStdout }) - - err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"}) - - _ = writeOut.Close() - var out bytes.Buffer - _, _ = io.Copy(&out, readOut) - - if err != nil { - t.Fatalf("RunHandler returned error: %v", err) - } - - if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") { - t.Fatalf("expected sign-in guidance message, got %q", out.String()) - } - - if !strings.Contains(out.String(), "https://ollama.com/signin") { - t.Fatalf("expected signin_url in output, got %q", out.String()) - } -} - func TestGetModelfileName(t *testing.T) { tests := []struct { name string @@ -1798,26 +1664,20 @@ func TestRunOptions_Copy_Independence(t *testing.T) { func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { tests := []struct { name string - model string remoteHost string - remoteModel string whoamiStatus int whoamiResp any expectedError string }{ { name: "ollama.com cloud model - user signed in", - model: "test-cloud-model", remoteHost: "https://ollama.com", - remoteModel: "test-model", whoamiStatus: http.StatusOK, whoamiResp: api.UserResponse{Name: "testuser"}, }, { name: "ollama.com cloud model - user not signed in", - model: "test-cloud-model", remoteHost: "https://ollama.com", - remoteModel: "test-model", whoamiStatus: http.StatusUnauthorized, whoamiResp: map[string]string{ "error": "unauthorized", @@ -1827,33 +1687,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { }, { name: "non-ollama.com remote - no auth check", - model: "test-cloud-model", remoteHost: "https://other-remote.com", - remoteModel: "test-model", - whoamiStatus: http.StatusUnauthorized, // should not be called - whoamiResp: nil, - }, - { - name: "explicit :cloud model - auth check without remote metadata", - model: "kimi-k2.5:cloud", - remoteHost: "", - remoteModel: "", - whoamiStatus: http.StatusOK, - whoamiResp: api.UserResponse{Name: "testuser"}, - }, - { - name: "explicit -cloud model - auth check without remote metadata", - model: "kimi-k2.5:latest-cloud", - remoteHost: "", - remoteModel: "", - whoamiStatus: http.StatusOK, - whoamiResp: api.UserResponse{Name: "testuser"}, - }, - { - name: "dash cloud-like name without explicit source does not require auth", - model: "test-cloud-model", - remoteHost: "", - remoteModel: "", whoamiStatus: http.StatusUnauthorized, // should not be called whoamiResp: nil, }, @@ -1868,7 +1702,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(api.ShowResponse{ RemoteHost: tt.remoteHost, - RemoteModel: tt.remoteModel, + RemoteModel: "test-model", }); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } @@ -1881,8 +1715,6 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { http.Error(w, err.Error(), http.StatusInternalServerError) } } - case "/api/generate": - w.WriteHeader(http.StatusOK) default: http.NotFound(w, r) } @@ -1895,13 +1727,13 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { cmd.SetContext(t.Context()) opts := &runOptions{ - Model: tt.model, + Model: "test-cloud-model", ShowConnect: false, } err := loadOrUnloadModel(cmd, opts) - if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) { + if strings.HasPrefix(tt.remoteHost, "https://ollama.com") { if !whoamiCalled { t.Error("expected whoami to be called for ollama.com cloud model") } diff --git a/cmd/config/claude.go b/cmd/config/claude.go index 9018d193d..b7ed02af1 100644 --- a/cmd/config/claude.go +++ b/cmd/config/claude.go @@ -107,12 +107,15 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli } if !force && aliases["primary"] != "" { - if isCloudModelName(aliases["primary"]) { - aliases["fast"] = aliases["primary"] + client, _ := api.ClientFromEnvironment() + if isCloudModel(ctx, client, aliases["primary"]) { + if isCloudModel(ctx, client, aliases["fast"]) { + return aliases, false, nil + } + } else { + delete(aliases, "fast") return aliases, false, nil } - delete(aliases, "fast") - return aliases, false, nil } items, existingModels, cloudModels, client, err := listModels(ctx) @@ -136,8 +139,10 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli aliases["primary"] = primary } - if isCloudModelName(aliases["primary"]) { - aliases["fast"] = aliases["primary"] + if isCloudModel(ctx, client, aliases["primary"]) { + if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) { + aliases["fast"] = aliases["primary"] + } } else { delete(aliases, "fast") } diff --git a/cmd/config/config.go b/cmd/config/config.go index 82bfb493d..8eb41f4ae 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -233,9 +233,6 @@ func ModelExists(ctx context.Context, name string) bool { if name == "" { return false } - if isCloudModelName(name) { - return true - } client, err := api.ClientFromEnvironment() if err != nil { return false diff --git a/cmd/config/droid.go b/cmd/config/droid.go index ed88c0177..d1a9f54dc 100644 --- a/cmd/config/droid.go +++ b/cmd/config/droid.go @@ -10,6 +10,7 @@ import ( "path/filepath" "slices" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" ) @@ -124,12 +125,13 @@ func (d *Droid) Edit(models []string) error { } // Build new Ollama model entries with sequential indices (0, 1, 2, ...) + client, _ := api.ClientFromEnvironment() var newModels []any var defaultModelID string for i, model := range models { maxOutput := 64000 - if isCloudModelName(model) { + if isCloudModel(context.Background(), client, model) { if l, ok := lookupCloudModelLimit(model); ok { maxOutput = l.Output } diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go index b3de38bb6..acf458abe 100644 --- a/cmd/config/integrations.go +++ b/cmd/config/integrations.go @@ -14,7 +14,6 @@ import ( "github.com/ollama/ollama/api" internalcloud "github.com/ollama/ollama/internal/cloud" - "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/progress" "github.com/spf13/cobra" ) @@ -325,7 +324,12 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri // If the selected model isn't installed, pull it first if !existingModels[selected] { - if !isCloudModelName(selected) { + if cloudModels[selected] { + // Cloud models only pull a small manifest; no confirmation needed + if err := pullModel(ctx, client, selected); err != nil { + return "", fmt.Errorf("failed to pull %s: %w", selected, err) + } + } else { msg := fmt.Sprintf("Download %s?", selected) if ok, err := confirmPrompt(msg); err != nil { return "", err @@ -520,7 +524,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single var toPull []string for _, m := range selected { - if !existingModels[m] && !isCloudModelName(m) { + if !existingModels[m] { toPull = append(toPull, m) } } @@ -546,28 +550,12 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single return selected, nil } -// TODO(parthsareen): consolidate pull logic from call sites func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error { - if isCloudModelName(model) || existingModels[model] { + if existingModels[model] { return nil } - return confirmAndPull(ctx, client, model) -} - -// TODO(parthsareen): pull this out to tui package -// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found. -func ShowOrPull(ctx context.Context, client *api.Client, model string) error { - if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil { - return nil - } - if isCloudModelName(model) { - return nil - } - return confirmAndPull(ctx, client, model) -} - -func confirmAndPull(ctx context.Context, client *api.Client, model string) error { - if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil { + msg := fmt.Sprintf("Download %s?", model) + if ok, err := confirmPrompt(msg); err != nil { return err } else if !ok { return errCancelled @@ -579,6 +567,26 @@ func confirmAndPull(ctx context.Context, client *api.Client, model string) error return nil } +// TODO(parthsareen): pull this out to tui package +// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found. +func ShowOrPull(ctx context.Context, client *api.Client, model string) error { + if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil { + return nil + } + // Cloud models only pull a small manifest; skip the download confirmation + // TODO(parthsareen): consolidate with cloud config changes + if strings.HasSuffix(model, "cloud") { + return pullModel(ctx, client, model) + } + if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil { + return err + } else if !ok { + return errCancelled + } + fmt.Fprintf(os.Stderr, "\n") + return pullModel(ctx, client, model) +} + func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) { client, err := api.ClientFromEnvironment() if err != nil { @@ -723,8 +731,10 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na } aliases["primary"] = model - if isCloudModelName(model) { - aliases["fast"] = model + if isCloudModel(ctx, client, model) { + if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) { + aliases["fast"] = model + } } else { delete(aliases, "fast") } @@ -1010,7 +1020,7 @@ Examples: existingAliases = aliases // Ensure cloud models are authenticated - if isCloudModelName(model) { + if isCloudModel(cmd.Context(), client, model) { if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil { return err } @@ -1199,7 +1209,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) ( // When user has no models, preserve recommended order. notInstalled := make(map[string]bool) for i := range items { - if !existingModels[items[i].Name] && !cloudModels[items[i].Name] { + if !existingModels[items[i].Name] { notInstalled[items[i].Name] = true var parts []string if items[i].Description != "" { @@ -1293,8 +1303,7 @@ func IsCloudModelDisabled(ctx context.Context, name string) bool { } func isCloudModelName(name string) bool { - // TODO(drifkin): Replace this wrapper with inlining once things stabilize a bit - return modelref.HasExplicitCloudSource(name) + return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud") } func filterCloudModels(existing []modelInfo) []modelInfo { diff --git a/cmd/config/integrations_test.go b/cmd/config/integrations_test.go index 2eca19fd9..914a8f661 100644 --- a/cmd/config/integrations_test.go +++ b/cmd/config/integrations_test.go @@ -426,14 +426,8 @@ func TestBuildModelList_NoExistingModels(t *testing.T) { } for _, item := range items { - if strings.HasSuffix(item.Name, ":cloud") { - if strings.HasSuffix(item.Description, "(not downloaded)") { - t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) - } - } else { - if !strings.HasSuffix(item.Description, "(not downloaded)") { - t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description) - } + if !strings.HasSuffix(item.Description, "(not downloaded)") { + t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description) } } } @@ -498,14 +492,10 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) { if strings.HasSuffix(item.Description, "(not downloaded)") { t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) } - case "qwen3:8b": + case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b": if !strings.HasSuffix(item.Description, "(not downloaded)") { t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description) } - case "minimax-m2.5:cloud", "kimi-k2.5:cloud": - if strings.HasSuffix(item.Description, "(not downloaded)") { - t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) - } } } } @@ -546,13 +536,7 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes } for _, item := range items { - isCloud := strings.HasSuffix(item.Name, ":cloud") - isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) - if isInstalled || isCloud { - if strings.HasSuffix(item.Description, "(not downloaded)") { - t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) - } - } else { + if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) { if !strings.HasSuffix(item.Description, "(not downloaded)") { t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description) } @@ -1016,8 +1000,8 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) { } } -func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) { - // Confirm prompt should NOT be called for explicit cloud models +func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) { + // Confirm prompt should NOT be called for cloud models oldHook := DefaultConfirmPrompt DefaultConfirmPrompt = func(prompt string) (bool, error) { t.Error("confirm prompt should not be called for cloud models") @@ -1048,115 +1032,8 @@ func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) { if err != nil { t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err) } - if pullCalled { - t.Error("expected pull not to be called for cloud model") - } -} - -func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) { - // Confirm prompt should NOT be called for explicit cloud models - oldHook := DefaultConfirmPrompt - DefaultConfirmPrompt = func(prompt string) (bool, error) { - t.Error("confirm prompt should not be called for cloud models") - return false, nil - } - defer func() { DefaultConfirmPrompt = oldHook }() - - var pullCalled bool - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/show": - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, `{"error":"model not found"}`) - case "/api/pull": - pullCalled = true - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"status":"success"}`) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer srv.Close() - - u, _ := url.Parse(srv.URL) - client := api.NewClient(u, srv.Client()) - - err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud") - if err != nil { - t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err) - } - if pullCalled { - t.Error("expected pull not to be called for cloud model") - } -} - -func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) { - oldHook := DefaultConfirmPrompt - DefaultConfirmPrompt = func(prompt string) (bool, error) { - t.Error("confirm prompt should not be called for cloud models") - return false, nil - } - defer func() { DefaultConfirmPrompt = oldHook }() - - err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud") - if err != nil { - t.Fatalf("expected no error for cloud model, got %v", err) - } - - err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud") - if err != nil { - t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err) - } -} - -func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) { - var pullCalled bool - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/status": - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, `{"error":"not found"}`) - case "/api/tags": - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"models":[]}`) - case "/api/me": - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"name":"test-user"}`) - case "/api/pull": - pullCalled = true - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"status":"success"}`) - default: - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, `{"error":"not found"}`) - } - })) - defer srv.Close() - t.Setenv("OLLAMA_HOST", srv.URL) - - single := func(title string, items []ModelItem, current string) (string, error) { - for _, item := range items { - if item.Name == "glm-5:cloud" { - return item.Name, nil - } - } - t.Fatalf("expected glm-5:cloud in selector items, got %v", items) - return "", nil - } - - multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) { - return nil, fmt.Errorf("multi selector should not be called") - } - - selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi) - if err != nil { - t.Fatalf("selectModelsWithSelectors returned error: %v", err) - } - if !slices.Equal(selected, []string{"glm-5:cloud"}) { - t.Fatalf("unexpected selected models: %v", selected) - } - if pullCalled { - t.Fatal("expected cloud selection to skip pull") + if !pullCalled { + t.Error("expected pull to be called for cloud model without confirmation") } } diff --git a/cmd/config/opencode.go b/cmd/config/opencode.go index 2f5fb5858..b044359e1 100644 --- a/cmd/config/opencode.go +++ b/cmd/config/opencode.go @@ -12,8 +12,8 @@ import ( "slices" "strings" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/internal/modelref" ) // OpenCode implements Runner and Editor for OpenCode integration @@ -26,13 +26,13 @@ type cloudModelLimit struct { } // lookupCloudModelLimit returns the token limits for a cloud model. -// It tries the exact name first, then strips explicit cloud suffixes. +// It tries the exact name first, then strips the ":cloud" suffix. func lookupCloudModelLimit(name string) (cloudModelLimit, bool) { if l, ok := cloudModelLimits[name]; ok { return l, true } - base, stripped := modelref.StripCloudSourceTag(name) - if stripped { + base := strings.TrimSuffix(name, ":cloud") + if base != name { if l, ok := cloudModelLimits[base]; ok { return l, true } @@ -152,6 +152,7 @@ func (o *OpenCode) Edit(modelList []string) error { } } + client, _ := api.ClientFromEnvironment() for _, model := range modelList { if existing, ok := models[model].(map[string]any); ok { @@ -162,7 +163,7 @@ func (o *OpenCode) Edit(modelList []string) error { existing["name"] = strings.TrimSuffix(name, " [Ollama]") } } - if isCloudModelName(model) { + if isCloudModel(context.Background(), client, model) { if l, ok := lookupCloudModelLimit(model); ok { existing["limit"] = map[string]any{ "context": l.Context, @@ -176,7 +177,7 @@ func (o *OpenCode) Edit(modelList []string) error { "name": model, "_launch": true, } - if isCloudModelName(model) { + if isCloudModel(context.Background(), client, model) { if l, ok := lookupCloudModelLimit(model); ok { entry["limit"] = map[string]any{ "context": l.Context, diff --git a/cmd/tui/tui.go b/cmd/tui/tui.go index 5803d98fa..389c875a9 100644 --- a/cmd/tui/tui.go +++ b/cmd/tui/tui.go @@ -11,7 +11,6 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/ollama/ollama/api" "github.com/ollama/ollama/cmd/config" - "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/version" ) @@ -148,13 +147,7 @@ type signInCheckMsg struct { type clearStatusMsg struct{} func (m *model) modelExists(name string) bool { - if name == "" { - return false - } - if modelref.HasExplicitCloudSource(name) { - return true - } - if m.availableModels == nil { + if m.availableModels == nil || name == "" { return false } if m.availableModels[name] { @@ -216,7 +209,7 @@ func (m *model) openMultiModelModal(integration string) { } func isCloudModel(name string) bool { - return modelref.HasExplicitCloudSource(name) + return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud") } func cloudStatusDisabled(client *api.Client) bool { diff --git a/internal/modelref/modelref.go b/internal/modelref/modelref.go deleted file mode 100644 index f62757912..000000000 --- a/internal/modelref/modelref.go +++ /dev/null @@ -1,115 +0,0 @@ -package modelref - -import ( - "errors" - "fmt" - "strings" -) - -type ModelSource uint8 - -const ( - ModelSourceUnspecified ModelSource = iota - ModelSourceLocal - ModelSourceCloud -) - -var ( - ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both") - ErrModelRequired = errors.New("model is required") -) - -type ParsedRef struct { - Original string - Base string - Source ModelSource -} - -func ParseRef(raw string) (ParsedRef, error) { - var zero ParsedRef - - raw = strings.TrimSpace(raw) - if raw == "" { - return zero, ErrModelRequired - } - - base, source, explicit := parseSourceSuffix(raw) - if explicit { - if _, _, nested := parseSourceSuffix(base); nested { - return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw) - } - } - - return ParsedRef{ - Original: raw, - Base: base, - Source: source, - }, nil -} - -func HasExplicitCloudSource(raw string) bool { - parsedRef, err := ParseRef(raw) - return err == nil && parsedRef.Source == ModelSourceCloud -} - -func HasExplicitLocalSource(raw string) bool { - parsedRef, err := ParseRef(raw) - return err == nil && parsedRef.Source == ModelSourceLocal -} - -func StripCloudSourceTag(raw string) (string, bool) { - parsedRef, err := ParseRef(raw) - if err != nil || parsedRef.Source != ModelSourceCloud { - return strings.TrimSpace(raw), false - } - - return parsedRef.Base, true -} - -func NormalizePullName(raw string) (string, bool, error) { - parsedRef, err := ParseRef(raw) - if err != nil { - return "", false, err - } - - if parsedRef.Source != ModelSourceCloud { - return parsedRef.Base, false, nil - } - - return toLegacyCloudPullName(parsedRef.Base), true, nil -} - -func toLegacyCloudPullName(base string) string { - if hasExplicitTag(base) { - return base + "-cloud" - } - - return base + ":cloud" -} - -func hasExplicitTag(name string) bool { - lastSlash := strings.LastIndex(name, "/") - lastColon := strings.LastIndex(name, ":") - return lastColon > lastSlash -} - -func parseSourceSuffix(raw string) (string, ModelSource, bool) { - idx := strings.LastIndex(raw, ":") - if idx >= 0 { - suffixRaw := strings.TrimSpace(raw[idx+1:]) - suffix := strings.ToLower(suffixRaw) - - switch suffix { - case "cloud": - return raw[:idx], ModelSourceCloud, true - case "local": - return raw[:idx], ModelSourceLocal, true - } - - if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") { - return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true - } - } - - return raw, ModelSourceUnspecified, false -} diff --git a/internal/modelref/modelref_test.go b/internal/modelref/modelref_test.go deleted file mode 100644 index 7d1c1bee5..000000000 --- a/internal/modelref/modelref_test.go +++ /dev/null @@ -1,268 +0,0 @@ -package modelref - -import ( - "errors" - "testing" -) - -func TestParseRef(t *testing.T) { - tests := []struct { - name string - input string - wantBase string - wantSource ModelSource - wantErr error - wantCloud bool - wantLocal bool - wantStripped string - wantStripOK bool - }{ - { - name: "cloud suffix", - input: "gpt-oss:20b:cloud", - wantBase: "gpt-oss:20b", - wantSource: ModelSourceCloud, - wantCloud: true, - wantStripped: "gpt-oss:20b", - wantStripOK: true, - }, - { - name: "legacy cloud suffix", - input: "gpt-oss:20b-cloud", - wantBase: "gpt-oss:20b", - wantSource: ModelSourceCloud, - wantCloud: true, - wantStripped: "gpt-oss:20b", - wantStripOK: true, - }, - { - name: "local suffix", - input: "qwen3:8b:local", - wantBase: "qwen3:8b", - wantSource: ModelSourceLocal, - wantLocal: true, - wantStripped: "qwen3:8b:local", - }, - { - name: "no source suffix", - input: "llama3.2", - wantBase: "llama3.2", - wantSource: ModelSourceUnspecified, - wantStripped: "llama3.2", - }, - { - name: "bare cloud name is not explicit cloud", - input: "my-cloud-model", - wantBase: "my-cloud-model", - wantSource: ModelSourceUnspecified, - wantStripped: "my-cloud-model", - }, - { - name: "slash in suffix blocks legacy cloud parsing", - input: "foo:bar-cloud/baz", - wantBase: "foo:bar-cloud/baz", - wantSource: ModelSourceUnspecified, - wantStripped: "foo:bar-cloud/baz", - }, - { - name: "conflicting source suffixes", - input: "foo:cloud:local", - wantErr: ErrConflictingSourceSuffix, - wantSource: ModelSourceUnspecified, - }, - { - name: "empty input", - input: " ", - wantErr: ErrModelRequired, - wantSource: ModelSourceUnspecified, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := ParseRef(tt.input) - if tt.wantErr != nil { - if !errors.Is(err, tt.wantErr) { - t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr) - } - return - } - if err != nil { - t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err) - } - - if got.Base != tt.wantBase { - t.Fatalf("base = %q, want %q", got.Base, tt.wantBase) - } - - if got.Source != tt.wantSource { - t.Fatalf("source = %v, want %v", got.Source, tt.wantSource) - } - - if HasExplicitCloudSource(tt.input) != tt.wantCloud { - t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud) - } - - if HasExplicitLocalSource(tt.input) != tt.wantLocal { - t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal) - } - - stripped, ok := StripCloudSourceTag(tt.input) - if ok != tt.wantStripOK { - t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK) - } - if stripped != tt.wantStripped { - t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped) - } - }) - } -} - -func TestNormalizePullName(t *testing.T) { - tests := []struct { - name string - input string - wantName string - wantCloud bool - wantErr error - }{ - { - name: "explicit local strips source", - input: "gpt-oss:20b:local", - wantName: "gpt-oss:20b", - }, - { - name: "explicit cloud with size maps to legacy dash cloud tag", - input: "gpt-oss:20b:cloud", - wantName: "gpt-oss:20b-cloud", - wantCloud: true, - }, - { - name: "legacy cloud with size remains stable", - input: "gpt-oss:20b-cloud", - wantName: "gpt-oss:20b-cloud", - wantCloud: true, - }, - { - name: "explicit cloud without tag maps to cloud tag", - input: "qwen3:cloud", - wantName: "qwen3:cloud", - wantCloud: true, - }, - { - name: "host port without tag keeps host port and appends cloud tag", - input: "localhost:11434/library/foo:cloud", - wantName: "localhost:11434/library/foo:cloud", - wantCloud: true, - }, - { - name: "conflicting source suffixes fail", - input: "foo:cloud:local", - wantErr: ErrConflictingSourceSuffix, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotName, gotCloud, err := NormalizePullName(tt.input) - if tt.wantErr != nil { - if !errors.Is(err, tt.wantErr) { - t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr) - } - return - } - if err != nil { - t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err) - } - - if gotName != tt.wantName { - t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName) - } - if gotCloud != tt.wantCloud { - t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud) - } - }) - } -} - -func TestParseSourceSuffix(t *testing.T) { - tests := []struct { - name string - input string - wantBase string - wantSource ModelSource - wantExplicit bool - }{ - { - name: "explicit cloud suffix", - input: "gpt-oss:20b:cloud", - wantBase: "gpt-oss:20b", - wantSource: ModelSourceCloud, - wantExplicit: true, - }, - { - name: "explicit local suffix", - input: "qwen3:8b:local", - wantBase: "qwen3:8b", - wantSource: ModelSourceLocal, - wantExplicit: true, - }, - { - name: "legacy cloud suffix on tag", - input: "gpt-oss:20b-cloud", - wantBase: "gpt-oss:20b", - wantSource: ModelSourceCloud, - wantExplicit: true, - }, - { - name: "legacy cloud suffix does not match model segment", - input: "my-cloud-model", - wantBase: "my-cloud-model", - wantSource: ModelSourceUnspecified, - wantExplicit: false, - }, - { - name: "legacy cloud suffix blocked when suffix includes slash", - input: "foo:bar-cloud/baz", - wantBase: "foo:bar-cloud/baz", - wantSource: ModelSourceUnspecified, - wantExplicit: false, - }, - { - name: "unknown suffix is not explicit source", - input: "gpt-oss:clod", - wantBase: "gpt-oss:clod", - wantSource: ModelSourceUnspecified, - wantExplicit: false, - }, - { - name: "uppercase suffix is accepted", - input: "gpt-oss:20b:CLOUD", - wantBase: "gpt-oss:20b", - wantSource: ModelSourceCloud, - wantExplicit: true, - }, - { - name: "no suffix", - input: "llama3.2", - wantBase: "llama3.2", - wantSource: ModelSourceUnspecified, - wantExplicit: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input) - if gotBase != tt.wantBase { - t.Fatalf("base = %q, want %q", gotBase, tt.wantBase) - } - if gotSource != tt.wantSource { - t.Fatalf("source = %v, want %v", gotSource, tt.wantSource) - } - if gotExplicit != tt.wantExplicit { - t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit) - } - }) - } -} diff --git a/middleware/anthropic.go b/middleware/anthropic.go index d65edd53f..85c95e60c 100644 --- a/middleware/anthropic.go +++ b/middleware/anthropic.go @@ -17,7 +17,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" internalcloud "github.com/ollama/ollama/internal/cloud" - "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/logutil" ) @@ -920,7 +919,7 @@ func hasWebSearchTool(tools []anthropic.Tool) bool { } func isCloudModelName(name string) bool { - return modelref.HasExplicitCloudSource(name) + return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud") } // extractQueryFromToolCall extracts the search query from a web_search tool call diff --git a/server/cloud_proxy.go b/server/cloud_proxy.go deleted file mode 100644 index bf91d7694..000000000 --- a/server/cloud_proxy.go +++ /dev/null @@ -1,460 +0,0 @@ -package server - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "log/slog" - "net" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/gin-gonic/gin" - - "github.com/ollama/ollama/auth" - "github.com/ollama/ollama/envconfig" - internalcloud "github.com/ollama/ollama/internal/cloud" -) - -const ( - defaultCloudProxyBaseURL = "https://ollama.com:443" - defaultCloudProxySigningHost = "ollama.com" - cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL" - legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search" -) - -var ( - cloudProxyBaseURL = defaultCloudProxyBaseURL - cloudProxySigningHost = defaultCloudProxySigningHost - cloudProxySignRequest = signCloudProxyRequest - cloudProxySigninURL = signinURL -) - -var hopByHopHeaders = map[string]struct{}{ - "connection": {}, - "content-length": {}, - "proxy-connection": {}, - "keep-alive": {}, - "proxy-authenticate": {}, - "proxy-authorization": {}, - "te": {}, - "trailer": {}, - "transfer-encoding": {}, - "upgrade": {}, -} - -func init() { - baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode) - if err != nil { - slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err) - return - } - - cloudProxyBaseURL = baseURL - cloudProxySigningHost = signingHost - - if overridden { - slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode) - } -} - -func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc { - return func(c *gin.Context) { - if c.Request.Method != http.MethodPost { - c.Next() - return - } - - // TODO(drifkin): Avoid full-body buffering here for model detection. - // A future optimization can parse just enough JSON to read "model" (and - // optionally short-circuit cloud-disabled explicit-cloud requests) while - // preserving raw passthrough semantics. - body, err := readRequestBody(c.Request) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - c.Abort() - return - } - - model, ok := extractModelField(body) - if !ok { - c.Next() - return - } - - modelRef, err := parseAndValidateModelRef(model) - if err != nil || modelRef.Source != modelSourceCloud { - c.Next() - return - } - - normalizedBody, err := replaceJSONModelField(body, modelRef.Base) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - c.Abort() - return - } - - // TEMP(drifkin): keep Anthropic web search requests on the local middleware - // path so WebSearchAnthropicWriter can orchestrate follow-up calls. - if c.Request.URL.Path == "/v1/messages" { - if hasAnthropicWebSearchTool(body) { - c.Set(legacyCloudAnthropicKey, true) - c.Next() - return - } - } - - proxyCloudRequest(c, normalizedBody, disabledOperation) - c.Abort() - } -} - -func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc { - return func(c *gin.Context) { - modelName := strings.TrimSpace(c.Param("model")) - if modelName == "" { - c.Next() - return - } - - modelRef, err := parseAndValidateModelRef(modelName) - if err != nil || modelRef.Source != modelSourceCloud { - c.Next() - return - } - - proxyPath := "/v1/models/" + modelRef.Base - proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation) - c.Abort() - } -} - -func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) { - // TEMP(drifkin): we currently split out this `WithPath` method because we are - // mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we - // stop doing this, we can inline this method. - proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation) -} - -func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) { - body, err := json.Marshal(payload) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - proxyCloudRequestWithPath(c, body, path, disabledOperation) -} - -func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) { - proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation) -} - -func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) { - if disabled, _ := internalcloud.Status(); disabled { - c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)}) - return - } - - baseURL, err := url.Parse(cloudProxyBaseURL) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - targetURL := baseURL.ResolveReference(&url.URL{ - Path: path, - RawQuery: c.Request.URL.RawQuery, - }) - - outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body)) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - copyProxyRequestHeaders(outReq.Header, c.Request.Header) - if outReq.Header.Get("Content-Type") == "" && len(body) > 0 { - outReq.Header.Set("Content-Type", "application/json") - } - - if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil { - slog.Warn("cloud proxy signing failed", "error", err) - writeCloudUnauthorized(c) - return - } - - // TODO(drifkin): Add phase-specific proxy timeouts. - // Connect/TLS/TTFB should have bounded timeouts, but once streaming starts - // we should not enforce a short total timeout for long-lived responses. - resp, err := http.DefaultClient.Do(outReq) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) - return - } - defer resp.Body.Close() - - copyProxyResponseHeaders(c.Writer.Header(), resp.Header) - c.Status(resp.StatusCode) - - if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil { - c.Error(err) //nolint:errcheck - } -} - -func replaceJSONModelField(body []byte, model string) ([]byte, error) { - if len(body) == 0 { - return body, nil - } - - var payload map[string]json.RawMessage - if err := json.Unmarshal(body, &payload); err != nil { - return nil, err - } - - modelJSON, err := json.Marshal(model) - if err != nil { - return nil, err - } - payload["model"] = modelJSON - - return json.Marshal(payload) -} - -func readRequestBody(r *http.Request) ([]byte, error) { - if r.Body == nil { - return nil, nil - } - - body, err := io.ReadAll(r.Body) - if err != nil { - return nil, err - } - - r.Body = io.NopCloser(bytes.NewReader(body)) - return body, nil -} - -func extractModelField(body []byte) (string, bool) { - if len(body) == 0 { - return "", false - } - - var payload map[string]json.RawMessage - if err := json.Unmarshal(body, &payload); err != nil { - return "", false - } - - raw, ok := payload["model"] - if !ok { - return "", false - } - - var model string - if err := json.Unmarshal(raw, &model); err != nil { - return "", false - } - - model = strings.TrimSpace(model) - return model, model != "" -} - -func hasAnthropicWebSearchTool(body []byte) bool { - if len(body) == 0 { - return false - } - - var payload struct { - Tools []struct { - Type string `json:"type"` - } `json:"tools"` - } - if err := json.Unmarshal(body, &payload); err != nil { - return false - } - - for _, tool := range payload.Tools { - if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") { - return true - } - } - - return false -} - -func writeCloudUnauthorized(c *gin.Context) { - signinURL, err := cloudProxySigninURL() - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) - return - } - - c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL}) -} - -func signCloudProxyRequest(ctx context.Context, req *http.Request) error { - if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) { - return nil - } - - ts := strconv.FormatInt(time.Now().Unix(), 10) - challenge := buildCloudSignatureChallenge(req, ts) - signature, err := auth.Sign(ctx, []byte(challenge)) - if err != nil { - return err - } - - req.Header.Set("Authorization", signature) - return nil -} - -func buildCloudSignatureChallenge(req *http.Request, ts string) string { - query := req.URL.Query() - query.Set("ts", ts) - req.URL.RawQuery = query.Encode() - - return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()) -} - -func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) { - baseURL = defaultCloudProxyBaseURL - signingHost = defaultCloudProxySigningHost - - rawOverride = strings.TrimSpace(rawOverride) - if rawOverride == "" { - return baseURL, signingHost, false, nil - } - - u, err := url.Parse(rawOverride) - if err != nil { - return "", "", false, fmt.Errorf("invalid URL: %w", err) - } - if u.Scheme == "" || u.Host == "" { - return "", "", false, fmt.Errorf("invalid URL: scheme and host are required") - } - if u.User != nil { - return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed") - } - if u.Path != "" && u.Path != "/" { - return "", "", false, fmt.Errorf("invalid URL: path is not allowed") - } - if u.RawQuery != "" || u.Fragment != "" { - return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed") - } - - host := u.Hostname() - if host == "" { - return "", "", false, fmt.Errorf("invalid URL: host is required") - } - - loopback := isLoopbackHost(host) - if runMode == gin.ReleaseMode && !loopback { - return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode") - } - if !loopback && !strings.EqualFold(u.Scheme, "https") { - return "", "", false, fmt.Errorf("non-loopback cloud override must use https") - } - - u.Path = "" - u.RawPath = "" - u.RawQuery = "" - u.Fragment = "" - - return u.String(), strings.ToLower(host), true, nil -} - -func isLoopbackHost(host string) bool { - if strings.EqualFold(host, "localhost") { - return true - } - - ip := net.ParseIP(host) - return ip != nil && ip.IsLoopback() -} - -func copyProxyRequestHeaders(dst, src http.Header) { - connectionTokens := connectionHeaderTokens(src) - for key, values := range src { - if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) { - continue - } - - dst.Del(key) - for _, value := range values { - dst.Add(key, value) - } - } -} - -func copyProxyResponseHeaders(dst, src http.Header) { - connectionTokens := connectionHeaderTokens(src) - for key, values := range src { - if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) { - continue - } - - dst.Del(key) - for _, value := range values { - dst.Add(key, value) - } - } -} - -func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error { - flusher, canFlush := dst.(http.Flusher) - buf := make([]byte, 32*1024) - - for { - n, err := src.Read(buf) - if n > 0 { - if _, writeErr := dst.Write(buf[:n]); writeErr != nil { - return writeErr - } - if canFlush { - // TODO(drifkin): Consider conditional flushing so non-streaming - // responses don't flush every write and can optimize throughput. - flusher.Flush() - } - } - - if err != nil { - if err == io.EOF { - return nil - } - return err - } - } -} - -func isHopByHopHeader(name string) bool { - _, ok := hopByHopHeaders[strings.ToLower(name)] - return ok -} - -func connectionHeaderTokens(header http.Header) map[string]struct{} { - tokens := map[string]struct{}{} - for _, raw := range header.Values("Connection") { - for _, token := range strings.Split(raw, ",") { - token = strings.TrimSpace(strings.ToLower(token)) - if token == "" { - continue - } - tokens[token] = struct{}{} - } - } - return tokens -} - -func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool { - if len(tokens) == 0 { - return false - } - _, ok := tokens[strings.ToLower(name)] - return ok -} diff --git a/server/cloud_proxy_test.go b/server/cloud_proxy_test.go deleted file mode 100644 index 1a7b27956..000000000 --- a/server/cloud_proxy_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package server - -import ( - "net/http" - "testing" - - "github.com/gin-gonic/gin" -) - -func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) { - src := http.Header{} - src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop") - src.Add("X-Trace-Hop", "drop-me") - src.Add("X-Alt-Hop", "drop-me-too") - src.Add("Keep-Alive", "timeout=5") - src.Add("X-End-To-End", "keep-me") - - dst := http.Header{} - copyProxyRequestHeaders(dst, src) - - if got := dst.Get("Connection"); got != "" { - t.Fatalf("expected Connection to be stripped, got %q", got) - } - if got := dst.Get("Keep-Alive"); got != "" { - t.Fatalf("expected Keep-Alive to be stripped, got %q", got) - } - if got := dst.Get("X-Trace-Hop"); got != "" { - t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got) - } - if got := dst.Get("X-Alt-Hop"); got != "" { - t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got) - } - if got := dst.Get("X-End-To-End"); got != "keep-me" { - t.Fatalf("expected X-End-To-End to be forwarded, got %q", got) - } -} - -func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) { - src := http.Header{} - src.Add("Connection", "X-Upstream-Hop") - src.Add("X-Upstream-Hop", "drop-me") - src.Add("Content-Type", "application/json") - src.Add("X-Server-Trace", "keep-me") - - dst := http.Header{} - copyProxyResponseHeaders(dst, src) - - if got := dst.Get("Connection"); got != "" { - t.Fatalf("expected Connection to be stripped, got %q", got) - } - if got := dst.Get("X-Upstream-Hop"); got != "" { - t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got) - } - if got := dst.Get("Content-Type"); got != "application/json" { - t.Fatalf("expected Content-Type to be forwarded, got %q", got) - } - if got := dst.Get("X-Server-Trace"); got != "keep-me" { - t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got) - } -} - -func TestResolveCloudProxyBaseURL_Default(t *testing.T) { - baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if overridden { - t.Fatal("expected override=false for empty input") - } - if baseURL != defaultCloudProxyBaseURL { - t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL) - } - if signingHost != defaultCloudProxySigningHost { - t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost) - } -} - -func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) { - baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !overridden { - t.Fatal("expected override=true") - } - if baseURL != "http://localhost:8080" { - t.Fatalf("unexpected base URL: %q", baseURL) - } - if signingHost != "localhost" { - t.Fatalf("unexpected signing host: %q", signingHost) - } -} - -func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) { - _, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode) - if err == nil { - t.Fatal("expected error for non-loopback override in release mode") - } -} - -func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) { - baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !overridden { - t.Fatal("expected override=true") - } - if baseURL != "https://example.com:8443" { - t.Fatalf("unexpected base URL: %q", baseURL) - } - if signingHost != "example.com" { - t.Fatalf("unexpected signing host: %q", signingHost) - } -} - -func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) { - _, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode) - if err == nil { - t.Fatal("expected error for non-loopback http override in dev mode") - } -} - -func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil) - if err != nil { - t.Fatalf("failed to create request: %v", err) - } - - got := buildCloudSignatureChallenge(req, "123") - want := "POST,/v1/messages?beta=true&foo=bar&ts=123" - if got != want { - t.Fatalf("challenge mismatch: got %q want %q", got, want) - } - if req.URL.RawQuery != "beta=true&foo=bar&ts=123" { - t.Fatalf("unexpected signed query: %q", req.URL.RawQuery) - } -} - -func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil) - if err != nil { - t.Fatalf("failed to create request: %v", err) - } - - got := buildCloudSignatureChallenge(req, "123") - want := "POST,/v1/messages?beta=true&ts=123" - if got != want { - t.Fatalf("challenge mismatch: got %q want %q", got, want) - } - if req.URL.RawQuery != "beta=true&ts=123" { - t.Fatalf("unexpected signed query: %q", req.URL.RawQuery) - } -} diff --git a/server/create.go b/server/create.go index 9797384fd..c9ade530e 100644 --- a/server/create.go +++ b/server/create.go @@ -110,26 +110,19 @@ func (s *Server) CreateHandler(c *gin.Context) { if r.From != "" { slog.Debug("create model from model name", "from", r.From) - fromRef, err := parseAndValidateModelRef(r.From) - if err != nil { + fromName := model.ParseName(r.From) + if !fromName.IsValid() { ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest} return } - - fromName := fromRef.Name - remoteHost := r.RemoteHost - if fromRef.Source == modelSourceCloud && remoteHost == "" { - remoteHost = cloudProxyBaseURL - } - - if remoteHost != "" { - ru, err := remoteURL(remoteHost) + if r.RemoteHost != "" { + ru, err := remoteURL(r.RemoteHost) if err != nil { ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest} return } - config.RemoteModel = fromRef.Base + config.RemoteModel = r.From config.RemoteHost = ru remote = true } else { diff --git a/server/model_resolver.go b/server/model_resolver.go deleted file mode 100644 index cbbeffa37..000000000 --- a/server/model_resolver.go +++ /dev/null @@ -1,81 +0,0 @@ -package server - -import ( - "github.com/ollama/ollama/internal/modelref" - "github.com/ollama/ollama/types/model" -) - -type modelSource = modelref.ModelSource - -const ( - modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified - modelSourceLocal modelSource = modelref.ModelSourceLocal - modelSourceCloud modelSource = modelref.ModelSourceCloud -) - -var ( - errConflictingModelSource = modelref.ErrConflictingSourceSuffix - errModelRequired = modelref.ErrModelRequired -) - -type parsedModelRef struct { - // Original is the caller-provided model string before source parsing. - // Example: "gpt-oss:20b:cloud". - Original string - // Base is the model string after source suffix normalization. - // Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b". - Base string - // Name is Base parsed as a fully-qualified model.Name with defaults applied. - // Example: "registry.ollama.ai/library/gpt-oss:20b". - Name model.Name - // Source captures explicit source intent from the original input. - // Example: "gpt-oss:20b:cloud" -> modelSourceCloud. - Source modelSource -} - -func parseAndValidateModelRef(raw string) (parsedModelRef, error) { - var zero parsedModelRef - - parsed, err := modelref.ParseRef(raw) - if err != nil { - return zero, err - } - - name := model.ParseName(parsed.Base) - if !name.IsValid() { - return zero, model.Unqualified(name) - } - - return parsedModelRef{ - Original: parsed.Original, - Base: parsed.Base, - Name: name, - Source: parsed.Source, - }, nil -} - -func parseNormalizePullModelRef(raw string) (parsedModelRef, error) { - var zero parsedModelRef - - parsedRef, err := modelref.ParseRef(raw) - if err != nil { - return zero, err - } - - normalizedName, _, err := modelref.NormalizePullName(raw) - if err != nil { - return zero, err - } - - name := model.ParseName(normalizedName) - if !name.IsValid() { - return zero, model.Unqualified(name) - } - - return parsedModelRef{ - Original: parsedRef.Original, - Base: normalizedName, - Name: name, - Source: parsedRef.Source, - }, nil -} diff --git a/server/model_resolver_test.go b/server/model_resolver_test.go deleted file mode 100644 index c0926ec30..000000000 --- a/server/model_resolver_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package server - -import ( - "errors" - "strings" - "testing" -) - -func TestParseModelSelector(t *testing.T) { - t.Run("cloud suffix", func(t *testing.T) { - got, err := parseAndValidateModelRef("gpt-oss:20b:cloud") - if err != nil { - t.Fatalf("parseModelSelector returned error: %v", err) - } - - if got.Source != modelSourceCloud { - t.Fatalf("expected source cloud, got %v", got.Source) - } - - if got.Base != "gpt-oss:20b" { - t.Fatalf("expected base gpt-oss:20b, got %q", got.Base) - } - - if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" { - t.Fatalf("unexpected resolved name: %q", got.Name.String()) - } - }) - - t.Run("legacy cloud suffix", func(t *testing.T) { - got, err := parseAndValidateModelRef("gpt-oss:20b-cloud") - if err != nil { - t.Fatalf("parseModelSelector returned error: %v", err) - } - - if got.Source != modelSourceCloud { - t.Fatalf("expected source cloud, got %v", got.Source) - } - - if got.Base != "gpt-oss:20b" { - t.Fatalf("expected base gpt-oss:20b, got %q", got.Base) - } - }) - - t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) { - got, err := parseAndValidateModelRef("my-cloud-model") - if err != nil { - t.Fatalf("parseModelSelector returned error: %v", err) - } - - if got.Source != modelSourceUnspecified { - t.Fatalf("expected source unspecified, got %v", got.Source) - } - - if got.Base != "my-cloud-model" { - t.Fatalf("expected base my-cloud-model, got %q", got.Base) - } - }) - - t.Run("local suffix", func(t *testing.T) { - got, err := parseAndValidateModelRef("qwen3:8b:local") - if err != nil { - t.Fatalf("parseModelSelector returned error: %v", err) - } - - if got.Source != modelSourceLocal { - t.Fatalf("expected source local, got %v", got.Source) - } - - if got.Base != "qwen3:8b" { - t.Fatalf("expected base qwen3:8b, got %q", got.Base) - } - }) - - t.Run("conflicting source suffixes fail", func(t *testing.T) { - _, err := parseAndValidateModelRef("foo:cloud:local") - if !errors.Is(err, errConflictingModelSource) { - t.Fatalf("expected errConflictingModelSource, got %v", err) - } - }) - - t.Run("unspecified source", func(t *testing.T) { - got, err := parseAndValidateModelRef("llama3") - if err != nil { - t.Fatalf("parseModelSelector returned error: %v", err) - } - - if got.Source != modelSourceUnspecified { - t.Fatalf("expected source unspecified, got %v", got.Source) - } - - if got.Name.Tag != "latest" { - t.Fatalf("expected default latest tag, got %q", got.Name.Tag) - } - }) - - t.Run("unknown suffix is treated as tag", func(t *testing.T) { - got, err := parseAndValidateModelRef("gpt-oss:clod") - if err != nil { - t.Fatalf("parseModelSelector returned error: %v", err) - } - - if got.Source != modelSourceUnspecified { - t.Fatalf("expected source unspecified, got %v", got.Source) - } - - if got.Name.Tag != "clod" { - t.Fatalf("expected tag clod, got %q", got.Name.Tag) - } - }) - - t.Run("empty model fails", func(t *testing.T) { - _, err := parseAndValidateModelRef("") - if !errors.Is(err, errModelRequired) { - t.Fatalf("expected errModelRequired, got %v", err) - } - }) - - t.Run("invalid model fails", func(t *testing.T) { - _, err := parseAndValidateModelRef("::cloud") - if err == nil { - t.Fatal("expected error for invalid model") - } - if !strings.Contains(err.Error(), "unqualified") { - t.Fatalf("expected unqualified model error, got %v", err) - } - }) -} - -func TestParsePullModelRef(t *testing.T) { - t.Run("explicit local is normalized", func(t *testing.T) { - got, err := parseNormalizePullModelRef("gpt-oss:20b:local") - if err != nil { - t.Fatalf("parseNormalizePullModelRef returned error: %v", err) - } - - if got.Source != modelSourceLocal { - t.Fatalf("expected source local, got %v", got.Source) - } - - if got.Base != "gpt-oss:20b" { - t.Fatalf("expected base gpt-oss:20b, got %q", got.Base) - } - }) - - t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) { - got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud") - if err != nil { - t.Fatalf("parseNormalizePullModelRef returned error: %v", err) - } - if got.Base != "gpt-oss:20b-cloud" { - t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base) - } - if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" { - t.Fatalf("unexpected resolved name: %q", got.Name.String()) - } - }) - - t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) { - got, err := parseNormalizePullModelRef("qwen3:cloud") - if err != nil { - t.Fatalf("parseNormalizePullModelRef returned error: %v", err) - } - if got.Base != "qwen3:cloud" { - t.Fatalf("expected base qwen3:cloud, got %q", got.Base) - } - if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" { - t.Fatalf("unexpected resolved name: %q", got.Name.String()) - } - }) -} diff --git a/server/routes.go b/server/routes.go index 43f76ea08..a27ce3a96 100644 --- a/server/routes.go +++ b/server/routes.go @@ -64,17 +64,6 @@ const ( cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable" ) -func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) { - switch { - case errors.Is(err, errConflictingModelSource): - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - case errors.Is(err, model.ErrUnqualifiedName): - c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) - default: - c.JSON(fallbackStatus, gin.H{"error": fallbackMessage}) - } -} - func shouldUseHarmony(model *Model) bool { if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { // heuristic to check whether the template expects to be parsed via harmony: @@ -207,22 +196,14 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - modelRef, err := parseAndValidateModelRef(req.Model) - if err != nil { - writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model)) + name := model.ParseName(req.Model) + if !name.IsValid() { + // Ideally this is "invalid model name" but we're keeping with + // what the API currently returns until we can change it. + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) return } - if modelRef.Source == modelSourceCloud { - // TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the - // original body (modulo model name normalization) is sent to cloud. - req.Model = modelRef.Base - proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable) - return - } - - name := modelRef.Name - resolvedName, _, err := s.resolveAlias(name) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -256,11 +237,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - return - } - if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { if disabled, _ := internalcloud.Status(); disabled { c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)}) @@ -700,18 +676,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - modelRef, err := parseAndValidateModelRef(req.Model) - if err != nil { - writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model)) - return - } - - if modelRef.Source == modelSourceCloud { - req.Model = modelRef.Base - proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable) - return - } - var input []string switch i := req.Input.(type) { @@ -734,7 +698,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { } } - name, err := getExistingName(modelRef.Name) + name, err := getExistingName(model.ParseName(req.Model)) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) return @@ -881,20 +845,12 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - modelRef, err := parseAndValidateModelRef(req.Model) - if err != nil { - writeModelRefParseError(c, err, http.StatusBadRequest, "model is required") + name := model.ParseName(req.Model) + if !name.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return } - if modelRef.Source == modelSourceCloud { - req.Model = modelRef.Base - proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable) - return - } - - name := modelRef.Name - r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) @@ -936,19 +892,12 @@ func (s *Server) PullHandler(c *gin.Context) { return } - // TEMP(drifkin): we're temporarily allowing to continue pulling cloud model - // stub-files until we integrate cloud models into `/api/tags` (in which case - // this roundabout way of "adding" cloud models won't be needed anymore). So - // right here normalize any `:cloud` models into the legacy-style suffixes - // `:-cloud` and `:cloud` - modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name)) - if err != nil { - writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg) + name := model.ParseName(cmp.Or(req.Model, req.Name)) + if !name.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) return } - name := modelRef.Name - name, err = getExistingName(name) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -1075,20 +1024,13 @@ func (s *Server) DeleteHandler(c *gin.Context) { return } - modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name)) - if err != nil { - switch { - case errors.Is(err, errConflictingModelSource): - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - case errors.Is(err, model.ErrUnqualifiedName): - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) - default: - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - } + n := model.ParseName(cmp.Or(r.Model, r.Name)) + if !n.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) return } - n, err := getExistingName(modelRef.Name) + n, err := getExistingName(n) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))}) return @@ -1137,20 +1079,6 @@ func (s *Server) ShowHandler(c *gin.Context) { return } - modelRef, err := parseAndValidateModelRef(req.Model) - if err != nil { - writeModelRefParseError(c, err, http.StatusBadRequest, err.Error()) - return - } - - if modelRef.Source == modelSourceCloud { - req.Model = modelRef.Base - proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable) - return - } - - req.Model = modelRef.Base - resp, err := GetModelInfo(req) if err != nil { var statusErr api.StatusError @@ -1167,11 +1095,6 @@ func (s *Server) ShowHandler(c *gin.Context) { return } - if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)}) - return - } - c.JSON(http.StatusOK, resp) } @@ -1708,20 +1631,18 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/embeddings", s.EmbeddingsHandler) // Inference (OpenAI compatibility) - // TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud - // parents on v1 request families while preserving this explicit :cloud passthrough. - r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler) - r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler) - r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler) + r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler) + r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler) + r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler) r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) - r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler) - r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler) + r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) + r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) // OpenAI-compatible image generation endpoints - r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler) - r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler) + r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler) + r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler) // Inference (Anthropic compatibility) - r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler) + r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) if rc != nil { // wrap old with new @@ -2080,24 +2001,12 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - modelRef, err := parseAndValidateModelRef(req.Model) - if err != nil { - writeModelRefParseError(c, err, http.StatusBadRequest, "model is required") + name := model.ParseName(req.Model) + if !name.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return } - if modelRef.Source == modelSourceCloud { - req.Model = modelRef.Base - if c.GetBool(legacyCloudAnthropicKey) { - proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable) - return - } - proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable) - return - } - - name := modelRef.Name - resolvedName, _, err := s.resolveAlias(name) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -2129,11 +2038,6 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - return - } - // expire the runner if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) diff --git a/server/routes_cloud_test.go b/server/routes_cloud_test.go index d6311582c..b0ee126ea 100644 --- a/server/routes_cloud_test.go +++ b/server/routes_cloud_test.go @@ -1,22 +1,13 @@ package server import ( - "bufio" - "bytes" - "context" "encoding/json" - "errors" - "io" "net/http" - "net/http/httptest" - "strings" "testing" - "time" "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" internalcloud "github.com/ollama/ollama/internal/cloud" - "github.com/ollama/ollama/middleware" ) func TestStatusHandler(t *testing.T) { @@ -101,982 +92,3 @@ func TestCloudDisabledBlocksRemoteOperations(t *testing.T) { } }) } - -func TestDeleteHandlerNormalizesExplicitSourceSuffixes(t *testing.T) { - gin.SetMode(gin.TestMode) - setTestHome(t, t.TempDir()) - - s := Server{} - - tests := []string{ - "gpt-oss:20b:local", - "gpt-oss:20b:cloud", - "qwen3:cloud", - } - - for _, modelName := range tests { - t.Run(modelName, func(t *testing.T) { - w := createRequest(t, s.DeleteHandler, api.DeleteRequest{ - Model: modelName, - }) - if w.Code != http.StatusNotFound { - t.Fatalf("expected status 404, got %d (%s)", w.Code, w.Body.String()) - } - - var resp map[string]string - if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { - t.Fatal(err) - } - want := "model '" + modelName + "' not found" - if resp["error"] != want { - t.Fatalf("unexpected error: got %q, want %q", resp["error"], want) - } - }) - } -} - -func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) { - gin.SetMode(gin.TestMode) - setTestHome(t, t.TempDir()) - - type upstreamCapture struct { - path string - body string - header http.Header - } - - newUpstream := func(t *testing.T, responseBody string) (*httptest.Server, *upstreamCapture) { - t.Helper() - capture := &upstreamCapture{} - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - payload, _ := io.ReadAll(r.Body) - capture.path = r.URL.Path - capture.body = string(payload) - capture.header = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(responseBody)) - })) - - return srv, capture - } - - t.Run("api generate", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"ok":"api"}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud","prompt":"hello","stream":false}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/generate", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Test-Header", "api-header") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/api/generate" { - t.Fatalf("expected upstream path /api/generate, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - - if got := capture.header.Get("X-Test-Header"); got != "api-header" { - t.Fatalf("expected forwarded X-Test-Header=api-header, got %q", got) - } - }) - - t.Run("api chat", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"message":{"role":"assistant","content":"ok"},"done":true}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud","messages":[{"role":"user","content":"hello"}],"stream":false}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/chat", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/api/chat" { - t.Fatalf("expected upstream path /api/chat, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - }) - - t.Run("api embed", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"model":"kimi-k2.5:cloud","embeddings":[[0.1,0.2]]}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud","input":"hello"}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/embed", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/api/embed" { - t.Fatalf("expected upstream path /api/embed, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - }) - - t.Run("api embeddings", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"embedding":[0.1,0.2]}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud","prompt":"hello"}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/embeddings", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/api/embeddings" { - t.Fatalf("expected upstream path /api/embeddings, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - }) - - t.Run("api show", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"details":{"format":"gguf"}}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud"}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/show", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/api/show" { - t.Fatalf("expected upstream path /api/show, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - }) - - t.Run("v1 chat completions bypasses conversion", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"id":"chatcmpl_test","object":"chat.completion"}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"gpt-oss:120b:cloud","messages":[{"role":"user","content":"hi"}],"max_tokens":7}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/chat/completions", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Test-Header", "v1-header") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/v1/chat/completions" { - t.Fatalf("expected upstream path /v1/chat/completions, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"max_tokens":7`) { - t.Fatalf("expected original OpenAI request body, got %q", capture.body) - } - - if !strings.Contains(capture.body, `"model":"gpt-oss:120b"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - - if strings.Contains(capture.body, `"options"`) { - t.Fatalf("expected no converted Ollama options in upstream body, got %q", capture.body) - } - - if got := capture.header.Get("X-Test-Header"); got != "v1-header" { - t.Fatalf("expected forwarded X-Test-Header=v1-header, got %q", got) - } - }) - - t.Run("v1 chat completions bypasses conversion with legacy cloud suffix", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"id":"chatcmpl_test","object":"chat.completion"}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"gpt-oss:120b-cloud","messages":[{"role":"user","content":"hi"}],"max_tokens":7}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/chat/completions", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Test-Header", "v1-legacy-header") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/v1/chat/completions" { - t.Fatalf("expected upstream path /v1/chat/completions, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"max_tokens":7`) { - t.Fatalf("expected original OpenAI request body, got %q", capture.body) - } - - if !strings.Contains(capture.body, `"model":"gpt-oss:120b"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - - if strings.Contains(capture.body, `"options"`) { - t.Fatalf("expected no converted Ollama options in upstream body, got %q", capture.body) - } - - if got := capture.header.Get("X-Test-Header"); got != "v1-legacy-header" { - t.Fatalf("expected forwarded X-Test-Header=v1-legacy-header, got %q", got) - } - }) - - t.Run("v1 messages bypasses conversion", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"id":"msg_1","type":"message"}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud","max_tokens":10,"messages":[{"role":"user","content":"hi"}]}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/v1/messages" { - t.Fatalf("expected upstream path /v1/messages, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"max_tokens":10`) { - t.Fatalf("expected original Anthropic request body, got %q", capture.body) - } - - if !strings.Contains(capture.body, `"model":"kimi-k2.5"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - - if strings.Contains(capture.body, `"options"`) { - t.Fatalf("expected no converted Ollama options in upstream body, got %q", capture.body) - } - }) - - t.Run("v1 messages bypasses conversion with legacy cloud suffix", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"id":"msg_1","type":"message"}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:latest-cloud","max_tokens":10,"messages":[{"role":"user","content":"hi"}]}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/v1/messages" { - t.Fatalf("expected upstream path /v1/messages, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"max_tokens":10`) { - t.Fatalf("expected original Anthropic request body, got %q", capture.body) - } - - if !strings.Contains(capture.body, `"model":"kimi-k2.5:latest"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - - if strings.Contains(capture.body, `"options"`) { - t.Fatalf("expected no converted Ollama options in upstream body, got %q", capture.body) - } - }) - - t.Run("v1 messages web_search fallback uses legacy cloud /api/chat path", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"model":"gpt-oss:120b","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"hello"},"done":true}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{ - "model":"gpt-oss:120b-cloud", - "max_tokens":10, - "messages":[{"role":"user","content":"search the web"}], - "tools":[{"type":"web_search_20250305","name":"web_search"}], - "stream":false - }` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages?beta=true", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/api/chat" { - t.Fatalf("expected upstream path /api/chat for web_search fallback, got %q", capture.path) - } - - if !strings.Contains(capture.body, `"model":"gpt-oss:120b"`) { - t.Fatalf("expected normalized model in upstream body, got %q", capture.body) - } - - if !strings.Contains(capture.body, `"num_predict":10`) { - t.Fatalf("expected converted ollama options in upstream body, got %q", capture.body) - } - }) - - t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, local.URL+"/v1/models/kimi-k2.5:cloud", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("X-Test-Header", "v1-model-header") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/v1/models/kimi-k2.5" { - t.Fatalf("expected upstream path /v1/models/kimi-k2.5, got %q", capture.path) - } - - if capture.body != "" { - t.Fatalf("expected empty request body, got %q", capture.body) - } - - if got := capture.header.Get("X-Test-Header"); got != "v1-model-header" { - t.Fatalf("expected forwarded X-Test-Header=v1-model-header, got %q", got) - } - }) - - t.Run("v1 model retrieve normalizes legacy cloud suffix", func(t *testing.T) { - upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:latest","object":"model","created":1,"owned_by":"ollama"}`) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, local.URL+"/v1/models/kimi-k2.5:latest-cloud", nil) - if err != nil { - t.Fatal(err) - } - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - if capture.path != "/v1/models/kimi-k2.5:latest" { - t.Fatalf("expected upstream path /v1/models/kimi-k2.5:latest, got %q", capture.path) - } - }) -} - -func TestCloudDisabledBlocksExplicitCloudPassthrough(t *testing.T) { - gin.SetMode(gin.TestMode) - setTestHome(t, t.TempDir()) - t.Setenv("OLLAMA_NO_CLOUD", "1") - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - - local := httptest.NewServer(router) - defer local.Close() - - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/chat/completions", bytes.NewBufferString(`{"model":"kimi-k2.5:cloud","messages":[{"role":"user","content":"hi"}]}`)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusForbidden { - t.Fatalf("expected status 403, got %d (%s)", resp.StatusCode, string(body)) - } - - var got map[string]string - if err := json.Unmarshal(body, &got); err != nil { - t.Fatalf("expected json error body, got: %q", string(body)) - } - - if got["error"] != internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable) { - t.Fatalf("unexpected error message: %q", got["error"]) - } -} - -func TestCloudPassthroughStreamsPromptly(t *testing.T) { - gin.SetMode(gin.TestMode) - setTestHome(t, t.TempDir()) - - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/x-ndjson") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("upstream writer is not a flusher") - } - - _, _ = w.Write([]byte(`{"response":"first"}` + "\n")) - flusher.Flush() - - time.Sleep(700 * time.Millisecond) - - _, _ = w.Write([]byte(`{"response":"second"}` + "\n")) - flusher.Flush() - })) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud","messages":[{"role":"user","content":"hi"}],"stream":true}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/chat", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) - } - - reader := bufio.NewReader(resp.Body) - - start := time.Now() - firstLine, err := reader.ReadString('\n') - if err != nil { - t.Fatalf("failed reading first streamed line: %v", err) - } - if elapsed := time.Since(start); elapsed > 400*time.Millisecond { - t.Fatalf("first streamed line arrived too late (%s), likely not flushing", elapsed) - } - if !strings.Contains(firstLine, `"first"`) { - t.Fatalf("expected first line to contain first chunk, got %q", firstLine) - } - - secondLine, err := reader.ReadString('\n') - if err != nil { - t.Fatalf("failed reading second streamed line: %v", err) - } - if !strings.Contains(secondLine, `"second"`) { - t.Fatalf("expected second line to contain second chunk, got %q", secondLine) - } -} - -func TestCloudPassthroughSkipsAnthropicWebSearch(t *testing.T) { - gin.SetMode(gin.TestMode) - setTestHome(t, t.TempDir()) - - type upstreamCapture struct { - path string - } - capture := &upstreamCapture{} - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capture.path = r.URL.Path - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"msg_1","type":"message"}`)) - })) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - router := gin.New() - router.POST( - "/v1/messages", - cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), - middleware.AnthropicMessagesMiddleware(), - func(c *gin.Context) { c.Status(http.StatusTeapot) }, - ) - - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{ - "model":"kimi-k2.5:cloud", - "max_tokens":10, - "messages":[{"role":"user","content":"hi"}], - "tools":[{"type":"web_search_20250305","name":"web_search"}] - }` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusTeapot { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("expected local middleware path status %d, got %d (%s)", http.StatusTeapot, resp.StatusCode, string(body)) - } - - if capture.path != "" { - t.Fatalf("expected no passthrough for web_search requests, got upstream path %q", capture.path) - } -} - -func TestCloudPassthroughSkipsAnthropicWebSearchLegacySuffix(t *testing.T) { - gin.SetMode(gin.TestMode) - setTestHome(t, t.TempDir()) - - type upstreamCapture struct { - path string - } - capture := &upstreamCapture{} - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capture.path = r.URL.Path - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"msg_1","type":"message"}`)) - })) - defer upstream.Close() - - original := cloudProxyBaseURL - cloudProxyBaseURL = upstream.URL - t.Cleanup(func() { cloudProxyBaseURL = original }) - - router := gin.New() - router.POST( - "/v1/messages", - cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), - middleware.AnthropicMessagesMiddleware(), - func(c *gin.Context) { c.Status(http.StatusTeapot) }, - ) - - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{ - "model":"kimi-k2.5:latest-cloud", - "max_tokens":10, - "messages":[{"role":"user","content":"hi"}], - "tools":[{"type":"web_search_20250305","name":"web_search"}] - }` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusTeapot { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("expected local middleware path status %d, got %d (%s)", http.StatusTeapot, resp.StatusCode, string(body)) - } - - if capture.path != "" { - t.Fatalf("expected no passthrough for web_search requests, got upstream path %q", capture.path) - } -} - -func TestCloudPassthroughSigningFailureReturnsUnauthorized(t *testing.T) { - gin.SetMode(gin.TestMode) - setTestHome(t, t.TempDir()) - - origSignRequest := cloudProxySignRequest - origSigninURL := cloudProxySigninURL - cloudProxySignRequest = func(context.Context, *http.Request) error { - return errors.New("ssh: no key found") - } - cloudProxySigninURL = func() (string, error) { - return "https://ollama.com/signin/example", nil - } - t.Cleanup(func() { - cloudProxySignRequest = origSignRequest - cloudProxySigninURL = origSigninURL - }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud","prompt":"hello","stream":false}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/generate", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusUnauthorized { - t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body)) - } - - var got map[string]any - if err := json.Unmarshal(body, &got); err != nil { - t.Fatalf("expected json error body, got: %q", string(body)) - } - - if got["error"] != "unauthorized" { - t.Fatalf("unexpected error message: %v", got["error"]) - } - - if got["signin_url"] != "https://ollama.com/signin/example" { - t.Fatalf("unexpected signin_url: %v", got["signin_url"]) - } -} - -func TestCloudPassthroughSigningFailureWithoutSigninURL(t *testing.T) { - gin.SetMode(gin.TestMode) - setTestHome(t, t.TempDir()) - - origSignRequest := cloudProxySignRequest - origSigninURL := cloudProxySigninURL - cloudProxySignRequest = func(context.Context, *http.Request) error { - return errors.New("ssh: no key found") - } - cloudProxySigninURL = func() (string, error) { - return "", errors.New("key missing") - } - t.Cleanup(func() { - cloudProxySignRequest = origSignRequest - cloudProxySigninURL = origSigninURL - }) - - s := &Server{} - router, err := s.GenerateRoutes(nil) - if err != nil { - t.Fatal(err) - } - - local := httptest.NewServer(router) - defer local.Close() - - reqBody := `{"model":"kimi-k2.5:cloud","prompt":"hello","stream":false}` - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/generate", bytes.NewBufferString(reqBody)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := local.Client().Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusUnauthorized { - t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body)) - } - - var got map[string]any - if err := json.Unmarshal(body, &got); err != nil { - t.Fatalf("expected json error body, got: %q", string(body)) - } - - if got["error"] != "unauthorized" { - t.Fatalf("unexpected error message: %v", got["error"]) - } - - if _, ok := got["signin_url"]; ok { - t.Fatalf("did not expect signin_url when helper fails, got %v", got["signin_url"]) - } -} diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 401f98d9d..0d0ac6dbc 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -794,43 +794,6 @@ func TestCreateAndShowRemoteModel(t *testing.T) { fmt.Printf("resp = %#v\n", resp) } -func TestCreateFromCloudSourceSuffix(t *testing.T) { - gin.SetMode(gin.TestMode) - - var s Server - - w := createRequest(t, s.CreateHandler, api.CreateRequest{ - Model: "test-cloud-from-suffix", - From: "gpt-oss:20b:cloud", - Info: map[string]any{ - "capabilities": []string{"completion"}, - }, - Stream: &stream, - }) - - if w.Code != http.StatusOK { - t.Fatalf("expected status code 200, got %d", w.Code) - } - - w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"}) - if w.Code != http.StatusOK { - t.Fatalf("expected status code 200, got %d", w.Code) - } - - var resp api.ShowResponse - if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { - t.Fatal(err) - } - - if resp.RemoteHost != "https://ollama.com:443" { - t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost) - } - - if resp.RemoteModel != "gpt-oss:20b" { - t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel) - } -} - func TestCreateLicenses(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index 444c76ed6..a1a5f5424 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -111,32 +111,3 @@ func TestDeleteDuplicateLayers(t *testing.T) { checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) } - -func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) { - gin.SetMode(gin.TestMode) - - p := t.TempDir() - t.Setenv("OLLAMA_MODELS", p) - - var s Server - - _, digest := createBinFile(t, nil, nil) - w := createRequest(t, s.CreateHandler, api.CreateRequest{ - Name: "gpt-oss:20b-cloud", - Files: map[string]string{"test.gguf": digest}, - }) - if w.Code != http.StatusOK { - t.Fatalf("expected status code 200, actual %d", w.Code) - } - - checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ - filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"), - }) - - w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"}) - if w.Code != http.StatusOK { - t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String()) - } - - checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) -} diff --git a/x/cmd/run.go b/x/cmd/run.go index e96c8385e..e5d7ea25e 100644 --- a/x/cmd/run.go +++ b/x/cmd/run.go @@ -20,7 +20,6 @@ import ( "github.com/ollama/ollama/api" internalcloud "github.com/ollama/ollama/internal/cloud" - "github.com/ollama/ollama/internal/modelref" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" "github.com/ollama/ollama/types/model" @@ -44,7 +43,7 @@ const ( // isLocalModel checks if the model is running locally (not a cloud model). // TODO: Improve local/cloud model identification - could check model metadata func isLocalModel(modelName string) bool { - return !modelref.HasExplicitCloudSource(modelName) + return !strings.HasSuffix(modelName, "-cloud") } // isLocalServer checks if connecting to a local Ollama server. diff --git a/x/cmd/run_test.go b/x/cmd/run_test.go index 75429f8ac..a65e8cc80 100644 --- a/x/cmd/run_test.go +++ b/x/cmd/run_test.go @@ -22,22 +22,12 @@ func TestIsLocalModel(t *testing.T) { }, { name: "cloud model", - modelName: "gpt-oss:latest-cloud", - expected: false, - }, - { - name: "cloud model with :cloud suffix", - modelName: "gpt-oss:cloud", + modelName: "gpt-4-cloud", expected: false, }, { name: "cloud model with version", - modelName: "gpt-oss:20b-cloud", - expected: false, - }, - { - name: "cloud model with version and :cloud suffix", - modelName: "gpt-oss:20b:cloud", + modelName: "claude-3-cloud", expected: false, }, { @@ -144,7 +134,7 @@ func TestTruncateToolOutput(t *testing.T) { { name: "long output cloud model - uses 10k limit", output: string(localLimitOutput), // 20k chars, under 10k token limit - modelName: "gpt-oss:latest-cloud", + modelName: "gpt-4-cloud", host: "", shouldTrim: false, expectedLimit: defaultTokenLimit, @@ -152,7 +142,7 @@ func TestTruncateToolOutput(t *testing.T) { { name: "very long output cloud model - trimmed at 10k", output: string(defaultLimitOutput), - modelName: "gpt-oss:latest-cloud", + modelName: "gpt-4-cloud", host: "", shouldTrim: true, expectedLimit: defaultTokenLimit,