mirror of
https://github.com/go-vikunja/vikunja.git
synced 2026-03-11 17:48:44 -05:00
534 lines
17 KiB
Go
534 lines
17 KiB
Go
// Vikunja is a to-do list application to facilitate your life.
|
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package db
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"code.vikunja.io/api/pkg/config"
|
|
"code.vikunja.io/api/pkg/log"
|
|
|
|
"xorm.io/builder"
|
|
"xorm.io/xorm"
|
|
"xorm.io/xorm/names"
|
|
"xorm.io/xorm/schemas"
|
|
|
|
_ "github.com/go-sql-driver/mysql" // Because.
|
|
_ "github.com/lib/pq" // Because.
|
|
_ "github.com/mattn/go-sqlite3" // Because.
|
|
)
|
|
|
|
var (
|
|
// We only want one instance of the engine, so we can create it once and reuse it
|
|
x *xorm.Engine
|
|
// paradedbInstalled marks whether the paradedb extension is available
|
|
// and can be used for full text search.
|
|
paradedbInstalled bool
|
|
)
|
|
|
|
// registeredTables holds all table beans registered by Vikunja packages.
|
|
var registeredTables []interface{}
|
|
|
|
// RegisterTables registers table beans so that Dump and WipeEverything
|
|
// only operate on known Vikunja tables.
|
|
func RegisterTables(tables []interface{}) {
|
|
registeredTables = append(registeredTables, tables...)
|
|
}
|
|
|
|
// RegisteredTableNames returns the table names of all registered Vikunja tables.
|
|
func RegisteredTableNames() []string {
|
|
mapper := x.GetTableMapper()
|
|
names := make([]string, 0, len(registeredTables)+1)
|
|
for _, bean := range registeredTables {
|
|
names = append(names, mapper.Obj2Table(reflect.Indirect(reflect.ValueOf(bean)).Type().Name()))
|
|
}
|
|
// The xormigrate migration tracking table is not registered via GetTables()
|
|
names = append(names, "migration")
|
|
return names
|
|
}
|
|
|
|
// CreateDBEngine initializes a db engine from the config
|
|
func CreateDBEngine() (engine *xorm.Engine, err error) {
|
|
|
|
if x != nil {
|
|
return x, nil
|
|
}
|
|
|
|
// If the database type is not set, this likely means we need to initialize the config first
|
|
if config.DatabaseType.GetString() == "" {
|
|
config.InitConfig()
|
|
}
|
|
|
|
// Use Mysql if set
|
|
switch config.DatabaseType.GetString() {
|
|
case "mysql":
|
|
engine, err = initMysqlEngine()
|
|
if err != nil {
|
|
return
|
|
}
|
|
case "postgres":
|
|
engine, err = initPostgresEngine()
|
|
if err != nil {
|
|
return
|
|
}
|
|
case "sqlite":
|
|
// Otherwise use sqlite
|
|
engine, err = initSqliteEngine()
|
|
if err != nil {
|
|
return
|
|
}
|
|
default:
|
|
log.Fatalf("Unknown database type %s", config.DatabaseType.GetString())
|
|
}
|
|
|
|
engine.SetTZLocation(config.GetTimeZone()) // Vikunja's timezone
|
|
loc, err := time.LoadLocation("GMT") // The db data timezone
|
|
if err != nil {
|
|
log.Fatalf("Error parsing time zone: %s", err)
|
|
}
|
|
engine.SetTZDatabase(loc)
|
|
engine.SetMapper(names.GonicMapper{})
|
|
logger := log.NewXormLogger(config.LogEnabled.GetBool(), config.LogDatabase.GetString(), config.LogDatabaseLevel.GetString(), config.LogFormat.GetString())
|
|
engine.SetLogger(logger)
|
|
|
|
x = engine
|
|
return
|
|
}
|
|
|
|
func initMysqlEngine() (engine *xorm.Engine, err error) {
|
|
// We're using utf8mb here instead of just utf8 because we want to use non-BMP characters.
|
|
// See https://stackoverflow.com/a/30074553/10924593 for more info.
|
|
host := fmt.Sprintf("tcp(%s)", config.DatabaseHost.GetString())
|
|
if config.DatabaseHost.GetString()[0] == '/' { // looks like a unix socket
|
|
host = fmt.Sprintf("unix(%s)", config.DatabaseHost.GetString())
|
|
}
|
|
|
|
connStr := fmt.Sprintf(
|
|
"%s:%s@%s/%s?charset=utf8mb4&parseTime=true&tls=%s",
|
|
config.DatabaseUser.GetString(),
|
|
config.DatabasePassword.GetString(),
|
|
host,
|
|
config.DatabaseDatabase.GetString(),
|
|
config.DatabaseTLS.GetString())
|
|
engine, err = xorm.NewEngine("mysql", connStr)
|
|
if err != nil {
|
|
return
|
|
}
|
|
engine.SetMaxOpenConns(config.DatabaseMaxOpenConnections.GetInt())
|
|
engine.SetMaxIdleConns(config.DatabaseMaxIdleConnections.GetInt())
|
|
maxLifetime, err := time.ParseDuration(strconv.Itoa(config.DatabaseMaxConnectionLifetime.GetInt()) + `ms`)
|
|
if err != nil {
|
|
return
|
|
}
|
|
engine.SetConnMaxLifetime(maxLifetime)
|
|
return
|
|
}
|
|
|
|
// parsePostgreSQLHostPort parses given input in various forms defined in
|
|
// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
|
|
// and returns proper host and port number.
|
|
func parsePostgreSQLHostPort(info string) (string, string) {
|
|
host, port := "127.0.0.1", "5432"
|
|
if strings.Contains(info, ":") && !strings.HasSuffix(info, "]") {
|
|
idx := strings.LastIndex(info, ":")
|
|
host = info[:idx]
|
|
port = info[idx+1:]
|
|
} else if len(info) > 0 {
|
|
host = info
|
|
}
|
|
return host, port
|
|
}
|
|
|
|
// Copied and adopted from https://github.com/go-gitea/gitea/blob/f337c32e868381c6d2d948221aca0c59f8420c13/modules/setting/database.go#L176-L186
|
|
func getPostgreSQLConnectionString(dbHost, dbUser, dbPasswd, dbName, dbSslMode, dbSslCert, dbSslKey, dbSslRootCert string) (connStr string) {
|
|
dbParam := "?"
|
|
if strings.Contains(dbName, dbParam) {
|
|
dbParam = "&"
|
|
}
|
|
host, port := parsePostgreSQLHostPort(dbHost)
|
|
if host[0] == '/' { // looks like a unix socket
|
|
connStr = fmt.Sprintf("postgres://%s:%s@:%s/%s%ssslmode=%s&sslcert=%s&sslkey=%s&sslrootcert=%s&host=%s",
|
|
url.PathEscape(dbUser), url.PathEscape(dbPasswd), port, dbName, dbParam, dbSslMode, dbSslCert, dbSslKey, dbSslRootCert, host)
|
|
} else {
|
|
connStr = fmt.Sprintf("postgres://%s:%s@%s:%s/%s%ssslmode=%s&sslcert=%s&sslkey=%s&sslrootcert=%s",
|
|
url.PathEscape(dbUser), url.PathEscape(dbPasswd), host, port, dbName, dbParam, dbSslMode, dbSslCert, dbSslKey, dbSslRootCert)
|
|
}
|
|
return connStr
|
|
}
|
|
|
|
func initPostgresEngine() (engine *xorm.Engine, err error) {
|
|
connStr := getPostgreSQLConnectionString(
|
|
config.DatabaseHost.GetString(),
|
|
config.DatabaseUser.GetString(),
|
|
config.DatabasePassword.GetString(),
|
|
config.DatabaseDatabase.GetString(),
|
|
config.DatabaseSslMode.GetString(),
|
|
config.DatabaseSslCert.GetString(),
|
|
config.DatabaseSslKey.GetString(),
|
|
config.DatabaseSslRootCert.GetString(),
|
|
)
|
|
|
|
engine, err = xorm.NewEngine("postgres", connStr)
|
|
if err != nil {
|
|
return
|
|
}
|
|
engine.SetSchema(config.DatabaseSchema.GetString())
|
|
engine.SetMaxOpenConns(config.DatabaseMaxOpenConnections.GetInt())
|
|
engine.SetMaxIdleConns(config.DatabaseMaxIdleConnections.GetInt())
|
|
maxLifetime, err := time.ParseDuration(strconv.Itoa(config.DatabaseMaxConnectionLifetime.GetInt()) + `ms`)
|
|
if err != nil {
|
|
return
|
|
}
|
|
engine.SetConnMaxLifetime(maxLifetime)
|
|
|
|
checkParadeDB(engine)
|
|
return
|
|
}
|
|
|
|
// DatabasePathConfig holds configuration for database path resolution.
|
|
// This struct allows the path resolution logic to be tested independently
|
|
// of the global config package.
|
|
type DatabasePathConfig struct {
|
|
ConfiguredPath string // The database.path config value
|
|
RootPath string // The service.rootpath config value
|
|
ExecutablePath string // Directory of the executable binary
|
|
}
|
|
|
|
// resolveDatabasePath resolves a database path configuration to an absolute path.
|
|
//
|
|
// Resolution rules:
|
|
// 1. If ConfiguredPath is "memory", returns "memory" (special case for in-memory DB)
|
|
// 2. If ConfiguredPath is already absolute, returns it as-is (cleaned)
|
|
// 3. If ConfiguredPath is relative:
|
|
// a. If RootPath differs from ExecutablePath (explicitly configured),
|
|
// joins with RootPath
|
|
// b. Otherwise, joins with platform-specific user data directory
|
|
//
|
|
// The getUserDataDir parameter allows injecting a mock for testing.
|
|
func resolveDatabasePath(cfg DatabasePathConfig, getUserDataDir func() (string, error)) (string, error) {
|
|
if cfg.ConfiguredPath == "memory" {
|
|
return "memory", nil
|
|
}
|
|
|
|
var path string
|
|
|
|
switch {
|
|
case filepath.IsAbs(cfg.ConfiguredPath):
|
|
path = filepath.Clean(cfg.ConfiguredPath)
|
|
case cfg.RootPath != cfg.ExecutablePath:
|
|
path = filepath.Join(cfg.RootPath, cfg.ConfiguredPath)
|
|
default:
|
|
dataDir, err := getUserDataDir()
|
|
if err != nil {
|
|
log.Warningf("Could not get user data directory, falling back to rootpath: %v", err)
|
|
path = filepath.Join(cfg.RootPath, cfg.ConfiguredPath)
|
|
} else {
|
|
path = filepath.Join(dataDir, cfg.ConfiguredPath)
|
|
}
|
|
}
|
|
|
|
return filepath.Abs(path)
|
|
}
|
|
|
|
func initSqliteEngine() (engine *xorm.Engine, err error) {
|
|
rootPath := config.ServiceRootpath.GetString()
|
|
|
|
executablePath := rootPath
|
|
if execPath, err := os.Executable(); err == nil {
|
|
executablePath = filepath.Dir(execPath)
|
|
}
|
|
|
|
cfg := DatabasePathConfig{
|
|
ConfiguredPath: config.DatabasePath.GetString(),
|
|
RootPath: rootPath,
|
|
ExecutablePath: executablePath,
|
|
}
|
|
|
|
path, err := resolveDatabasePath(cfg, getUserDataDir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not resolve database path: %w", err)
|
|
}
|
|
|
|
if path == "memory" {
|
|
// Use a temp file with WAL mode instead of in-memory shared cache.
|
|
// Shared cache (file::memory:?cache=shared) uses table-level locking
|
|
// where _busy_timeout is ineffective (returns SQLITE_LOCKED, not
|
|
// SQLITE_BUSY) and concurrent connections deadlock. A temp file with
|
|
// WAL mode provides proper concurrency: readers never block writers,
|
|
// and _busy_timeout handles write-write contention.
|
|
tmpDir, mkErr := os.MkdirTemp("", "vikunja-*")
|
|
if mkErr != nil {
|
|
return nil, fmt.Errorf("could not create temp directory for ephemeral database: %w", mkErr)
|
|
}
|
|
dbPath := filepath.Join(tmpDir, "vikunja.db")
|
|
engine, err = xorm.NewEngine("sqlite3", dbPath+"?_busy_timeout=5000&_journal_mode=WAL")
|
|
if err != nil {
|
|
return
|
|
}
|
|
log.Infof("Using ephemeral SQLite database at: %s", dbPath)
|
|
return
|
|
}
|
|
|
|
// Log the resolved database path
|
|
log.Infof("Using SQLite database at: %s", path)
|
|
|
|
// Warn if the database is in a potentially problematic location
|
|
if isSystemDirectory(path) {
|
|
log.Warningf("Database path (%s) appears to be in a system directory. This may cause issues. Please use an absolute path or configure the database path to a user data directory.", path)
|
|
}
|
|
|
|
// Try opening the db file to return a better error message if that does not work
|
|
var exists = true
|
|
if _, err := os.Stat(path); err != nil {
|
|
exists = !os.IsNotExist(err)
|
|
}
|
|
file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not open database file [uid=%d, gid=%d]: %w", os.Getuid(), os.Getgid(), err)
|
|
}
|
|
_ = file.Close() // We directly close the file because we only want to check if it is writable. It will be reopened lazily later by xorm.
|
|
|
|
if !exists {
|
|
_ = os.Remove(path) // Remove the file to not prevent the db from creating another one
|
|
}
|
|
|
|
// WAL mode allows concurrent readers alongside a single writer without
|
|
// blocking each other. busy_timeout makes concurrent writers wait (up to
|
|
// 5 s) instead of failing immediately with SQLITE_BUSY.
|
|
engine, err = xorm.NewEngine("sqlite3", path+"?_busy_timeout=5000&_journal_mode=WAL")
|
|
return
|
|
}
|
|
|
|
// getUserDataDir returns the platform-appropriate directory for application data
|
|
func getUserDataDir() (string, error) {
|
|
var dataDir string
|
|
|
|
switch runtime.GOOS {
|
|
case "windows":
|
|
// On Windows, use %LOCALAPPDATA%\Vikunja
|
|
localAppData := os.Getenv("LOCALAPPDATA")
|
|
if localAppData == "" {
|
|
// Fallback to %USERPROFILE%\AppData\Local if LOCALAPPDATA is not set
|
|
userProfile := os.Getenv("USERPROFILE")
|
|
if userProfile == "" {
|
|
return "", fmt.Errorf("neither LOCALAPPDATA nor USERPROFILE environment variables are set")
|
|
}
|
|
localAppData = filepath.Join(userProfile, "AppData", "Local")
|
|
}
|
|
dataDir = filepath.Join(localAppData, "Vikunja")
|
|
case "darwin":
|
|
// On macOS, use ~/Library/Application Support/Vikunja
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
dataDir = filepath.Join(home, "Library", "Application Support", "Vikunja")
|
|
default:
|
|
// On Linux and other Unix-like systems, use XDG_DATA_HOME or ~/.local/share/vikunja
|
|
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
|
if xdgDataHome != "" {
|
|
dataDir = filepath.Join(xdgDataHome, "vikunja")
|
|
} else {
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
dataDir = filepath.Join(home, ".local", "share", "vikunja")
|
|
}
|
|
}
|
|
|
|
// Ensure the directory exists
|
|
if err := os.MkdirAll(dataDir, 0o700); err != nil {
|
|
return "", fmt.Errorf("could not create data directory %s: %w", dataDir, err)
|
|
}
|
|
|
|
return dataDir, nil
|
|
}
|
|
|
|
// isSystemDirectory checks if a path appears to be in a system directory
|
|
// where users should not typically store application data
|
|
func isSystemDirectory(path string) bool {
|
|
// Clean and normalize the path
|
|
path = filepath.Clean(path)
|
|
lowerPath := strings.ToLower(path)
|
|
|
|
// Windows system directories
|
|
if runtime.GOOS == "windows" {
|
|
// Convert to absolute path if possible for more accurate checking
|
|
absPath := lowerPath
|
|
if abs, err := filepath.Abs(path); err == nil {
|
|
absPath = strings.ToLower(filepath.Clean(abs))
|
|
}
|
|
|
|
// Check common Windows system directories using prefix matching
|
|
// This prevents false positives like C:\myapp\windows\data
|
|
windowsSystemPrefixes := []string{
|
|
"c:\\windows\\system32",
|
|
"c:\\windows\\syswow64",
|
|
"c:\\windows\\winsxs",
|
|
"c:\\windows\\servicing",
|
|
}
|
|
|
|
for _, prefix := range windowsSystemPrefixes {
|
|
if strings.HasPrefix(absPath, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Also check for direct C:\Windows (not subdirectories like C:\myapp\windows)
|
|
// by ensuring it starts with the drive and windows directory
|
|
if absPath == "c:\\windows" || strings.HasPrefix(absPath, "c:\\windows\\") {
|
|
// Exclude some safe subdirectories under C:\Windows
|
|
safeDirs := []string{
|
|
"c:\\windows\\temp",
|
|
}
|
|
for _, safeDir := range safeDirs {
|
|
if strings.HasPrefix(absPath, safeDir) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Unix-like system directories - use prefix matching
|
|
systemDirs := []string{
|
|
"/bin", "/sbin", "/usr/bin", "/usr/sbin",
|
|
"/etc", "/sys", "/proc", "/dev",
|
|
}
|
|
for _, sysDir := range systemDirs {
|
|
// Ensure we match exact directory boundaries
|
|
if lowerPath == sysDir || strings.HasPrefix(lowerPath, sysDir+"/") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// WipeEverything wipes all tables and their data. Use with caution...
|
|
func WipeEverything() error {
|
|
|
|
tables, err := x.DBMetas()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, t := range tables {
|
|
if err := x.DropTables(t.Name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NewSession creates a new xorm session with an active transaction.
|
|
// The caller must call s.Commit() on success or s.Rollback() on error.
|
|
// s.Close() will auto-rollback any uncommitted transaction.
|
|
func NewSession() *xorm.Session {
|
|
s := x.NewSession()
|
|
if err := s.Begin(); err != nil {
|
|
log.Fatalf("Failed to begin database transaction: %s", err)
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Type returns the db type of the currently configured db
|
|
func Type() schemas.DBType {
|
|
return x.Dialect().URI().DBType
|
|
}
|
|
|
|
func GetDialect() string {
|
|
switch config.DatabaseType.GetString() {
|
|
case "mysql":
|
|
return builder.MYSQL
|
|
case "postgres":
|
|
return builder.POSTGRES
|
|
default:
|
|
return builder.SQLITE
|
|
}
|
|
}
|
|
|
|
func checkParadeDB(engine *xorm.Engine) {
|
|
if engine.Dialect().URI().DBType != schemas.POSTGRES {
|
|
return
|
|
}
|
|
|
|
exists := false
|
|
if _, err := engine.SQL("SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname='pg_search')").Get(&exists); err != nil {
|
|
log.Errorf("could not check for paradedb extension: %v", err)
|
|
return
|
|
}
|
|
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
paradedbInstalled = true
|
|
log.Debug("ParadeDB extension detected, using @@@ search operator")
|
|
}
|
|
|
|
func CreateParadeDBIndexes() error {
|
|
if !paradedbInstalled {
|
|
return nil
|
|
}
|
|
// ParadeDB only allows one bm25 index per table, so we create a single index covering both fields
|
|
// Use optimized configuration with fast fields and field boosting for better performance
|
|
indexSQL := `CREATE INDEX IF NOT EXISTS idx_tasks_paradedb ON tasks USING bm25 (id, title, description, project_id, done)
|
|
WITH (
|
|
key_field='id',
|
|
text_fields='{
|
|
"title": {"fast": true, "record": "freq"},
|
|
"description": {"fast": true, "record": "freq"}
|
|
}',
|
|
numeric_fields='{
|
|
"project_id": {"fast": true}
|
|
}',
|
|
boolean_fields='{
|
|
"done": {"fast": true}
|
|
}'
|
|
)`
|
|
if _, err := x.Exec(indexSQL); err != nil {
|
|
return fmt.Errorf("could not ensure paradedb task index: %w", err)
|
|
}
|
|
|
|
// Create ParadeDB index for projects table
|
|
projectIndexSQL := `CREATE INDEX IF NOT EXISTS idx_projects_paradedb ON projects USING bm25 (id, title, description, identifier)
|
|
WITH (
|
|
key_field='id',
|
|
text_fields='{
|
|
"title": {"fast": true, "record": "freq"},
|
|
"description": {"fast": true, "record": "freq"},
|
|
"identifier": {"fast": true, "record": "freq"}
|
|
}'
|
|
)`
|
|
if _, err := x.Exec(projectIndexSQL); err != nil {
|
|
return fmt.Errorf("could not ensure paradedb project index: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|