mirror of
https://github.com/ollama/ollama.git
synced 2026-03-11 17:34:04 -05:00
app: add upgrade configuration to settings page (#13512)
This commit is contained in:
@@ -253,6 +253,8 @@ func main() {
|
||||
done <- osrv.Run(octx)
|
||||
}()
|
||||
|
||||
upd := &updater.Updater{Store: st}
|
||||
|
||||
uiServer := ui.Server{
|
||||
Token: token,
|
||||
Restart: func() {
|
||||
@@ -267,6 +269,10 @@ func main() {
|
||||
ToolRegistry: toolRegistry,
|
||||
Dev: devMode,
|
||||
Logger: slog.Default(),
|
||||
Updater: upd,
|
||||
UpdateAvailableFunc: func() {
|
||||
UpdateAvailable("")
|
||||
},
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
@@ -284,8 +290,13 @@ func main() {
|
||||
slog.Debug("background desktop server done")
|
||||
}()
|
||||
|
||||
updater := &updater.Updater{Store: st}
|
||||
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||
|
||||
// Check for pending updates on startup (show tray notification if update is ready)
|
||||
if updater.IsUpdatePending() {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
|
||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||
if err != nil {
|
||||
@@ -348,6 +359,18 @@ func startHiddenTasks() {
|
||||
// CLI triggered app startup use-case
|
||||
slog.Info("deferring pending update for fast startup")
|
||||
} else {
|
||||
// Check if auto-update is enabled before automatically upgrading
|
||||
st := &store.Store{}
|
||||
settings, err := st.Settings()
|
||||
if err != nil {
|
||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||
} else if !settings.AutoUpdateEnabled {
|
||||
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
|
||||
// Still show tray notification so user knows update is ready
|
||||
UpdateAvailable("")
|
||||
return
|
||||
}
|
||||
|
||||
if err := updater.DoUpgradeAtStartup(); err != nil {
|
||||
slog.Info("unable to perform upgrade at startup", "error", err)
|
||||
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// currentSchemaVersion defines the current database schema version.
|
||||
// Increment this when making schema changes that require migrations.
|
||||
const currentSchemaVersion = 14
|
||||
const currentSchemaVersion = 15
|
||||
|
||||
// database wraps the SQLite connection.
|
||||
// SQLite handles its own locking for concurrent access:
|
||||
@@ -86,6 +86,7 @@ func (db *database) init() error {
|
||||
think_level TEXT NOT NULL DEFAULT '',
|
||||
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
schema_version INTEGER NOT NULL DEFAULT %d
|
||||
);
|
||||
|
||||
@@ -257,6 +258,12 @@ func (db *database) migrate() error {
|
||||
return fmt.Errorf("migrate v13 to v14: %w", err)
|
||||
}
|
||||
version = 14
|
||||
case 14:
|
||||
// add auto_update_enabled column to settings table
|
||||
if err := db.migrateV14ToV15(); err != nil {
|
||||
return fmt.Errorf("migrate v14 to v15: %w", err)
|
||||
}
|
||||
version = 15
|
||||
default:
|
||||
// If we have a version we don't recognize, just set it to current
|
||||
// This might happen during development
|
||||
@@ -496,6 +503,21 @@ func (db *database) migrateV13ToV14() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV14ToV15 adds the auto_update_enabled column to the settings table
|
||||
func (db *database) migrateV14ToV15() error {
|
||||
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
|
||||
if err != nil && !duplicateColumnError(err) {
|
||||
return fmt.Errorf("add auto_update_enabled column: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 15`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||
func (db *database) cleanupOrphanedData() error {
|
||||
_, err := db.conn.Exec(`
|
||||
@@ -526,19 +548,11 @@ func (db *database) cleanupOrphanedData() error {
|
||||
}
|
||||
|
||||
func duplicateColumnError(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
||||
}
|
||||
return false
|
||||
return err != nil && strings.Contains(err.Error(), "duplicate column name")
|
||||
}
|
||||
|
||||
func columnNotExists(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "no such column")
|
||||
}
|
||||
return false
|
||||
return err != nil && strings.Contains(err.Error(), "no such column")
|
||||
}
|
||||
|
||||
func (db *database) getAllChats() ([]Chat, error) {
|
||||
@@ -1152,9 +1166,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||
var s Settings
|
||||
|
||||
err := db.conn.QueryRow(`
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
||||
FROM settings
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
||||
if err != nil {
|
||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||
}
|
||||
@@ -1164,9 +1178,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||
|
||||
func (db *database) setSettings(s Settings) error {
|
||||
_, err := db.conn.Exec(`
|
||||
UPDATE settings
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||
UPDATE settings
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set settings: %w", err)
|
||||
}
|
||||
|
||||
@@ -166,6 +166,9 @@ type Settings struct {
|
||||
|
||||
// SidebarOpen indicates if the chat sidebar is open
|
||||
SidebarOpen bool
|
||||
|
||||
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
||||
AutoUpdateEnabled bool
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
|
||||
@@ -414,6 +414,7 @@ export class Settings {
|
||||
ThinkLevel: string;
|
||||
SelectedModel: string;
|
||||
SidebarOpen: boolean;
|
||||
AutoUpdateEnabled: boolean;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
@@ -431,6 +432,7 @@ export class Settings {
|
||||
this.ThinkLevel = source["ThinkLevel"];
|
||||
this.SelectedModel = source["SelectedModel"];
|
||||
this.SidebarOpen = source["SidebarOpen"];
|
||||
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
||||
}
|
||||
}
|
||||
export class SettingsResponse {
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
XMarkIcon,
|
||||
CogIcon,
|
||||
ArrowLeftIcon,
|
||||
ArrowDownTrayIcon,
|
||||
} from "@heroicons/react/20/solid";
|
||||
import { Settings as SettingsType } from "@/gotypes";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
@@ -440,6 +441,29 @@ export default function Settings() {
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
{/* Auto Update */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="flex items-start space-x-3 flex-1">
|
||||
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||
<div>
|
||||
<Label>Auto-download updates</Label>
|
||||
<Description>
|
||||
{settings.AutoUpdateEnabled
|
||||
? "Automatically download updates when available."
|
||||
: "Updates will not be downloaded automatically."}
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<Switch
|
||||
checked={settings.AutoUpdateEnabled}
|
||||
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
{/* Expose Ollama */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
|
||||
23
app/ui/ui.go
23
app/ui/ui.go
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/app/tools"
|
||||
"github.com/ollama/ollama/app/types/not"
|
||||
"github.com/ollama/ollama/app/ui/responses"
|
||||
"github.com/ollama/ollama/app/updater"
|
||||
"github.com/ollama/ollama/app/version"
|
||||
ollamaAuth "github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -106,6 +107,10 @@ type Server struct {
|
||||
|
||||
// Dev is true if the server is running in development mode
|
||||
Dev bool
|
||||
|
||||
// Updater for checking and downloading updates
|
||||
Updater *updater.Updater
|
||||
UpdateAvailableFunc func()
|
||||
}
|
||||
|
||||
func (s *Server) log() *slog.Logger {
|
||||
@@ -1447,6 +1452,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||
return fmt.Errorf("failed to save settings: %w", err)
|
||||
}
|
||||
|
||||
// Handle auto-update toggle changes
|
||||
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
|
||||
if !settings.AutoUpdateEnabled {
|
||||
// Auto-update disabled: cancel any ongoing download
|
||||
if s.Updater != nil {
|
||||
s.Updater.CancelOngoingDownload()
|
||||
}
|
||||
} else {
|
||||
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
|
||||
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
|
||||
s.UpdateAvailableFunc()
|
||||
} else if s.Updater != nil {
|
||||
// Trigger the background checker to run immediately
|
||||
s.Updater.TriggerImmediateCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if old.ContextLength != settings.ContextLength ||
|
||||
old.Models != settings.Models ||
|
||||
old.Expose != settings.Expose {
|
||||
|
||||
@@ -4,6 +4,7 @@ package ui
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -11,9 +12,11 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/app/store"
|
||||
"github.com/ollama/ollama/app/updater"
|
||||
)
|
||||
|
||||
func TestHandlePostApiSettings(t *testing.T) {
|
||||
@@ -629,3 +632,183 @@ func TestWebSearchToolRegistration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsToggleAutoUpdateOff_CancelsDownload(t *testing.T) {
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Start with auto-update enabled
|
||||
settings, err := testStore.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := testStore.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
upd := &updater.Updater{Store: &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
|
||||
}}
|
||||
defer upd.Store.Close()
|
||||
|
||||
// We can't easily mock CancelOngoingDownload, but we can verify
|
||||
// the full settings handler flow works without error
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {},
|
||||
Updater: upd,
|
||||
}
|
||||
|
||||
// Disable auto-update via settings API
|
||||
settings.AutoUpdateEnabled = false
|
||||
body, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := server.settings(rr, req); err != nil {
|
||||
t.Fatalf("settings() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Verify settings were saved with auto-update disabled
|
||||
saved, err := testStore.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if saved.AutoUpdateEnabled {
|
||||
t.Fatal("expected AutoUpdateEnabled to be false after toggle off")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsToggleAutoUpdateOn_WithPendingUpdate_ShowsNotification(t *testing.T) {
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Start with auto-update disabled
|
||||
settings, err := testStore.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := testStore.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Simulate that an update was previously downloaded
|
||||
oldVal := updater.UpdateDownloaded
|
||||
updater.UpdateDownloaded = true
|
||||
defer func() { updater.UpdateDownloaded = oldVal }()
|
||||
|
||||
var notificationCalled atomic.Bool
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {},
|
||||
UpdateAvailableFunc: func() {
|
||||
notificationCalled.Store(true)
|
||||
},
|
||||
}
|
||||
|
||||
// Re-enable auto-update via settings API
|
||||
settings.AutoUpdateEnabled = true
|
||||
body, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := server.settings(rr, req); err != nil {
|
||||
t.Fatalf("settings() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
if !notificationCalled.Load() {
|
||||
t.Fatal("expected UpdateAvailableFunc to be called when re-enabling with a downloaded update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsToggleAutoUpdateOn_NoPendingUpdate_TriggersCheck(t *testing.T) {
|
||||
testStore := &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Start with auto-update disabled
|
||||
settings, err := testStore.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := testStore.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Ensure no pending update - clear both the downloaded flag and the stage dir
|
||||
oldVal := updater.UpdateDownloaded
|
||||
updater.UpdateDownloaded = false
|
||||
defer func() { updater.UpdateDownloaded = oldVal }()
|
||||
|
||||
oldStageDir := updater.UpdateStageDir
|
||||
updater.UpdateStageDir = t.TempDir() // empty dir means IsUpdatePending() returns false
|
||||
defer func() { updater.UpdateStageDir = oldStageDir }()
|
||||
|
||||
upd := &updater.Updater{Store: &store.Store{
|
||||
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
|
||||
}}
|
||||
defer upd.Store.Close()
|
||||
|
||||
// Initialize the checkNow channel by starting (and immediately stopping) the checker
|
||||
// so TriggerImmediateCheck doesn't panic on nil channel
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
upd.StartBackgroundUpdaterChecker(ctx, func(string) error { return nil })
|
||||
defer cancel()
|
||||
|
||||
var notificationCalled atomic.Bool
|
||||
server := &Server{
|
||||
Store: testStore,
|
||||
Restart: func() {},
|
||||
Updater: upd,
|
||||
UpdateAvailableFunc: func() {
|
||||
notificationCalled.Store(true)
|
||||
},
|
||||
}
|
||||
|
||||
// Re-enable auto-update via settings API
|
||||
settings.AutoUpdateEnabled = true
|
||||
body, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := server.settings(rr, req); err != nil {
|
||||
t.Fatalf("settings() error = %v", err)
|
||||
}
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// UpdateAvailableFunc should NOT be called since there's no pending update
|
||||
if notificationCalled.Load() {
|
||||
t.Fatal("UpdateAvailableFunc should not be called when there is no pending update")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/app/store"
|
||||
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||
query := requestURL.Query()
|
||||
query.Add("os", runtime.GOOS)
|
||||
query.Add("arch", runtime.GOARCH)
|
||||
query.Add("version", version.Version)
|
||||
currentVersion := version.Version
|
||||
query.Add("version", currentVersion)
|
||||
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
|
||||
// The original macOS app used to use the device ID
|
||||
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||
}
|
||||
|
||||
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||
// Create a cancellable context for this download
|
||||
downloadCtx, cancel := context.WithCancel(ctx)
|
||||
u.cancelDownloadLock.Lock()
|
||||
u.cancelDownload = cancel
|
||||
u.cancelDownloadLock.Unlock()
|
||||
defer func() {
|
||||
u.cancelDownloadLock.Lock()
|
||||
u.cancelDownload = nil
|
||||
u.cancelDownloadLock.Unlock()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Do a head first to check etag info
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// In case of slow downloads, continue the update check in the background
|
||||
bgctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
bgctx, bgcancel := context.WithCancel(downloadCtx)
|
||||
defer bgcancel()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||
_, err = os.Stat(stageFilename)
|
||||
if err == nil {
|
||||
slog.Info("update already downloaded", "bundle", stageFilename)
|
||||
UpdateDownloaded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -244,33 +259,84 @@ func cleanupOldDownloads(stageDir string) {
|
||||
}
|
||||
|
||||
type Updater struct {
|
||||
Store *store.Store
|
||||
Store *store.Store
|
||||
cancelDownload context.CancelFunc
|
||||
cancelDownloadLock sync.Mutex
|
||||
checkNow chan struct{}
|
||||
}
|
||||
|
||||
// CancelOngoingDownload cancels any currently running download
|
||||
func (u *Updater) CancelOngoingDownload() {
|
||||
u.cancelDownloadLock.Lock()
|
||||
defer u.cancelDownloadLock.Unlock()
|
||||
if u.cancelDownload != nil {
|
||||
slog.Info("cancelling ongoing update download")
|
||||
u.cancelDownload()
|
||||
u.cancelDownload = nil
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerImmediateCheck signals the background checker to check for updates immediately
|
||||
func (u *Updater) TriggerImmediateCheck() {
|
||||
if u.checkNow != nil {
|
||||
select {
|
||||
case u.checkNow <- struct{}{}:
|
||||
default:
|
||||
// Check already pending, no need to queue another
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||
u.checkNow = make(chan struct{}, 1)
|
||||
go func() {
|
||||
// Don't blast an update message immediately after startup
|
||||
time.Sleep(UpdateCheckInitialDelay)
|
||||
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
||||
ticker := time.NewTicker(UpdateCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
if available {
|
||||
err := u.DownloadNewRelease(ctx, resp)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
|
||||
} else {
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("stopping background update checker")
|
||||
return
|
||||
default:
|
||||
time.Sleep(UpdateCheckInterval)
|
||||
case <-u.checkNow:
|
||||
// Immediate check triggered
|
||||
case <-ticker.C:
|
||||
// Regular interval check
|
||||
}
|
||||
|
||||
// Always check for updates
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
if !available {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update is available - check if auto-update is enabled for downloading
|
||||
settings, err := u.Store.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to load settings", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !settings.AutoUpdateEnabled {
|
||||
// Auto-update disabled - don't download, just log
|
||||
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
|
||||
continue
|
||||
}
|
||||
|
||||
// Auto-update is enabled - download
|
||||
err = u.DownloadNewRelease(ctx, resp)
|
||||
if err != nil {
|
||||
slog.Error("failed to download new release", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Download successful - show tray notification (regardless of toggle state)
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn("failed to register update available with tray", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -33,7 +35,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
|
||||
defer server.Close()
|
||||
slog.Debug("server", "url", server.URL)
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close() // Ensure database is closed
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
updatePresent, resp := updater.checkForUpdate(t.Context())
|
||||
@@ -84,8 +86,18 @@ func TestBackgoundChecker(t *testing.T) {
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
defer updater.Store.Close() // Ensure database is closed
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
settings, err := updater.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := updater.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
@@ -99,3 +111,264 @@ func TestBackgoundChecker(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
var downloadAttempted atomic.Bool
|
||||
done := make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||
UpdateCheckInterval = 5 * time.Millisecond
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
downloadAttempted.Store(true)
|
||||
buf := &bytes.Buffer{}
|
||||
zw := zip.NewWriter(buf)
|
||||
zw.Close()
|
||||
io.Copy(w, buf)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
// Ensure auto-update is disabled
|
||||
settings, err := updater.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := updater.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := func(ver string) error {
|
||||
t.Fatal("callback should not be called when auto-update is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait enough time for multiple check cycles
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
close(done)
|
||||
|
||||
if downloadAttempted.Load() {
|
||||
t.Fatal("download should not be attempted when auto-update is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoUpdateReenabledDownloadsUpdate(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
var downloadAttempted atomic.Bool
|
||||
callbackCalled := make(chan struct{}, 1)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||
UpdateCheckInterval = 5 * time.Millisecond
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
downloadAttempted.Store(true)
|
||||
buf := &bytes.Buffer{}
|
||||
zw := zip.NewWriter(buf)
|
||||
zw.Close()
|
||||
io.Copy(w, buf)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
upd := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer upd.Store.Close()
|
||||
|
||||
// Start with auto-update disabled
|
||||
settings, err := upd.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := upd.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := func(ver string) error {
|
||||
select {
|
||||
case callbackCalled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
upd.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait for a few cycles with auto-update disabled - no download should happen
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if downloadAttempted.Load() {
|
||||
t.Fatal("download should not happen while auto-update is disabled")
|
||||
}
|
||||
|
||||
// Re-enable auto-update
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := upd.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for the checker to pick it up and download
|
||||
select {
|
||||
case <-callbackCalled:
|
||||
// Success: download happened and callback was called after re-enabling
|
||||
if !downloadAttempted.Load() {
|
||||
t.Fatal("expected download to be attempted after re-enabling")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("expected download and callback after re-enabling auto-update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelOngoingDownload(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
downloadStarted := make(chan struct{})
|
||||
downloadCancelled := make(chan struct{})
|
||||
|
||||
ctx := t.Context()
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", "1000000")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
// Signal that download has started
|
||||
close(downloadStarted)
|
||||
// Wait for cancellation or timeout
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
close(downloadCancelled)
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("download was not cancelled in time")
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
_, resp := updater.checkForUpdate(ctx)
|
||||
|
||||
// Start download in goroutine
|
||||
go func() {
|
||||
_ = updater.DownloadNewRelease(ctx, resp)
|
||||
}()
|
||||
|
||||
// Wait for download to start
|
||||
select {
|
||||
case <-downloadStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("download did not start in time")
|
||||
}
|
||||
|
||||
// Cancel the download
|
||||
updater.CancelOngoingDownload()
|
||||
|
||||
// Verify cancellation was received
|
||||
select {
|
||||
case <-downloadCancelled:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("download cancellation was not received by server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriggerImmediateCheck(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
checkCount := atomic.Int32{}
|
||||
checkDone := make(chan struct{}, 10)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
// Set a very long interval so only TriggerImmediateCheck causes checks
|
||||
UpdateCheckInitialDelay = 1 * time.Millisecond
|
||||
UpdateCheckInterval = 1 * time.Hour
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
checkCount.Add(1)
|
||||
select {
|
||||
case checkDone <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
// Return no update available
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
cb := func(ver string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait for goroutine to start and pass initial delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// With 1 hour interval, no check should have happened yet
|
||||
initialCount := checkCount.Load()
|
||||
|
||||
// Trigger immediate check
|
||||
updater.TriggerImmediateCheck()
|
||||
|
||||
// Wait for the triggered check
|
||||
select {
|
||||
case <-checkDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("triggered check did not happen")
|
||||
}
|
||||
|
||||
finalCount := checkCount.Load()
|
||||
if finalCount <= initialCount {
|
||||
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
|
||||
// const ERROR_SUCCESS syscall.Errno = 0
|
||||
|
||||
// t.muMenus.RLock()
|
||||
// menu := uintptr(t.menus[parentId])
|
||||
// t.muMenus.RUnlock()
|
||||
// res, _, err := pRemoveMenu.Call(
|
||||
// menu,
|
||||
// uintptr(menuItemId),
|
||||
// MF_BYCOMMAND,
|
||||
// )
|
||||
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
||||
// return err
|
||||
// }
|
||||
// t.delFromVisibleItems(parentId, menuItemId)
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (t *winTray) showMenu() error {
|
||||
p := point{}
|
||||
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
||||
|
||||
@@ -51,7 +51,6 @@ const (
|
||||
IMAGE_ICON = 1 // Loads an icon
|
||||
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
||||
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
||||
MF_BYCOMMAND = 0x00000000
|
||||
MFS_DISABLED = 0x00000003
|
||||
MFT_SEPARATOR = 0x00000800
|
||||
MFT_STRING = 0x00000000
|
||||
|
||||
Reference in New Issue
Block a user