diff --git a/app/cmd/app/app.go b/app/cmd/app/app.go index 7e183b8df..904807fd7 100644 --- a/app/cmd/app/app.go +++ b/app/cmd/app/app.go @@ -253,6 +253,8 @@ func main() { done <- osrv.Run(octx) }() + upd := &updater.Updater{Store: st} + uiServer := ui.Server{ Token: token, Restart: func() { @@ -267,6 +269,10 @@ func main() { ToolRegistry: toolRegistry, Dev: devMode, Logger: slog.Default(), + Updater: upd, + UpdateAvailableFunc: func() { + UpdateAvailable("") + }, } srv := &http.Server{ @@ -284,8 +290,13 @@ func main() { slog.Debug("background desktop server done") }() - updater := &updater.Updater{Store: st} - updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable) + upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable) + + // Check for pending updates on startup (show tray notification if update is ready) + if updater.IsUpdatePending() { + slog.Debug("update pending on startup, showing tray notification") + UpdateAvailable("") + } hasCompletedFirstRun, err := st.HasCompletedFirstRun() if err != nil { @@ -348,6 +359,18 @@ func startHiddenTasks() { // CLI triggered app startup use-case slog.Info("deferring pending update for fast startup") } else { + // Check if auto-update is enabled before automatically upgrading + st := &store.Store{} + settings, err := st.Settings() + if err != nil { + slog.Warn("failed to load settings for upgrade check", "error", err) + } else if !settings.AutoUpdateEnabled { + slog.Info("auto-update disabled, skipping automatic upgrade at startup") + // Still show tray notification so user knows update is ready + UpdateAvailable("") + return + } + if err := updater.DoUpgradeAtStartup(); err != nil { slog.Info("unable to perform upgrade at startup", "error", err) // Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization diff --git a/app/store/database.go b/app/store/database.go index 20f384bb7..82e2f238e 100644 --- a/app/store/database.go +++ b/app/store/database.go @@ -9,12 +9,12 @@ import ( "strings" "time" - sqlite3 "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" ) // currentSchemaVersion defines the current database schema version. // Increment this when making schema changes that require migrations. -const currentSchemaVersion = 14 +const currentSchemaVersion = 15 // database wraps the SQLite connection. // SQLite handles its own locking for concurrent access: @@ -86,6 +86,7 @@ func (db *database) init() error { think_level TEXT NOT NULL DEFAULT '', cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0, remote TEXT NOT NULL DEFAULT '', -- deprecated + auto_update_enabled BOOLEAN NOT NULL DEFAULT 1, schema_version INTEGER NOT NULL DEFAULT %d ); @@ -257,6 +258,12 @@ func (db *database) migrate() error { return fmt.Errorf("migrate v13 to v14: %w", err) } version = 14 + case 14: + // add auto_update_enabled column to settings table + if err := db.migrateV14ToV15(); err != nil { + return fmt.Errorf("migrate v14 to v15: %w", err) + } + version = 15 default: // If we have a version we don't recognize, just set it to current // This might happen during development @@ -496,6 +503,21 @@ func (db *database) migrateV13ToV14() error { return nil } +// migrateV14ToV15 adds the auto_update_enabled column to the settings table +func (db *database) migrateV14ToV15() error { + _, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`) + if err != nil && !duplicateColumnError(err) { + return fmt.Errorf("add auto_update_enabled column: %w", err) + } + + _, err = db.conn.Exec(`UPDATE settings SET schema_version = 15`) + if err != nil { + return fmt.Errorf("update schema version: %w", err) + } + + return nil +} + // cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug func (db *database) cleanupOrphanedData() error { _, err := db.conn.Exec(` @@ -526,19 +548,11 @@ func (db *database) cleanupOrphanedData() error { } func duplicateColumnError(err error) bool { - if sqlite3Err, ok := err.(sqlite3.Error); ok { - return sqlite3Err.Code == sqlite3.ErrError && - strings.Contains(sqlite3Err.Error(), "duplicate column name") - } - return false + return err != nil && strings.Contains(err.Error(), "duplicate column name") } func columnNotExists(err error) bool { - if sqlite3Err, ok := err.(sqlite3.Error); ok { - return sqlite3Err.Code == sqlite3.ErrError && - strings.Contains(sqlite3Err.Error(), "no such column") - } - return false + return err != nil && strings.Contains(err.Error(), "no such column") } func (db *database) getAllChats() ([]Chat, error) { @@ -1152,9 +1166,9 @@ func (db *database) getSettings() (Settings, error) { var s Settings err := db.conn.QueryRow(` - SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level + SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled FROM settings - `).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel) + `).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled) if err != nil { return Settings{}, fmt.Errorf("get settings: %w", err) } @@ -1164,9 +1178,9 @@ func (db *database) getSettings() (Settings, error) { func (db *database) setSettings(s Settings) error { _, err := db.conn.Exec(` - UPDATE settings - SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ? - `, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel) + UPDATE settings + SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ? + `, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled) if err != nil { return fmt.Errorf("set settings: %w", err) } diff --git a/app/store/store.go b/app/store/store.go index 171ead8e2..a4c03b239 100644 --- a/app/store/store.go +++ b/app/store/store.go @@ -166,6 +166,9 @@ type Settings struct { // SidebarOpen indicates if the chat sidebar is open SidebarOpen bool + + // AutoUpdateEnabled indicates if automatic updates should be downloaded + AutoUpdateEnabled bool } type Store struct { diff --git a/app/ui/app/codegen/gotypes.gen.ts b/app/ui/app/codegen/gotypes.gen.ts index 61140bf7f..51928e685 100644 --- a/app/ui/app/codegen/gotypes.gen.ts +++ b/app/ui/app/codegen/gotypes.gen.ts @@ -414,6 +414,7 @@ export class Settings { ThinkLevel: string; SelectedModel: string; SidebarOpen: boolean; + AutoUpdateEnabled: boolean; constructor(source: any = {}) { if ('string' === typeof source) source = JSON.parse(source); @@ -431,6 +432,7 @@ export class Settings { this.ThinkLevel = source["ThinkLevel"]; this.SelectedModel = source["SelectedModel"]; this.SidebarOpen = source["SidebarOpen"]; + this.AutoUpdateEnabled = source["AutoUpdateEnabled"]; } } export class SettingsResponse { diff --git a/app/ui/app/src/components/Settings.tsx b/app/ui/app/src/components/Settings.tsx index ef0bf4c53..77b9fe632 100644 --- a/app/ui/app/src/components/Settings.tsx +++ b/app/ui/app/src/components/Settings.tsx @@ -15,6 +15,7 @@ import { XMarkIcon, CogIcon, ArrowLeftIcon, + ArrowDownTrayIcon, } from "@heroicons/react/20/solid"; import { Settings as SettingsType } from "@/gotypes"; import { useNavigate } from "@tanstack/react-router"; @@ -440,6 +441,29 @@ export default function Settings() { + {/* Auto Update */} + +
+
+ +
+ + + {settings.AutoUpdateEnabled + ? "Automatically download updates when available." + : "Updates will not be downloaded automatically."} + +
+
+
+ handleChange("AutoUpdateEnabled", checked)} + /> +
+
+
+ {/* Expose Ollama */}
diff --git a/app/ui/ui.go b/app/ui/ui.go index 3f2e73f37..f720fe05a 100644 --- a/app/ui/ui.go +++ b/app/ui/ui.go @@ -28,6 +28,7 @@ import ( "github.com/ollama/ollama/app/tools" "github.com/ollama/ollama/app/types/not" "github.com/ollama/ollama/app/ui/responses" + "github.com/ollama/ollama/app/updater" "github.com/ollama/ollama/app/version" ollamaAuth "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" @@ -106,6 +107,10 @@ type Server struct { // Dev is true if the server is running in development mode Dev bool + + // Updater for checking and downloading updates + Updater *updater.Updater + UpdateAvailableFunc func() } func (s *Server) log() *slog.Logger { @@ -1447,6 +1452,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error { return fmt.Errorf("failed to save settings: %w", err) } + // Handle auto-update toggle changes + if old.AutoUpdateEnabled != settings.AutoUpdateEnabled { + if !settings.AutoUpdateEnabled { + // Auto-update disabled: cancel any ongoing download + if s.Updater != nil { + s.Updater.CancelOngoingDownload() + } + } else { + // Auto-update re-enabled: show notification if update is already staged, or trigger immediate check + if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil { + s.UpdateAvailableFunc() + } else if s.Updater != nil { + // Trigger the background checker to run immediately + s.Updater.TriggerImmediateCheck() + } + } + } + if old.ContextLength != settings.ContextLength || old.Models != settings.Models || old.Expose != settings.Expose { diff --git a/app/ui/ui_test.go b/app/ui/ui_test.go index 6f29732b2..270f3145f 100644 --- a/app/ui/ui_test.go +++ b/app/ui/ui_test.go @@ -4,6 +4,7 @@ package ui import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -11,9 +12,11 @@ import ( "path/filepath" "runtime" "strings" + "sync/atomic" "testing" "github.com/ollama/ollama/app/store" + "github.com/ollama/ollama/app/updater" ) func TestHandlePostApiSettings(t *testing.T) { @@ -629,3 +632,183 @@ func TestWebSearchToolRegistration(t *testing.T) { }) } } + +func TestSettingsToggleAutoUpdateOff_CancelsDownload(t *testing.T) { + testStore := &store.Store{ + DBPath: filepath.Join(t.TempDir(), "db.sqlite"), + } + defer testStore.Close() + + // Start with auto-update enabled + settings, err := testStore.Settings() + if err != nil { + t.Fatal(err) + } + settings.AutoUpdateEnabled = true + if err := testStore.SetSettings(settings); err != nil { + t.Fatal(err) + } + + upd := &updater.Updater{Store: &store.Store{ + DBPath: filepath.Join(t.TempDir(), "db2.sqlite"), + }} + defer upd.Store.Close() + + // We can't easily mock CancelOngoingDownload, but we can verify + // the full settings handler flow works without error + server := &Server{ + Store: testStore, + Restart: func() {}, + Updater: upd, + } + + // Disable auto-update via settings API + settings.AutoUpdateEnabled = false + body, err := json.Marshal(settings) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + if err := server.settings(rr, req); err != nil { + t.Fatalf("settings() error = %v", err) + } + if rr.Code != http.StatusOK { + t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK) + } + + // Verify settings were saved with auto-update disabled + saved, err := testStore.Settings() + if err != nil { + t.Fatal(err) + } + if saved.AutoUpdateEnabled { + t.Fatal("expected AutoUpdateEnabled to be false after toggle off") + } +} + +func TestSettingsToggleAutoUpdateOn_WithPendingUpdate_ShowsNotification(t *testing.T) { + testStore := &store.Store{ + DBPath: filepath.Join(t.TempDir(), "db.sqlite"), + } + defer testStore.Close() + + // Start with auto-update disabled + settings, err := testStore.Settings() + if err != nil { + t.Fatal(err) + } + settings.AutoUpdateEnabled = false + if err := testStore.SetSettings(settings); err != nil { + t.Fatal(err) + } + + // Simulate that an update was previously downloaded + oldVal := updater.UpdateDownloaded + updater.UpdateDownloaded = true + defer func() { updater.UpdateDownloaded = oldVal }() + + var notificationCalled atomic.Bool + server := &Server{ + Store: testStore, + Restart: func() {}, + UpdateAvailableFunc: func() { + notificationCalled.Store(true) + }, + } + + // Re-enable auto-update via settings API + settings.AutoUpdateEnabled = true + body, err := json.Marshal(settings) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + if err := server.settings(rr, req); err != nil { + t.Fatalf("settings() error = %v", err) + } + if rr.Code != http.StatusOK { + t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK) + } + + if !notificationCalled.Load() { + t.Fatal("expected UpdateAvailableFunc to be called when re-enabling with a downloaded update") + } +} + +func TestSettingsToggleAutoUpdateOn_NoPendingUpdate_TriggersCheck(t *testing.T) { + testStore := &store.Store{ + DBPath: filepath.Join(t.TempDir(), "db.sqlite"), + } + defer testStore.Close() + + // Start with auto-update disabled + settings, err := testStore.Settings() + if err != nil { + t.Fatal(err) + } + settings.AutoUpdateEnabled = false + if err := testStore.SetSettings(settings); err != nil { + t.Fatal(err) + } + + // Ensure no pending update - clear both the downloaded flag and the stage dir + oldVal := updater.UpdateDownloaded + updater.UpdateDownloaded = false + defer func() { updater.UpdateDownloaded = oldVal }() + + oldStageDir := updater.UpdateStageDir + updater.UpdateStageDir = t.TempDir() // empty dir means IsUpdatePending() returns false + defer func() { updater.UpdateStageDir = oldStageDir }() + + upd := &updater.Updater{Store: &store.Store{ + DBPath: filepath.Join(t.TempDir(), "db2.sqlite"), + }} + defer upd.Store.Close() + + // Initialize the checkNow channel by starting (and immediately stopping) the checker + // so TriggerImmediateCheck doesn't panic on nil channel + ctx, cancel := context.WithCancel(t.Context()) + upd.StartBackgroundUpdaterChecker(ctx, func(string) error { return nil }) + defer cancel() + + var notificationCalled atomic.Bool + server := &Server{ + Store: testStore, + Restart: func() {}, + Updater: upd, + UpdateAvailableFunc: func() { + notificationCalled.Store(true) + }, + } + + // Re-enable auto-update via settings API + settings.AutoUpdateEnabled = true + body, err := json.Marshal(settings) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + if err := server.settings(rr, req); err != nil { + t.Fatalf("settings() error = %v", err) + } + if rr.Code != http.StatusOK { + t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK) + } + + // UpdateAvailableFunc should NOT be called since there's no pending update + if notificationCalled.Load() { + t.Fatal("UpdateAvailableFunc should not be called when there is no pending update") + } +} diff --git a/app/updater/updater.go b/app/updater/updater.go index 473ecf466..29b6dbd3c 100644 --- a/app/updater/updater.go +++ b/app/updater/updater.go @@ -19,6 +19,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "github.com/ollama/ollama/app/store" @@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) { query := requestURL.Query() query.Add("os", runtime.GOOS) query.Add("arch", runtime.GOARCH) - query.Add("version", version.Version) + currentVersion := version.Version + query.Add("version", currentVersion) query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10)) // The original macOS app used to use the device ID @@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) { } func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { + // Create a cancellable context for this download + downloadCtx, cancel := context.WithCancel(ctx) + u.cancelDownloadLock.Lock() + u.cancelDownload = cancel + u.cancelDownloadLock.Unlock() + defer func() { + u.cancelDownloadLock.Lock() + u.cancelDownload = nil + u.cancelDownloadLock.Unlock() + cancel() + }() + // Do a head first to check etag info - req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil) + req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil) if err != nil { return err } // In case of slow downloads, continue the update check in the background - bgctx, cancel := context.WithCancel(ctx) - defer cancel() + bgctx, bgcancel := context.WithCancel(downloadCtx) + defer bgcancel() go func() { for { select { @@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo _, err = os.Stat(stageFilename) if err == nil { slog.Info("update already downloaded", "bundle", stageFilename) + UpdateDownloaded = true return nil } @@ -244,33 +259,84 @@ func cleanupOldDownloads(stageDir string) { } type Updater struct { - Store *store.Store + Store *store.Store + cancelDownload context.CancelFunc + cancelDownloadLock sync.Mutex + checkNow chan struct{} +} + +// CancelOngoingDownload cancels any currently running download +func (u *Updater) CancelOngoingDownload() { + u.cancelDownloadLock.Lock() + defer u.cancelDownloadLock.Unlock() + if u.cancelDownload != nil { + slog.Info("cancelling ongoing update download") + u.cancelDownload() + u.cancelDownload = nil + } +} + +// TriggerImmediateCheck signals the background checker to check for updates immediately +func (u *Updater) TriggerImmediateCheck() { + if u.checkNow != nil { + select { + case u.checkNow <- struct{}{}: + default: + // Check already pending, no need to queue another + } + } } func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) { + u.checkNow = make(chan struct{}, 1) go func() { // Don't blast an update message immediately after startup time.Sleep(UpdateCheckInitialDelay) slog.Info("beginning update checker", "interval", UpdateCheckInterval) + ticker := time.NewTicker(UpdateCheckInterval) + defer ticker.Stop() + for { - available, resp := u.checkForUpdate(ctx) - if available { - err := u.DownloadNewRelease(ctx, resp) - if err != nil { - slog.Error(fmt.Sprintf("failed to download new release: %s", err)) - } else { - err = cb(resp.UpdateVersion) - if err != nil { - slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err)) - } - } - } select { case <-ctx.Done(): slog.Debug("stopping background update checker") return - default: - time.Sleep(UpdateCheckInterval) + case <-u.checkNow: + // Immediate check triggered + case <-ticker.C: + // Regular interval check + } + + // Always check for updates + available, resp := u.checkForUpdate(ctx) + if !available { + continue + } + + // Update is available - check if auto-update is enabled for downloading + settings, err := u.Store.Settings() + if err != nil { + slog.Error("failed to load settings", "error", err) + continue + } + + if !settings.AutoUpdateEnabled { + // Auto-update disabled - don't download, just log + slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion) + continue + } + + // Auto-update is enabled - download + err = u.DownloadNewRelease(ctx, resp) + if err != nil { + slog.Error("failed to download new release", "error", err) + continue + } + + // Download successful - show tray notification (regardless of toggle state) + err = cb(resp.UpdateVersion) + if err != nil { + slog.Warn("failed to register update available with tray", "error", err) } } }() diff --git a/app/updater/updater_test.go b/app/updater/updater_test.go index dea820c28..c6857edbe 100644 --- a/app/updater/updater_test.go +++ b/app/updater/updater_test.go @@ -11,6 +11,8 @@ import ( "log/slog" "net/http" "net/http/httptest" + "path/filepath" + "sync/atomic" "testing" "time" @@ -33,7 +35,7 @@ func TestIsNewReleaseAvailable(t *testing.T) { defer server.Close() slog.Debug("server", "url", server.URL) - updater := &Updater{Store: &store.Store{}} + updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}} defer updater.Store.Close() // Ensure database is closed UpdateCheckURLBase = server.URL + "/update.json" updatePresent, resp := updater.checkForUpdate(t.Context()) @@ -84,8 +86,18 @@ func TestBackgoundChecker(t *testing.T) { defer server.Close() UpdateCheckURLBase = server.URL + "/update.json" - updater := &Updater{Store: &store.Store{}} - defer updater.Store.Close() // Ensure database is closed + updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}} + defer updater.Store.Close() + + settings, err := updater.Store.Settings() + if err != nil { + t.Fatal(err) + } + settings.AutoUpdateEnabled = true + if err := updater.Store.SetSettings(settings); err != nil { + t.Fatal(err) + } + updater.StartBackgroundUpdaterChecker(ctx, cb) select { case <-stallTimer.C: @@ -99,3 +111,264 @@ func TestBackgoundChecker(t *testing.T) { } } } + +func TestAutoUpdateDisabledSkipsDownload(t *testing.T) { + UpdateStageDir = t.TempDir() + var downloadAttempted atomic.Bool + done := make(chan struct{}) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + UpdateCheckInitialDelay = 5 * time.Millisecond + UpdateCheckInterval = 5 * time.Millisecond + VerifyDownload = func() error { + return nil + } + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/update.json" { + w.Write([]byte( + fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`, + server.URL+"/9.9.9/"+Installer))) + } else if r.URL.Path == "/9.9.9/"+Installer { + downloadAttempted.Store(true) + buf := &bytes.Buffer{} + zw := zip.NewWriter(buf) + zw.Close() + io.Copy(w, buf) + } + })) + defer server.Close() + UpdateCheckURLBase = server.URL + "/update.json" + + updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}} + defer updater.Store.Close() + + // Ensure auto-update is disabled + settings, err := updater.Store.Settings() + if err != nil { + t.Fatal(err) + } + settings.AutoUpdateEnabled = false + if err := updater.Store.SetSettings(settings); err != nil { + t.Fatal(err) + } + + cb := func(ver string) error { + t.Fatal("callback should not be called when auto-update is disabled") + return nil + } + + updater.StartBackgroundUpdaterChecker(ctx, cb) + + // Wait enough time for multiple check cycles + time.Sleep(50 * time.Millisecond) + close(done) + + if downloadAttempted.Load() { + t.Fatal("download should not be attempted when auto-update is disabled") + } +} + +func TestAutoUpdateReenabledDownloadsUpdate(t *testing.T) { + UpdateStageDir = t.TempDir() + var downloadAttempted atomic.Bool + callbackCalled := make(chan struct{}, 1) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + UpdateCheckInitialDelay = 5 * time.Millisecond + UpdateCheckInterval = 5 * time.Millisecond + VerifyDownload = func() error { + return nil + } + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/update.json" { + w.Write([]byte( + fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`, + server.URL+"/9.9.9/"+Installer))) + } else if r.URL.Path == "/9.9.9/"+Installer { + downloadAttempted.Store(true) + buf := &bytes.Buffer{} + zw := zip.NewWriter(buf) + zw.Close() + io.Copy(w, buf) + } + })) + defer server.Close() + UpdateCheckURLBase = server.URL + "/update.json" + + upd := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}} + defer upd.Store.Close() + + // Start with auto-update disabled + settings, err := upd.Store.Settings() + if err != nil { + t.Fatal(err) + } + settings.AutoUpdateEnabled = false + if err := upd.Store.SetSettings(settings); err != nil { + t.Fatal(err) + } + + cb := func(ver string) error { + select { + case callbackCalled <- struct{}{}: + default: + } + return nil + } + + upd.StartBackgroundUpdaterChecker(ctx, cb) + + // Wait for a few cycles with auto-update disabled - no download should happen + time.Sleep(50 * time.Millisecond) + if downloadAttempted.Load() { + t.Fatal("download should not happen while auto-update is disabled") + } + + // Re-enable auto-update + settings.AutoUpdateEnabled = true + if err := upd.Store.SetSettings(settings); err != nil { + t.Fatal(err) + } + + // Wait for the checker to pick it up and download + select { + case <-callbackCalled: + // Success: download happened and callback was called after re-enabling + if !downloadAttempted.Load() { + t.Fatal("expected download to be attempted after re-enabling") + } + case <-time.After(5 * time.Second): + t.Fatal("expected download and callback after re-enabling auto-update") + } +} + +func TestCancelOngoingDownload(t *testing.T) { + UpdateStageDir = t.TempDir() + downloadStarted := make(chan struct{}) + downloadCancelled := make(chan struct{}) + + ctx := t.Context() + VerifyDownload = func() error { + return nil + } + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/update.json" { + w.Write([]byte( + fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`, + server.URL+"/9.9.9/"+Installer))) + } else if r.URL.Path == "/9.9.9/"+Installer { + if r.Method == http.MethodHead { + w.Header().Set("Content-Length", "1000000") + w.WriteHeader(http.StatusOK) + return + } + // Signal that download has started + close(downloadStarted) + // Wait for cancellation or timeout + select { + case <-r.Context().Done(): + close(downloadCancelled) + return + case <-time.After(5 * time.Second): + t.Error("download was not cancelled in time") + } + } + })) + defer server.Close() + UpdateCheckURLBase = server.URL + "/update.json" + + updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}} + defer updater.Store.Close() + + _, resp := updater.checkForUpdate(ctx) + + // Start download in goroutine + go func() { + _ = updater.DownloadNewRelease(ctx, resp) + }() + + // Wait for download to start + select { + case <-downloadStarted: + case <-time.After(2 * time.Second): + t.Fatal("download did not start in time") + } + + // Cancel the download + updater.CancelOngoingDownload() + + // Verify cancellation was received + select { + case <-downloadCancelled: + // Success + case <-time.After(2 * time.Second): + t.Fatal("download cancellation was not received by server") + } +} + +func TestTriggerImmediateCheck(t *testing.T) { + UpdateStageDir = t.TempDir() + checkCount := atomic.Int32{} + checkDone := make(chan struct{}, 10) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + // Set a very long interval so only TriggerImmediateCheck causes checks + UpdateCheckInitialDelay = 1 * time.Millisecond + UpdateCheckInterval = 1 * time.Hour + VerifyDownload = func() error { + return nil + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/update.json" { + checkCount.Add(1) + select { + case checkDone <- struct{}{}: + default: + } + // Return no update available + w.WriteHeader(http.StatusNoContent) + } + })) + defer server.Close() + UpdateCheckURLBase = server.URL + "/update.json" + + updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}} + defer updater.Store.Close() + + cb := func(ver string) error { + return nil + } + + updater.StartBackgroundUpdaterChecker(ctx, cb) + + // Wait for goroutine to start and pass initial delay + time.Sleep(10 * time.Millisecond) + + // With 1 hour interval, no check should have happened yet + initialCount := checkCount.Load() + + // Trigger immediate check + updater.TriggerImmediateCheck() + + // Wait for the triggered check + select { + case <-checkDone: + case <-time.After(2 * time.Second): + t.Fatal("triggered check did not happen") + } + + finalCount := checkCount.Load() + if finalCount <= initialCount { + t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount) + } +} diff --git a/app/wintray/tray.go b/app/wintray/tray.go index 71d4bc767..179fbc1a4 100644 --- a/app/wintray/tray.go +++ b/app/wintray/tray.go @@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error { return nil } -// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error { -// const ERROR_SUCCESS syscall.Errno = 0 - -// t.muMenus.RLock() -// menu := uintptr(t.menus[parentId]) -// t.muMenus.RUnlock() -// res, _, err := pRemoveMenu.Call( -// menu, -// uintptr(menuItemId), -// MF_BYCOMMAND, -// ) -// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS { -// return err -// } -// t.delFromVisibleItems(parentId, menuItemId) - -// return nil -// } - func (t *winTray) showMenu() error { p := point{} boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p))) diff --git a/app/wintray/w32api.go b/app/wintray/w32api.go index 10e255816..861e090c1 100644 --- a/app/wintray/w32api.go +++ b/app/wintray/w32api.go @@ -51,7 +51,6 @@ const ( IMAGE_ICON = 1 // Loads an icon LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file - MF_BYCOMMAND = 0x00000000 MFS_DISABLED = 0x00000003 MFT_SEPARATOR = 0x00000800 MFT_STRING = 0x00000000