diff --git a/cmd/cmd.go b/cmd/cmd.go index abc970f63..00611a3d1 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -367,14 +367,25 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { return err } else if info.RemoteHost != "" { // Cloud model, no need to load/unload + + isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com") + + // Check if user is signed in for ollama.com cloud models + if isCloud { + if _, err := client.Whoami(cmd.Context()); err != nil { + return err + } + } + if opts.ShowConnect { p.StopAndClear() - if strings.HasPrefix(info.RemoteHost, "https://ollama.com") { + if isCloud { fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel) } else { fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost) } } + return nil } diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 433b2ab1b..7217c3d13 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -1659,3 +1660,103 @@ func TestRunOptions_Copy_Independence(t *testing.T) { t.Error("Copy Think should not be affected by original modification") } } + +func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { + tests := []struct { + name string + remoteHost string + whoamiStatus int + whoamiResp any + expectedError string + }{ + { + name: "ollama.com cloud model - user signed in", + remoteHost: "https://ollama.com", + whoamiStatus: http.StatusOK, + whoamiResp: api.UserResponse{Name: "testuser"}, + }, + { + name: "ollama.com cloud model - user not signed in", + remoteHost: "https://ollama.com", + whoamiStatus: http.StatusUnauthorized, + whoamiResp: map[string]string{ + "error": "unauthorized", + "signin_url": "https://ollama.com/signin", + }, + expectedError: "unauthorized", + }, + { + name: "non-ollama.com remote - no auth check", + remoteHost: "https://other-remote.com", + whoamiStatus: http.StatusUnauthorized, // should not be called + whoamiResp: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + whoamiCalled := false + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/show": + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(api.ShowResponse{ + RemoteHost: tt.remoteHost, + RemoteModel: "test-model", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + case "/api/me": + whoamiCalled = true + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tt.whoamiStatus) + if tt.whoamiResp != nil { + if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } + default: + http.NotFound(w, r) + } + })) + defer mockServer.Close() + + t.Setenv("OLLAMA_HOST", mockServer.URL) + + cmd := &cobra.Command{} + cmd.SetContext(t.Context()) + + opts := &runOptions{ + Model: "test-cloud-model", + ShowConnect: false, + } + + err := loadOrUnloadModel(cmd, opts) + + if strings.HasPrefix(tt.remoteHost, "https://ollama.com") { + if !whoamiCalled { + t.Error("expected whoami to be called for ollama.com cloud model") + } + } else { + if whoamiCalled { + t.Error("whoami should not be called for non-ollama.com remote") + } + } + + if tt.expectedError != "" { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.expectedError) + } else { + var authErr api.AuthorizationError + if !errors.As(err, &authErr) { + t.Errorf("expected AuthorizationError, got %T: %v", err, err) + } + } + } else { + if err != nil { + t.Errorf("expected no error, got %v", err) + } + } + }) + } +}