Files
ollama/server/model_recommendations_test.go
2026-05-05 14:40:18 -07:00

588 lines
18 KiB
Go

package server
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
)
func TestModelRecommendationsDefaultOrder(t *testing.T) {
want := []string{
"kimi-k2.6:cloud",
"glm-5.1:cloud",
"qwen3.5:cloud",
"minimax-m2.7:cloud",
"gemma4",
"qwen3.5",
}
if got := modelRecommendationNames(defaultModelRecommendations); !slices.Equal(got, want) {
t.Fatalf("recommendations = %v, want %v", got, want)
}
}
func TestModelRecommendationsCacheRefreshAppliesServerSideChanges(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
first := []api.ModelRecommendation{
{Model: " first-cloud:cloud ", Description: " first ", ContextLength: 2048, MaxOutputTokens: 512},
{Model: " first-local ", Description: " first local ", VRAMBytes: 3 * format.GigaByte},
}
second := []api.ModelRecommendation{
{Model: "second-cloud:cloud", Description: "second", ContextLength: 4096, MaxOutputTokens: 1024},
{Model: "second-local", Description: "second local", VRAMBytes: 6 * format.GigaByte},
}
calls := 0
cache := newModelRecommendationsCache()
cache.client = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodGet {
t.Fatalf("method = %q, want GET", req.Method)
}
if req.URL.String() != modelRecommendationsURL {
t.Fatalf("url = %q, want %q", req.URL.String(), modelRecommendationsURL)
}
calls++
payload := api.ModelRecommendationsResponse{Recommendations: first}
if calls > 1 {
payload.Recommendations = second
}
data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal payload failed: %v", err)
}
return jsonHTTPResponse(http.StatusOK, string(data)), nil
})}
if err := cache.refresh(context.Background()); err != nil {
t.Fatalf("first refresh failed: %v", err)
}
if got, want := cache.Get(), []api.ModelRecommendation{
{Model: "first-cloud:cloud", Description: "first", ContextLength: 2048, MaxOutputTokens: 512},
{Model: "first-local", Description: "first local", VRAMBytes: 3 * format.GigaByte},
}; !slices.Equal(got, want) {
t.Fatalf("after first refresh recommendations = %#v, want %#v", got, want)
}
if err := cache.refresh(context.Background()); err != nil {
t.Fatalf("second refresh failed: %v", err)
}
if got, want := cache.Get(), second; !slices.Equal(got, want) {
t.Fatalf("after second refresh recommendations = %#v, want %#v", got, want)
}
path, err := modelRecommendationsSnapshotPath()
if err != nil {
t.Fatalf("snapshot path failed: %v", err)
}
snapshotData, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read snapshot failed: %v", err)
}
var snapshot api.ModelRecommendationsResponse
if err := json.Unmarshal(snapshotData, &snapshot); err != nil {
t.Fatalf("unmarshal snapshot failed: %v", err)
}
if !slices.Equal(snapshot.Recommendations, second) {
t.Fatalf("snapshot recommendations = %#v, want %#v", snapshot.Recommendations, second)
}
}
func TestModelRecommendationsCacheRefreshErrorCasesPreserveCurrentData(t *testing.T) {
cases := []struct {
name string
transport roundTripFunc
errSubstr string
}{
{
name: "transport error",
transport: func(*http.Request) (*http.Response, error) {
return nil, errors.New("network down")
},
errSubstr: "network down",
},
{
name: "remote status error",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusInternalServerError, "upstream broken"), nil
},
errSubstr: "status 500: upstream broken",
},
{
name: "invalid json payload",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, "{"), nil
},
errSubstr: "unexpected EOF",
},
{
name: "duplicate recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"dup","description":"a"},{"model":"dup","description":"b"}]}`), nil
},
errSubstr: `duplicate recommendation "dup"`,
},
{
name: "empty recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[]}`), nil
},
errSubstr: "empty recommendations",
},
{
name: "only invalid cloud recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"bad:cloud","description":"missing limits"}]}`), nil
},
errSubstr: "no valid recommendations",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
stable := []api.ModelRecommendation{{Model: "stable-local", Description: "stable desc", VRAMBytes: 2 * format.GigaByte}}
cache.set(stable)
cache.client = &http.Client{Transport: tc.transport}
err := cache.refresh(context.Background())
if err == nil {
t.Fatalf("refresh returned nil error")
}
if !strings.Contains(err.Error(), tc.errSubstr) {
t.Fatalf("error = %q, want substring %q", err.Error(), tc.errSubstr)
}
if got := cache.Get(); !slices.Equal(got, stable) {
t.Fatalf("recommendations changed on error: got %#v, want %#v", got, stable)
}
path, pathErr := modelRecommendationsSnapshotPath()
if pathErr != nil {
t.Fatalf("snapshot path failed: %v", pathErr)
}
if _, statErr := os.Stat(path); !errors.Is(statErr, os.ErrNotExist) {
t.Fatalf("snapshot file should not be written on error, stat err = %v", statErr)
}
})
}
}
func TestModelRecommendationsCacheRefreshNoCloudShortCircuits(t *testing.T) {
setupModelRecommendationsTestEnv(t, "1")
called := false
cache := newModelRecommendationsCache()
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
called = true
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"should-not-be-used","description":"n/a"}]}`), nil
})}
err := cache.refresh(context.Background())
if !errors.Is(err, errModelRecommendationsNoCloud) {
t.Fatalf("refresh error = %v, want %v", err, errModelRecommendationsNoCloud)
}
if called {
t.Fatalf("remote endpoint should not be called when cloud is disabled")
}
}
func TestModelRecommendationsSnapshotPersistAndLoad(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
want := []api.ModelRecommendation{
{Model: "persist-cloud:cloud", Description: "persisted", ContextLength: 8192, MaxOutputTokens: 2048},
{Model: "persist-local", Description: "persisted local", VRAMBytes: 5 * format.GigaByte},
}
writer := newModelRecommendationsCache()
if err := writer.persistSnapshot(want); err != nil {
t.Fatalf("persistSnapshot failed: %v", err)
}
loader := newModelRecommendationsCache()
loader.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
loader.loadSnapshot()
if got := loader.Get(); !slices.Equal(got, want) {
t.Fatalf("loaded recommendations = %#v, want %#v", got, want)
}
}
func TestModelRecommendationsLoadSnapshotInvalidDoesNotOverwrite(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
path, err := modelRecommendationsSnapshotPath()
if err != nil {
t.Fatalf("snapshot path failed: %v", err)
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatalf("mkdir failed: %v", err)
}
if err := os.WriteFile(path, []byte("{invalid"), 0o644); err != nil {
t.Fatalf("write invalid snapshot failed: %v", err)
}
cache := newModelRecommendationsCache()
existing := []api.ModelRecommendation{{Model: "existing", Description: "existing description"}}
cache.set(existing)
cache.loadSnapshot()
if got := cache.Get(); !slices.Equal(got, existing) {
t.Fatalf("recommendations overwritten by invalid snapshot: got %#v, want %#v", got, existing)
}
}
func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing.T) {
input := []api.ModelRecommendation{
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256},
{Model: "bad-cloud:cloud", Description: "missing limits"},
{Model: " good-local ", Description: " good local ", VRAMBytes: 2 * format.GigaByte},
}
got, err := validateModelRecommendations(input)
if err != nil {
t.Fatalf("validateModelRecommendations failed: %v", err)
}
want := []api.ModelRecommendation{
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256},
{Model: "good-local", Description: "good local", VRAMBytes: 2 * format.GigaByte},
}
if !slices.Equal(got, want) {
t.Fatalf("validated recommendations = %#v, want %#v", got, want)
}
}
func TestModelRecommendationsHandlerReturnsDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
s := &Server{}
s.ModelRecommendationsExperimentalHandler(ctx)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
got := decodeRecommendationNames(t, w)
want := modelRecommendationNames(defaultModelRecommendations)
if !slices.Equal(got, want) {
t.Fatalf("models = %v, want %v", got, want)
}
}
func TestModelRecommendationsHandlerUsesCache(t *testing.T) {
gin.SetMode(gin.TestMode)
setupModelRecommendationsTestEnv(t, "1")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "test-model", Description: "test description"}})
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
s := &Server{modelCaches: &modelCaches{recommendations: cache}}
s.ModelRecommendationsExperimentalHandler(ctx)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
got := decodeRecommendationNames(t, w)
if !slices.Equal(got, []string{"test-model"}) {
t.Fatalf("models = %v, want %v", got, []string{"test-model"})
}
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsRouteRegistration(t *testing.T) {
gin.SetMode(gin.TestMode)
setupModelRecommendationsTestEnv(t, "1")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "route-model", Description: "route description"}})
s := &Server{modelCaches: &modelCaches{recommendations: cache}}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatalf("GenerateRoutes failed: %v", err)
}
getReq := httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
getResp := httptest.NewRecorder()
router.ServeHTTP(getResp, getReq)
if getResp.Code != http.StatusOK {
t.Fatalf("GET status = %d, want %d", getResp.Code, http.StatusOK)
}
if got := decodeRecommendationNames(t, getResp); !slices.Equal(got, []string{"route-model"}) {
t.Fatalf("GET models = %v, want %v", got, []string{"route-model"})
}
postReq := httptest.NewRequest(http.MethodPost, "/api/experimental/model-recommendations", nil)
postResp := httptest.NewRecorder()
router.ServeHTTP(postResp, postReq)
if postResp.Code != http.StatusMethodNotAllowed {
t.Fatalf("POST status = %d, want %d", postResp.Code, http.StatusMethodNotAllowed)
}
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRTriggersRefreshOnRead(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
newRecs := []api.ModelRecommendation{{Model: "new-cloud:cloud", Description: "new", ContextLength: 1024, MaxOutputTokens: 256}}
cache.set(old)
refreshDone := make(chan struct{})
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
defer close(refreshDone)
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"new-cloud:cloud","description":"new","context_length":1024,"max_output_tokens":256}]}`), nil
})}
gotImmediate := cache.GetSWR(context.Background())
if !slices.Equal(gotImmediate, old) {
t.Fatalf("GetSWR should return current cache immediately: got %#v, want %#v", gotImmediate, old)
}
select {
case <-refreshDone:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for async refresh")
}
waitForCondition(t, 2*time.Second, func() bool {
return slices.Equal(cache.Get(), newRecs)
})
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRSkipsWhenRefreshAlreadyInFlight(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
close(started)
}
<-release
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first refresh call")
}
for range 5 {
cache.GetSWR(context.Background())
}
time.Sleep(50 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during in-flight refresh = %d, want 1", got)
}
close(release)
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRThrottlesRefreshAfterCompletion(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
if calls.Add(1) == 1 {
close(started)
<-release
}
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first refresh call")
}
time.Sleep(2 * modelRecommendationsReadRefreshCooldown)
close(release)
waitForCacheIdle(t, cache)
cache.GetSWR(context.Background())
time.Sleep(25 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during read refresh cooldown = %d, want 1", got)
}
}
func TestModelRecommendationsGetSWRRetriesAfterReadRefreshCooldown(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
cache := newModelRecommendationsCache()
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
cache.set(old)
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
if calls.Add(1) == 1 {
return nil, errors.New("temporary upstream failure")
}
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"recovered","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
waitForCondition(t, 2*time.Second, func() bool { return calls.Load() >= 1 })
waitForCacheIdle(t, cache)
if !slices.Equal(cache.Get(), old) {
t.Fatalf("cache should remain unchanged after failed refresh, got %#v", cache.Get())
}
cache.GetSWR(context.Background())
time.Sleep(25 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during read refresh cooldown after failure = %d, want 1", got)
}
waitForCondition(t, 2*time.Second, func() bool {
cache.GetSWR(context.Background())
return calls.Load() >= 2
})
waitForCondition(t, 2*time.Second, func() bool {
return slices.Equal(cache.Get(), []api.ModelRecommendation{{Model: "recovered", Description: "ok"}})
})
waitForCacheIdle(t, cache)
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func jsonHTTPResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func setupModelRecommendationsTestEnv(t *testing.T, noCloudEnv string) {
t.Helper()
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
t.Setenv("HOMEDRIVE", filepath.VolumeName(home))
t.Setenv("HOMEPATH", strings.TrimPrefix(home, filepath.VolumeName(home)))
// Use explicit false rather than empty to avoid platform/env ambiguity.
if noCloudEnv == "" {
noCloudEnv = "false"
}
t.Setenv("OLLAMA_NO_CLOUD", noCloudEnv)
envconfig.ReloadServerConfig()
t.Cleanup(envconfig.ReloadServerConfig)
}
func withModelRecommendationsReadRefreshCooldown(t *testing.T, d time.Duration) {
t.Helper()
old := modelRecommendationsReadRefreshCooldown
modelRecommendationsReadRefreshCooldown = d
t.Cleanup(func() {
modelRecommendationsReadRefreshCooldown = old
})
}
func waitForCondition(t *testing.T, timeout time.Duration, cond func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for condition")
}
func waitForCacheIdle(t *testing.T, cache *modelRecommendationsCache) {
t.Helper()
waitForCondition(t, 2*time.Second, func() bool {
cache.mu.RLock()
refreshing := cache.refreshing
cache.mu.RUnlock()
return !refreshing
})
}
func decodeRecommendationNames(t *testing.T, w *httptest.ResponseRecorder) []string {
t.Helper()
var resp struct {
Recommendations []struct {
Model string `json:"model"`
} `json:"recommendations"`
}
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
names := make([]string, 0, len(resp.Recommendations))
for _, rec := range resp.Recommendations {
names = append(names, rec.Model)
}
return names
}
func modelRecommendationNames(recs []api.ModelRecommendation) []string {
names := make([]string, len(recs))
for i, rec := range recs {
names[i] = rec.Model
}
return names
}