// Vikunja is a to-do list application to facilitate your life. // Copyright 2018-present Vikunja and contributors. All rights reserved. // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package db import ( "fmt" "net/url" "os" "path/filepath" "reflect" "runtime" "strconv" "strings" "time" "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/log" "xorm.io/builder" "xorm.io/xorm" "xorm.io/xorm/names" "xorm.io/xorm/schemas" _ "github.com/go-sql-driver/mysql" // Because. _ "github.com/lib/pq" // Because. _ "github.com/mattn/go-sqlite3" // Because. ) var ( // We only want one instance of the engine, so we can create it once and reuse it x *xorm.Engine // paradedbInstalled marks whether the paradedb extension is available // and can be used for full text search. paradedbInstalled bool ) // registeredTables holds all table beans registered by Vikunja packages. var registeredTables []interface{} // RegisterTables registers table beans so that Dump and WipeEverything // only operate on known Vikunja tables. func RegisterTables(tables []interface{}) { registeredTables = append(registeredTables, tables...) } // RegisteredTableNames returns the table names of all registered Vikunja tables. func RegisteredTableNames() []string { mapper := x.GetTableMapper() names := make([]string, 0, len(registeredTables)+1) for _, bean := range registeredTables { names = append(names, mapper.Obj2Table(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())) } // The xormigrate migration tracking table is not registered via GetTables() names = append(names, "migration") return names } // CreateDBEngine initializes a db engine from the config func CreateDBEngine() (engine *xorm.Engine, err error) { if x != nil { return x, nil } // If the database type is not set, this likely means we need to initialize the config first if config.DatabaseType.GetString() == "" { config.InitConfig() } // Use Mysql if set switch config.DatabaseType.GetString() { case "mysql": engine, err = initMysqlEngine() if err != nil { return } case "postgres": engine, err = initPostgresEngine() if err != nil { return } case "sqlite": // Otherwise use sqlite engine, err = initSqliteEngine() if err != nil { return } default: log.Fatalf("Unknown database type %s", config.DatabaseType.GetString()) } engine.SetTZLocation(config.GetTimeZone()) // Vikunja's timezone loc, err := time.LoadLocation("GMT") // The db data timezone if err != nil { log.Fatalf("Error parsing time zone: %s", err) } engine.SetTZDatabase(loc) engine.SetMapper(names.GonicMapper{}) logger := log.NewXormLogger(config.LogEnabled.GetBool(), config.LogDatabase.GetString(), config.LogDatabaseLevel.GetString(), config.LogFormat.GetString()) engine.SetLogger(logger) x = engine return } func initMysqlEngine() (engine *xorm.Engine, err error) { // We're using utf8mb here instead of just utf8 because we want to use non-BMP characters. // See https://stackoverflow.com/a/30074553/10924593 for more info. host := fmt.Sprintf("tcp(%s)", config.DatabaseHost.GetString()) if config.DatabaseHost.GetString()[0] == '/' { // looks like a unix socket host = fmt.Sprintf("unix(%s)", config.DatabaseHost.GetString()) } connStr := fmt.Sprintf( "%s:%s@%s/%s?charset=utf8mb4&parseTime=true&tls=%s", config.DatabaseUser.GetString(), config.DatabasePassword.GetString(), host, config.DatabaseDatabase.GetString(), config.DatabaseTLS.GetString()) engine, err = xorm.NewEngine("mysql", connStr) if err != nil { return } engine.SetMaxOpenConns(config.DatabaseMaxOpenConnections.GetInt()) engine.SetMaxIdleConns(config.DatabaseMaxIdleConnections.GetInt()) maxLifetime, err := time.ParseDuration(strconv.Itoa(config.DatabaseMaxConnectionLifetime.GetInt()) + `ms`) if err != nil { return } engine.SetConnMaxLifetime(maxLifetime) return } // parsePostgreSQLHostPort parses given input in various forms defined in // https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING // and returns proper host and port number. func parsePostgreSQLHostPort(info string) (string, string) { host, port := "127.0.0.1", "5432" if strings.Contains(info, ":") && !strings.HasSuffix(info, "]") { idx := strings.LastIndex(info, ":") host = info[:idx] port = info[idx+1:] } else if len(info) > 0 { host = info } return host, port } // Copied and adopted from https://github.com/go-gitea/gitea/blob/f337c32e868381c6d2d948221aca0c59f8420c13/modules/setting/database.go#L176-L186 func getPostgreSQLConnectionString(dbHost, dbUser, dbPasswd, dbName, dbSslMode, dbSslCert, dbSslKey, dbSslRootCert string) (connStr string) { dbParam := "?" if strings.Contains(dbName, dbParam) { dbParam = "&" } host, port := parsePostgreSQLHostPort(dbHost) if host[0] == '/' { // looks like a unix socket connStr = fmt.Sprintf("postgres://%s:%s@:%s/%s%ssslmode=%s&sslcert=%s&sslkey=%s&sslrootcert=%s&host=%s", url.PathEscape(dbUser), url.PathEscape(dbPasswd), port, dbName, dbParam, dbSslMode, dbSslCert, dbSslKey, dbSslRootCert, host) } else { connStr = fmt.Sprintf("postgres://%s:%s@%s:%s/%s%ssslmode=%s&sslcert=%s&sslkey=%s&sslrootcert=%s", url.PathEscape(dbUser), url.PathEscape(dbPasswd), host, port, dbName, dbParam, dbSslMode, dbSslCert, dbSslKey, dbSslRootCert) } return connStr } func initPostgresEngine() (engine *xorm.Engine, err error) { connStr := getPostgreSQLConnectionString( config.DatabaseHost.GetString(), config.DatabaseUser.GetString(), config.DatabasePassword.GetString(), config.DatabaseDatabase.GetString(), config.DatabaseSslMode.GetString(), config.DatabaseSslCert.GetString(), config.DatabaseSslKey.GetString(), config.DatabaseSslRootCert.GetString(), ) engine, err = xorm.NewEngine("postgres", connStr) if err != nil { return } engine.SetSchema(config.DatabaseSchema.GetString()) engine.SetMaxOpenConns(config.DatabaseMaxOpenConnections.GetInt()) engine.SetMaxIdleConns(config.DatabaseMaxIdleConnections.GetInt()) maxLifetime, err := time.ParseDuration(strconv.Itoa(config.DatabaseMaxConnectionLifetime.GetInt()) + `ms`) if err != nil { return } engine.SetConnMaxLifetime(maxLifetime) checkParadeDB(engine) return } // DatabasePathConfig holds configuration for database path resolution. // This struct allows the path resolution logic to be tested independently // of the global config package. type DatabasePathConfig struct { ConfiguredPath string // The database.path config value RootPath string // The service.rootpath config value ExecutablePath string // Directory of the executable binary } // resolveDatabasePath resolves a database path configuration to an absolute path. // // Resolution rules: // 1. If ConfiguredPath is "memory", returns "memory" (special case for in-memory DB) // 2. If ConfiguredPath is already absolute, returns it as-is (cleaned) // 3. If ConfiguredPath is relative: // a. If RootPath differs from ExecutablePath (explicitly configured), // joins with RootPath // b. Otherwise, joins with platform-specific user data directory // // The getUserDataDir parameter allows injecting a mock for testing. func resolveDatabasePath(cfg DatabasePathConfig, getUserDataDir func() (string, error)) (string, error) { if cfg.ConfiguredPath == "memory" { return "memory", nil } var path string switch { case filepath.IsAbs(cfg.ConfiguredPath): path = filepath.Clean(cfg.ConfiguredPath) case cfg.RootPath != cfg.ExecutablePath: path = filepath.Join(cfg.RootPath, cfg.ConfiguredPath) default: dataDir, err := getUserDataDir() if err != nil { log.Warningf("Could not get user data directory, falling back to rootpath: %v", err) path = filepath.Join(cfg.RootPath, cfg.ConfiguredPath) } else { path = filepath.Join(dataDir, cfg.ConfiguredPath) } } return filepath.Abs(path) } func initSqliteEngine() (engine *xorm.Engine, err error) { rootPath := config.ServiceRootpath.GetString() executablePath := rootPath if execPath, err := os.Executable(); err == nil { executablePath = filepath.Dir(execPath) } cfg := DatabasePathConfig{ ConfiguredPath: config.DatabasePath.GetString(), RootPath: rootPath, ExecutablePath: executablePath, } path, err := resolveDatabasePath(cfg, getUserDataDir) if err != nil { return nil, fmt.Errorf("could not resolve database path: %w", err) } if path == "memory" { // Use a temp file with WAL mode instead of in-memory shared cache. // Shared cache (file::memory:?cache=shared) uses table-level locking // where _busy_timeout is ineffective (returns SQLITE_LOCKED, not // SQLITE_BUSY) and concurrent connections deadlock. A temp file with // WAL mode provides proper concurrency: readers never block writers, // and _busy_timeout handles write-write contention. tmpDir, mkErr := os.MkdirTemp("", "vikunja-*") if mkErr != nil { return nil, fmt.Errorf("could not create temp directory for ephemeral database: %w", mkErr) } dbPath := filepath.Join(tmpDir, "vikunja.db") engine, err = xorm.NewEngine("sqlite3", dbPath+"?_busy_timeout=5000&_journal_mode=WAL") if err != nil { return } log.Infof("Using ephemeral SQLite database at: %s", dbPath) return } // Log the resolved database path log.Infof("Using SQLite database at: %s", path) // Warn if the database is in a potentially problematic location if isSystemDirectory(path) { log.Warningf("Database path (%s) appears to be in a system directory. This may cause issues. Please use an absolute path or configure the database path to a user data directory.", path) } // Try opening the db file to return a better error message if that does not work var exists = true if _, err := os.Stat(path); err != nil { exists = !os.IsNotExist(err) } file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0) if err != nil { return nil, fmt.Errorf("could not open database file [uid=%d, gid=%d]: %w", os.Getuid(), os.Getgid(), err) } _ = file.Close() // We directly close the file because we only want to check if it is writable. It will be reopened lazily later by xorm. if !exists { _ = os.Remove(path) // Remove the file to not prevent the db from creating another one } // WAL mode allows concurrent readers alongside a single writer without // blocking each other. busy_timeout makes concurrent writers wait (up to // 5 s) instead of failing immediately with SQLITE_BUSY. engine, err = xorm.NewEngine("sqlite3", path+"?_busy_timeout=5000&_journal_mode=WAL") return } // getUserDataDir returns the platform-appropriate directory for application data func getUserDataDir() (string, error) { var dataDir string switch runtime.GOOS { case "windows": // On Windows, use %LOCALAPPDATA%\Vikunja localAppData := os.Getenv("LOCALAPPDATA") if localAppData == "" { // Fallback to %USERPROFILE%\AppData\Local if LOCALAPPDATA is not set userProfile := os.Getenv("USERPROFILE") if userProfile == "" { return "", fmt.Errorf("neither LOCALAPPDATA nor USERPROFILE environment variables are set") } localAppData = filepath.Join(userProfile, "AppData", "Local") } dataDir = filepath.Join(localAppData, "Vikunja") case "darwin": // On macOS, use ~/Library/Application Support/Vikunja home, err := os.UserHomeDir() if err != nil { return "", err } dataDir = filepath.Join(home, "Library", "Application Support", "Vikunja") default: // On Linux and other Unix-like systems, use XDG_DATA_HOME or ~/.local/share/vikunja xdgDataHome := os.Getenv("XDG_DATA_HOME") if xdgDataHome != "" { dataDir = filepath.Join(xdgDataHome, "vikunja") } else { home, err := os.UserHomeDir() if err != nil { return "", err } dataDir = filepath.Join(home, ".local", "share", "vikunja") } } // Ensure the directory exists if err := os.MkdirAll(dataDir, 0o700); err != nil { return "", fmt.Errorf("could not create data directory %s: %w", dataDir, err) } return dataDir, nil } // isSystemDirectory checks if a path appears to be in a system directory // where users should not typically store application data func isSystemDirectory(path string) bool { // Clean and normalize the path path = filepath.Clean(path) lowerPath := strings.ToLower(path) // Windows system directories if runtime.GOOS == "windows" { // Convert to absolute path if possible for more accurate checking absPath := lowerPath if abs, err := filepath.Abs(path); err == nil { absPath = strings.ToLower(filepath.Clean(abs)) } // Check common Windows system directories using prefix matching // This prevents false positives like C:\myapp\windows\data windowsSystemPrefixes := []string{ "c:\\windows\\system32", "c:\\windows\\syswow64", "c:\\windows\\winsxs", "c:\\windows\\servicing", } for _, prefix := range windowsSystemPrefixes { if strings.HasPrefix(absPath, prefix) { return true } } // Also check for direct C:\Windows (not subdirectories like C:\myapp\windows) // by ensuring it starts with the drive and windows directory if absPath == "c:\\windows" || strings.HasPrefix(absPath, "c:\\windows\\") { // Exclude some safe subdirectories under C:\Windows safeDirs := []string{ "c:\\windows\\temp", } for _, safeDir := range safeDirs { if strings.HasPrefix(absPath, safeDir) { return false } } return true } } // Unix-like system directories - use prefix matching systemDirs := []string{ "/bin", "/sbin", "/usr/bin", "/usr/sbin", "/etc", "/sys", "/proc", "/dev", } for _, sysDir := range systemDirs { // Ensure we match exact directory boundaries if lowerPath == sysDir || strings.HasPrefix(lowerPath, sysDir+"/") { return true } } return false } // WipeEverything wipes all tables and their data. Use with caution... func WipeEverything() error { tables, err := x.DBMetas() if err != nil { return err } for _, t := range tables { if err := x.DropTables(t.Name); err != nil { return err } } return nil } // NewSession creates a new xorm session with an active transaction. // The caller must call s.Commit() on success or s.Rollback() on error. // s.Close() will auto-rollback any uncommitted transaction. func NewSession() *xorm.Session { s := x.NewSession() if err := s.Begin(); err != nil { log.Fatalf("Failed to begin database transaction: %s", err) } return s } // Type returns the db type of the currently configured db func Type() schemas.DBType { return x.Dialect().URI().DBType } func GetDialect() string { switch config.DatabaseType.GetString() { case "mysql": return builder.MYSQL case "postgres": return builder.POSTGRES default: return builder.SQLITE } } func checkParadeDB(engine *xorm.Engine) { if engine.Dialect().URI().DBType != schemas.POSTGRES { return } exists := false if _, err := engine.SQL("SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname='pg_search')").Get(&exists); err != nil { log.Errorf("could not check for paradedb extension: %v", err) return } if !exists { return } paradedbInstalled = true log.Debug("ParadeDB extension detected, using @@@ search operator") } func CreateParadeDBIndexes() error { if !paradedbInstalled { return nil } // ParadeDB only allows one bm25 index per table, so we create a single index covering both fields // Use optimized configuration with fast fields and field boosting for better performance indexSQL := `CREATE INDEX IF NOT EXISTS idx_tasks_paradedb ON tasks USING bm25 (id, title, description, project_id, done) WITH ( key_field='id', text_fields='{ "title": {"fast": true, "record": "freq"}, "description": {"fast": true, "record": "freq"} }', numeric_fields='{ "project_id": {"fast": true} }', boolean_fields='{ "done": {"fast": true} }' )` if _, err := x.Exec(indexSQL); err != nil { return fmt.Errorf("could not ensure paradedb task index: %w", err) } // Create ParadeDB index for projects table projectIndexSQL := `CREATE INDEX IF NOT EXISTS idx_projects_paradedb ON projects USING bm25 (id, title, description, identifier) WITH ( key_field='id', text_fields='{ "title": {"fast": true, "record": "freq"}, "description": {"fast": true, "record": "freq"}, "identifier": {"fast": true, "record": "freq"} }' )` if _, err := x.Exec(projectIndexSQL); err != nil { return fmt.Errorf("could not ensure paradedb project index: %w", err) } return nil }