diff --git a/frontend/tests/factories/license.ts b/frontend/tests/factories/license.ts new file mode 100644 index 000000000..75b99a321 --- /dev/null +++ b/frontend/tests/factories/license.ts @@ -0,0 +1,35 @@ +import {Factory} from '../support/factory' + +export class LicenseFactory extends Factory { + static table = 'license_status' + + static async enable(features: string[]) { + const now = new Date().toISOString() + const response = JSON.stringify({ + valid: true, + features, + max_users: 0, + expires_at: '2099-01-01T00:00:00Z', + }) + await this.seed(this.table, [{ + id: 1, + instance_id: '00000000-0000-0000-0000-000000000000', + response, + validated_at: now, + created: now, + updated: now, + }]) + } + + static async disable() { + const now = new Date().toISOString() + await this.seed(this.table, [{ + id: 1, + instance_id: '00000000-0000-0000-0000-000000000000', + response: '{}', + validated_at: null, + created: now, + updated: now, + }]) + } +} diff --git a/pkg/db/fixtures/license_status.yml b/pkg/db/fixtures/license_status.yml new file mode 100644 index 000000000..fb850cb24 --- /dev/null +++ b/pkg/db/fixtures/license_status.yml @@ -0,0 +1,7 @@ +- + id: 1 + instance_id: '00000000-0000-0000-0000-000000000000' + response: '{}' + validated_at: null + created: 2026-01-01 00:00:00 + updated: 2026-01-01 00:00:00 diff --git a/pkg/license/license.go b/pkg/license/license.go index 59f95748a..a088f01ed 100644 --- a/pkg/license/license.go +++ b/pkg/license/license.go @@ -42,6 +42,7 @@ import ( "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/log" + "code.vikunja.io/api/pkg/modules/keyvalue" "code.vikunja.io/api/pkg/user" "github.com/google/uuid" @@ -121,20 +122,20 @@ func (Status) TableName() string { return "license_status" } -// state holds the current in-memory license state. +// state is persisted through keyvalue so all replicas share the same activation status. type state struct { - mu sync.RWMutex - licensed bool - features map[Feature]bool - maxUsers int64 - expiresAt time.Time - lastCheckFailed bool + Licensed bool + Features map[Feature]bool + MaxUsers int64 + ExpiresAt time.Time + LastCheckFailed bool } +const stateKey = "license.state" + var ( - currentState = &state{ - features: make(map[Feature]bool), - } + // stateMu serialises read-modify-write cycles against the keyvalue store within a single replica; across replicas the license server is authoritative. + stateMu sync.Mutex stopCh chan struct{} instanceID string ) @@ -151,9 +152,12 @@ func Init() { log.Fatalf("Could not initialize license system: %s", err) } - // No license key configured — free mode + // No license key configured — free mode. Clear any state persisted by a + // previous run (e.g. via Redis) so a removed key can't leave stale + // Licensed=true entitlements behind. if key == "" { log.Debugf("No license key configured.") + degradeToFree("No license key configured.") return } @@ -193,42 +197,102 @@ func Init() { go backgroundLoop(key) } -// EnabledProFeatures returns the string keys of all currently enabled licensed features. -// Returns an empty slice in free mode. -func EnabledProFeatures() []string { - currentState.mu.RLock() - defer currentState.mu.RUnlock() - if !currentState.licensed { - return []string{} +// SetForTests enables the given features. Pair with ResetForTests to avoid bleeding state between tests. +func SetForTests(features []Feature) { + feats := make([]Feature, 0, len(features)) + feats = append(feats, features...) + applyResponse(&Response{ + Valid: true, + Features: feats, + ExpiresAt: time.Now().Add(365 * 24 * time.Hour), + }) +} + +func ResetForTests() { + degradeToFree("reset for tests") +} + +// ReloadFromCache applies the cached license_status row; empty or missing cache degrades to free mode. +func ReloadFromCache() error { + cached, err := loadCachedStatus() + if err != nil { + return err } - out := make([]string, 0, len(currentState.features)) - for f, on := range currentState.features { + if cached == nil || cached.Response == "" || cached.Response == "{}" { + degradeToFree("License cache is empty.") + return nil + } + return applyFromCache(cached) +} + +type Info struct { + Licensed bool `json:"licensed"` + InstanceID string `json:"instance_id"` + Features []string `json:"features"` + MaxUsers int64 `json:"max_users"` + ExpiresAt time.Time `json:"expires_at"` + ValidatedAt time.Time `json:"validated_at"` + LastCheckFailed bool `json:"last_check_failed"` +} + +// CurrentInfo returns a snapshot of the current license state; on DB errors it omits the cache-backed fields rather than failing. +func CurrentInfo() Info { + st := loadState() + info := Info{ + Licensed: st.Licensed, + InstanceID: instanceID, + Features: make([]string, 0, len(st.Features)), + MaxUsers: st.MaxUsers, + ExpiresAt: st.ExpiresAt, + LastCheckFailed: st.LastCheckFailed, + } + for f, on := range st.Features { if !on { continue } - name := f.String() - out = append(out, name) + info.Features = append(info.Features, f.String()) } - sort.Strings(out) + sort.Strings(info.Features) + + if cached, err := loadCachedStatus(); err == nil && cached != nil { + info.ValidatedAt = cached.ValidatedAt + } + return info +} + +// EnabledProFeatures returns enabled features (empty slice in free mode); Feature values marshal to their JSON string key. +func EnabledProFeatures() []Feature { + st := loadState() + if !st.Licensed { + return []Feature{} + } + out := make([]Feature, 0, len(st.Features)) + for f, on := range st.Features { + if !on { + continue + } + out = append(out, f) + } + sort.Slice(out, func(i, j int) bool { + return out[i].String() < out[j].String() + }) return out } // IsFeatureEnabled returns whether a specific licensed feature is enabled. func IsFeatureEnabled(feature Feature) bool { - currentState.mu.RLock() - defer currentState.mu.RUnlock() - if !currentState.licensed { + st := loadState() + if !st.Licensed { return false } - return currentState.features[feature] + return st.Features[feature] } // MaxUsersReached returns whether the licensed user limit has been reached. // Returns false in free mode (no limit). func MaxUsersReached() bool { - currentState.mu.RLock() - defer currentState.mu.RUnlock() - if !currentState.licensed || currentState.maxUsers <= 0 { + st := loadState() + if !st.Licensed || st.MaxUsers <= 0 { return false } @@ -241,7 +305,7 @@ func MaxUsersReached() bool { return false } - return count >= currentState.maxUsers + return count >= st.MaxUsers } // Shutdown stops the background license check goroutine. @@ -251,6 +315,30 @@ func Shutdown() { } } +// loadState returns state from keyvalue; missing or unreadable state degrades to a zero-value (free mode) snapshot. +func loadState() state { + st := state{Features: make(map[Feature]bool)} + exists, err := keyvalue.GetWithValue(stateKey, &st) + if err != nil { + log.Errorf("Error loading license state from keyvalue: %s", err) + return state{Features: make(map[Feature]bool)} + } + if !exists { + return state{Features: make(map[Feature]bool)} + } + if st.Features == nil { + st.Features = make(map[Feature]bool) + } + return st +} + +// saveState persists through keyvalue so state is visible to every replica, not just the one that performed the check. +func saveState(st state) { + if err := keyvalue.Put(stateKey, st); err != nil { + log.Errorf("Error saving license state to keyvalue: %s", err) + } +} + func loadOrCreateInstanceID() (string, error) { s := db.NewSession() defer s.Close() @@ -297,20 +385,23 @@ func loadCachedStatus() (*Status, error) { } func applyResponse(resp *Response) { - currentState.mu.Lock() - defer currentState.mu.Unlock() + stateMu.Lock() + defer stateMu.Unlock() - currentState.licensed = true - currentState.features = make(map[Feature]bool) + st := state{ + Licensed: true, + Features: make(map[Feature]bool), + MaxUsers: resp.MaxUsers, + ExpiresAt: resp.ExpiresAt, + LastCheckFailed: false, + } for _, f := range resp.Features { if f == FeatureUnknown { continue } - currentState.features[f] = true + st.Features[f] = true } - currentState.maxUsers = resp.MaxUsers - currentState.expiresAt = resp.ExpiresAt - currentState.lastCheckFailed = false + saveState(st) } func applyFromCache(cached *Status) error { @@ -323,17 +414,29 @@ func applyFromCache(cached *Status) error { } func degradeToFree(reason string) { - currentState.mu.Lock() - defer currentState.mu.Unlock() + stateMu.Lock() + defer stateMu.Unlock() - currentState.licensed = false - currentState.features = make(map[Feature]bool) - currentState.maxUsers = 0 - currentState.lastCheckFailed = true + saveState(state{ + Licensed: false, + Features: make(map[Feature]bool), + MaxUsers: 0, + LastCheckFailed: true, + }) log.Warningf("%s Pro features have been disabled.", reason) } +// markCheckFailed flips LastCheckFailed while preserving other fields so cached-valid replicas still serve requests. +func markCheckFailed() { + stateMu.Lock() + defer stateMu.Unlock() + + st := loadState() + st.LastCheckFailed = true + saveState(st) +} + func cacheResponse(resp *Response) error { raw, err := serializeResponse(resp) if err != nil { @@ -343,7 +446,6 @@ func cacheResponse(resp *Response) error { s := db.NewSession() defer s.Close() - // Update the existing row _, err = s.Where("instance_id = ?", instanceID).Update(&Status{ Response: raw, ValidatedAt: time.Now(), @@ -358,13 +460,13 @@ func cacheResponse(resp *Response) error { func backgroundLoop(key string) { for { interval := 24 * time.Hour - currentState.mu.RLock() - if currentState.lastCheckFailed { + st := loadState() + switch { + case st.LastCheckFailed: interval = 1 * time.Hour - } else if !currentState.expiresAt.IsZero() && time.Until(currentState.expiresAt) < 72*time.Hour { + case !st.ExpiresAt.IsZero() && time.Until(st.ExpiresAt) < 72*time.Hour: interval = 1 * time.Hour } - currentState.mu.RUnlock() select { case <-stopCh: @@ -375,23 +477,19 @@ func backgroundLoop(key string) { log.Debugf("Running background license check...") resp, err := checkLicense(key) if err != nil { - // Servers unreachable log.Debugf("Background license check failed: %s", err) cached, cacheErr := loadCachedStatus() if cacheErr != nil || cached == nil || time.Since(cached.ValidatedAt) >= 72*time.Hour { degradeToFree("License cache expired and no license server is reachable.") log.Warningf("Next retry in 1 hour.") } else { - currentState.mu.Lock() - currentState.lastCheckFailed = true - currentState.mu.Unlock() + markCheckFailed() log.Warningf("License check failed, using cached validation from %s. Next retry in 1 hour.", cached.ValidatedAt.Format(time.RFC3339)) } continue } if !resp.Valid { - // Clear cache if err := clearCache(); err != nil { log.Errorf("Error clearing license cache: %s", err) } @@ -399,11 +497,8 @@ func backgroundLoop(key string) { continue } - // Success - wasFailure := false - currentState.mu.RLock() - wasFailure = currentState.lastCheckFailed || !currentState.licensed - currentState.mu.RUnlock() + prev := loadState() + wasFailure := prev.LastCheckFailed || !prev.Licensed applyResponse(resp) if err := cacheResponse(resp); err != nil { diff --git a/pkg/routes/api/v1/testing.go b/pkg/routes/api/v1/testing.go index 1b86ea86f..98f5aeca1 100644 --- a/pkg/routes/api/v1/testing.go +++ b/pkg/routes/api/v1/testing.go @@ -24,6 +24,7 @@ import ( "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/events" + "code.vikunja.io/api/pkg/license" "code.vikunja.io/api/pkg/log" "github.com/labstack/echo/v5" @@ -100,6 +101,17 @@ func HandleTesting(c *echo.Context) error { }) } + // License state is cached at startup; re-apply so tests take effect without a restart. + if table == "license_status" { + if err := license.ReloadFromCache(); err != nil { + log.Errorf("Error reloading license from seeded cache: %v", err) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ + "error": true, + "message": err.Error(), + }) + } + } + s := db.NewSession() defer s.Close() data := []map[string]interface{}{} @@ -140,6 +152,11 @@ func HandleTestingTruncateAll(c *echo.Context) error { }) } + // Reload after truncate; otherwise features enabled by a prior test outlive the now-empty license_status table. + if err := license.ReloadFromCache(); err != nil { + log.Errorf("Error reloading license after truncate: %v", err) + } + return c.JSON(http.StatusOK, map[string]string{ "message": "ok", })