mirror of
https://github.com/go-vikunja/vikunja.git
synced 2026-05-10 15:15:41 -05:00
feat(license): add runtime state snapshot and reload helpers
This commit is contained in:
35
frontend/tests/factories/license.ts
Normal file
35
frontend/tests/factories/license.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import {Factory} from '../support/factory'
|
||||
|
||||
export class LicenseFactory extends Factory {
|
||||
static table = 'license_status'
|
||||
|
||||
static async enable(features: string[]) {
|
||||
const now = new Date().toISOString()
|
||||
const response = JSON.stringify({
|
||||
valid: true,
|
||||
features,
|
||||
max_users: 0,
|
||||
expires_at: '2099-01-01T00:00:00Z',
|
||||
})
|
||||
await this.seed(this.table, [{
|
||||
id: 1,
|
||||
instance_id: '00000000-0000-0000-0000-000000000000',
|
||||
response,
|
||||
validated_at: now,
|
||||
created: now,
|
||||
updated: now,
|
||||
}])
|
||||
}
|
||||
|
||||
static async disable() {
|
||||
const now = new Date().toISOString()
|
||||
await this.seed(this.table, [{
|
||||
id: 1,
|
||||
instance_id: '00000000-0000-0000-0000-000000000000',
|
||||
response: '{}',
|
||||
validated_at: null,
|
||||
created: now,
|
||||
updated: now,
|
||||
}])
|
||||
}
|
||||
}
|
||||
7
pkg/db/fixtures/license_status.yml
Normal file
7
pkg/db/fixtures/license_status.yml
Normal file
@@ -0,0 +1,7 @@
|
||||
-
|
||||
id: 1
|
||||
instance_id: '00000000-0000-0000-0000-000000000000'
|
||||
response: '{}'
|
||||
validated_at: null
|
||||
created: 2026-01-01 00:00:00
|
||||
updated: 2026-01-01 00:00:00
|
||||
@@ -42,6 +42,7 @@ import (
|
||||
"code.vikunja.io/api/pkg/config"
|
||||
"code.vikunja.io/api/pkg/db"
|
||||
"code.vikunja.io/api/pkg/log"
|
||||
"code.vikunja.io/api/pkg/modules/keyvalue"
|
||||
"code.vikunja.io/api/pkg/user"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -121,20 +122,20 @@ func (Status) TableName() string {
|
||||
return "license_status"
|
||||
}
|
||||
|
||||
// state holds the current in-memory license state.
|
||||
// state is persisted through keyvalue so all replicas share the same activation status.
|
||||
type state struct {
|
||||
mu sync.RWMutex
|
||||
licensed bool
|
||||
features map[Feature]bool
|
||||
maxUsers int64
|
||||
expiresAt time.Time
|
||||
lastCheckFailed bool
|
||||
Licensed bool
|
||||
Features map[Feature]bool
|
||||
MaxUsers int64
|
||||
ExpiresAt time.Time
|
||||
LastCheckFailed bool
|
||||
}
|
||||
|
||||
const stateKey = "license.state"
|
||||
|
||||
var (
|
||||
currentState = &state{
|
||||
features: make(map[Feature]bool),
|
||||
}
|
||||
// stateMu serialises read-modify-write cycles against the keyvalue store within a single replica; across replicas the license server is authoritative.
|
||||
stateMu sync.Mutex
|
||||
stopCh chan struct{}
|
||||
instanceID string
|
||||
)
|
||||
@@ -151,9 +152,12 @@ func Init() {
|
||||
log.Fatalf("Could not initialize license system: %s", err)
|
||||
}
|
||||
|
||||
// No license key configured — free mode
|
||||
// No license key configured — free mode. Clear any state persisted by a
|
||||
// previous run (e.g. via Redis) so a removed key can't leave stale
|
||||
// Licensed=true entitlements behind.
|
||||
if key == "" {
|
||||
log.Debugf("No license key configured.")
|
||||
degradeToFree("No license key configured.")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -193,42 +197,102 @@ func Init() {
|
||||
go backgroundLoop(key)
|
||||
}
|
||||
|
||||
// EnabledProFeatures returns the string keys of all currently enabled licensed features.
|
||||
// Returns an empty slice in free mode.
|
||||
func EnabledProFeatures() []string {
|
||||
currentState.mu.RLock()
|
||||
defer currentState.mu.RUnlock()
|
||||
if !currentState.licensed {
|
||||
return []string{}
|
||||
// SetForTests enables the given features. Pair with ResetForTests to avoid bleeding state between tests.
|
||||
func SetForTests(features []Feature) {
|
||||
feats := make([]Feature, 0, len(features))
|
||||
feats = append(feats, features...)
|
||||
applyResponse(&Response{
|
||||
Valid: true,
|
||||
Features: feats,
|
||||
ExpiresAt: time.Now().Add(365 * 24 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
func ResetForTests() {
|
||||
degradeToFree("reset for tests")
|
||||
}
|
||||
|
||||
// ReloadFromCache applies the cached license_status row; empty or missing cache degrades to free mode.
|
||||
func ReloadFromCache() error {
|
||||
cached, err := loadCachedStatus()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out := make([]string, 0, len(currentState.features))
|
||||
for f, on := range currentState.features {
|
||||
if cached == nil || cached.Response == "" || cached.Response == "{}" {
|
||||
degradeToFree("License cache is empty.")
|
||||
return nil
|
||||
}
|
||||
return applyFromCache(cached)
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
Licensed bool `json:"licensed"`
|
||||
InstanceID string `json:"instance_id"`
|
||||
Features []string `json:"features"`
|
||||
MaxUsers int64 `json:"max_users"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
ValidatedAt time.Time `json:"validated_at"`
|
||||
LastCheckFailed bool `json:"last_check_failed"`
|
||||
}
|
||||
|
||||
// CurrentInfo returns a snapshot of the current license state; on DB errors it omits the cache-backed fields rather than failing.
|
||||
func CurrentInfo() Info {
|
||||
st := loadState()
|
||||
info := Info{
|
||||
Licensed: st.Licensed,
|
||||
InstanceID: instanceID,
|
||||
Features: make([]string, 0, len(st.Features)),
|
||||
MaxUsers: st.MaxUsers,
|
||||
ExpiresAt: st.ExpiresAt,
|
||||
LastCheckFailed: st.LastCheckFailed,
|
||||
}
|
||||
for f, on := range st.Features {
|
||||
if !on {
|
||||
continue
|
||||
}
|
||||
name := f.String()
|
||||
out = append(out, name)
|
||||
info.Features = append(info.Features, f.String())
|
||||
}
|
||||
sort.Strings(out)
|
||||
sort.Strings(info.Features)
|
||||
|
||||
if cached, err := loadCachedStatus(); err == nil && cached != nil {
|
||||
info.ValidatedAt = cached.ValidatedAt
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// EnabledProFeatures returns enabled features (empty slice in free mode); Feature values marshal to their JSON string key.
|
||||
func EnabledProFeatures() []Feature {
|
||||
st := loadState()
|
||||
if !st.Licensed {
|
||||
return []Feature{}
|
||||
}
|
||||
out := make([]Feature, 0, len(st.Features))
|
||||
for f, on := range st.Features {
|
||||
if !on {
|
||||
continue
|
||||
}
|
||||
out = append(out, f)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
return out[i].String() < out[j].String()
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
// IsFeatureEnabled returns whether a specific licensed feature is enabled.
|
||||
func IsFeatureEnabled(feature Feature) bool {
|
||||
currentState.mu.RLock()
|
||||
defer currentState.mu.RUnlock()
|
||||
if !currentState.licensed {
|
||||
st := loadState()
|
||||
if !st.Licensed {
|
||||
return false
|
||||
}
|
||||
return currentState.features[feature]
|
||||
return st.Features[feature]
|
||||
}
|
||||
|
||||
// MaxUsersReached returns whether the licensed user limit has been reached.
|
||||
// Returns false in free mode (no limit).
|
||||
func MaxUsersReached() bool {
|
||||
currentState.mu.RLock()
|
||||
defer currentState.mu.RUnlock()
|
||||
if !currentState.licensed || currentState.maxUsers <= 0 {
|
||||
st := loadState()
|
||||
if !st.Licensed || st.MaxUsers <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -241,7 +305,7 @@ func MaxUsersReached() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
return count >= currentState.maxUsers
|
||||
return count >= st.MaxUsers
|
||||
}
|
||||
|
||||
// Shutdown stops the background license check goroutine.
|
||||
@@ -251,6 +315,30 @@ func Shutdown() {
|
||||
}
|
||||
}
|
||||
|
||||
// loadState returns state from keyvalue; missing or unreadable state degrades to a zero-value (free mode) snapshot.
|
||||
func loadState() state {
|
||||
st := state{Features: make(map[Feature]bool)}
|
||||
exists, err := keyvalue.GetWithValue(stateKey, &st)
|
||||
if err != nil {
|
||||
log.Errorf("Error loading license state from keyvalue: %s", err)
|
||||
return state{Features: make(map[Feature]bool)}
|
||||
}
|
||||
if !exists {
|
||||
return state{Features: make(map[Feature]bool)}
|
||||
}
|
||||
if st.Features == nil {
|
||||
st.Features = make(map[Feature]bool)
|
||||
}
|
||||
return st
|
||||
}
|
||||
|
||||
// saveState persists through keyvalue so state is visible to every replica, not just the one that performed the check.
|
||||
func saveState(st state) {
|
||||
if err := keyvalue.Put(stateKey, st); err != nil {
|
||||
log.Errorf("Error saving license state to keyvalue: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func loadOrCreateInstanceID() (string, error) {
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
@@ -297,20 +385,23 @@ func loadCachedStatus() (*Status, error) {
|
||||
}
|
||||
|
||||
func applyResponse(resp *Response) {
|
||||
currentState.mu.Lock()
|
||||
defer currentState.mu.Unlock()
|
||||
stateMu.Lock()
|
||||
defer stateMu.Unlock()
|
||||
|
||||
currentState.licensed = true
|
||||
currentState.features = make(map[Feature]bool)
|
||||
st := state{
|
||||
Licensed: true,
|
||||
Features: make(map[Feature]bool),
|
||||
MaxUsers: resp.MaxUsers,
|
||||
ExpiresAt: resp.ExpiresAt,
|
||||
LastCheckFailed: false,
|
||||
}
|
||||
for _, f := range resp.Features {
|
||||
if f == FeatureUnknown {
|
||||
continue
|
||||
}
|
||||
currentState.features[f] = true
|
||||
st.Features[f] = true
|
||||
}
|
||||
currentState.maxUsers = resp.MaxUsers
|
||||
currentState.expiresAt = resp.ExpiresAt
|
||||
currentState.lastCheckFailed = false
|
||||
saveState(st)
|
||||
}
|
||||
|
||||
func applyFromCache(cached *Status) error {
|
||||
@@ -323,17 +414,29 @@ func applyFromCache(cached *Status) error {
|
||||
}
|
||||
|
||||
func degradeToFree(reason string) {
|
||||
currentState.mu.Lock()
|
||||
defer currentState.mu.Unlock()
|
||||
stateMu.Lock()
|
||||
defer stateMu.Unlock()
|
||||
|
||||
currentState.licensed = false
|
||||
currentState.features = make(map[Feature]bool)
|
||||
currentState.maxUsers = 0
|
||||
currentState.lastCheckFailed = true
|
||||
saveState(state{
|
||||
Licensed: false,
|
||||
Features: make(map[Feature]bool),
|
||||
MaxUsers: 0,
|
||||
LastCheckFailed: true,
|
||||
})
|
||||
|
||||
log.Warningf("%s Pro features have been disabled.", reason)
|
||||
}
|
||||
|
||||
// markCheckFailed flips LastCheckFailed while preserving other fields so cached-valid replicas still serve requests.
|
||||
func markCheckFailed() {
|
||||
stateMu.Lock()
|
||||
defer stateMu.Unlock()
|
||||
|
||||
st := loadState()
|
||||
st.LastCheckFailed = true
|
||||
saveState(st)
|
||||
}
|
||||
|
||||
func cacheResponse(resp *Response) error {
|
||||
raw, err := serializeResponse(resp)
|
||||
if err != nil {
|
||||
@@ -343,7 +446,6 @@ func cacheResponse(resp *Response) error {
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
// Update the existing row
|
||||
_, err = s.Where("instance_id = ?", instanceID).Update(&Status{
|
||||
Response: raw,
|
||||
ValidatedAt: time.Now(),
|
||||
@@ -358,13 +460,13 @@ func cacheResponse(resp *Response) error {
|
||||
func backgroundLoop(key string) {
|
||||
for {
|
||||
interval := 24 * time.Hour
|
||||
currentState.mu.RLock()
|
||||
if currentState.lastCheckFailed {
|
||||
st := loadState()
|
||||
switch {
|
||||
case st.LastCheckFailed:
|
||||
interval = 1 * time.Hour
|
||||
} else if !currentState.expiresAt.IsZero() && time.Until(currentState.expiresAt) < 72*time.Hour {
|
||||
case !st.ExpiresAt.IsZero() && time.Until(st.ExpiresAt) < 72*time.Hour:
|
||||
interval = 1 * time.Hour
|
||||
}
|
||||
currentState.mu.RUnlock()
|
||||
|
||||
select {
|
||||
case <-stopCh:
|
||||
@@ -375,23 +477,19 @@ func backgroundLoop(key string) {
|
||||
log.Debugf("Running background license check...")
|
||||
resp, err := checkLicense(key)
|
||||
if err != nil {
|
||||
// Servers unreachable
|
||||
log.Debugf("Background license check failed: %s", err)
|
||||
cached, cacheErr := loadCachedStatus()
|
||||
if cacheErr != nil || cached == nil || time.Since(cached.ValidatedAt) >= 72*time.Hour {
|
||||
degradeToFree("License cache expired and no license server is reachable.")
|
||||
log.Warningf("Next retry in 1 hour.")
|
||||
} else {
|
||||
currentState.mu.Lock()
|
||||
currentState.lastCheckFailed = true
|
||||
currentState.mu.Unlock()
|
||||
markCheckFailed()
|
||||
log.Warningf("License check failed, using cached validation from %s. Next retry in 1 hour.", cached.ValidatedAt.Format(time.RFC3339))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !resp.Valid {
|
||||
// Clear cache
|
||||
if err := clearCache(); err != nil {
|
||||
log.Errorf("Error clearing license cache: %s", err)
|
||||
}
|
||||
@@ -399,11 +497,8 @@ func backgroundLoop(key string) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Success
|
||||
wasFailure := false
|
||||
currentState.mu.RLock()
|
||||
wasFailure = currentState.lastCheckFailed || !currentState.licensed
|
||||
currentState.mu.RUnlock()
|
||||
prev := loadState()
|
||||
wasFailure := prev.LastCheckFailed || !prev.Licensed
|
||||
|
||||
applyResponse(resp)
|
||||
if err := cacheResponse(resp); err != nil {
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"code.vikunja.io/api/pkg/config"
|
||||
"code.vikunja.io/api/pkg/db"
|
||||
"code.vikunja.io/api/pkg/events"
|
||||
"code.vikunja.io/api/pkg/license"
|
||||
"code.vikunja.io/api/pkg/log"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
@@ -100,6 +101,17 @@ func HandleTesting(c *echo.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
// License state is cached at startup; re-apply so tests take effect without a restart.
|
||||
if table == "license_status" {
|
||||
if err := license.ReloadFromCache(); err != nil {
|
||||
log.Errorf("Error reloading license from seeded cache: %v", err)
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||
"error": true,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
data := []map[string]interface{}{}
|
||||
@@ -140,6 +152,11 @@ func HandleTestingTruncateAll(c *echo.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Reload after truncate; otherwise features enabled by a prior test outlive the now-empty license_status table.
|
||||
if err := license.ReloadFromCache(); err != nil {
|
||||
log.Errorf("Error reloading license after truncate: %v", err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"message": "ok",
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user