mirror of
https://github.com/fosrl/olm.git
synced 2025-12-05 19:17:35 -06:00
632 lines
15 KiB
Go
632 lines
15 KiB
Go
//go:build windows
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/fosrl/newt/logger"
|
|
"golang.org/x/sys/windows/svc"
|
|
"golang.org/x/sys/windows/svc/debug"
|
|
"golang.org/x/sys/windows/svc/eventlog"
|
|
"golang.org/x/sys/windows/svc/mgr"
|
|
)
|
|
|
|
const (
|
|
serviceName = "OlmWireguardService"
|
|
serviceDisplayName = "Olm WireGuard VPN Service"
|
|
serviceDescription = "Olm WireGuard VPN client service for secure network connectivity"
|
|
)
|
|
|
|
// Global variable to store service arguments
|
|
var serviceArgs []string
|
|
|
|
// getServiceArgsPath returns the path where service arguments are stored
|
|
func getServiceArgsPath() string {
|
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm")
|
|
return filepath.Join(logDir, "service_args.json")
|
|
}
|
|
|
|
// saveServiceArgs saves the service arguments to a file
|
|
func saveServiceArgs(args []string) error {
|
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm")
|
|
err := os.MkdirAll(logDir, 0755)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create config directory: %v", err)
|
|
}
|
|
|
|
argsPath := getServiceArgsPath()
|
|
data, err := json.Marshal(args)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal service args: %v", err)
|
|
}
|
|
|
|
err = os.WriteFile(argsPath, data, 0644)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write service args: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadServiceArgs loads the service arguments from a file
|
|
func loadServiceArgs() ([]string, error) {
|
|
argsPath := getServiceArgsPath()
|
|
data, err := os.ReadFile(argsPath)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return []string{}, nil // Return empty args if file doesn't exist
|
|
}
|
|
return nil, fmt.Errorf("failed to read service args: %v", err)
|
|
}
|
|
|
|
var args []string
|
|
err = json.Unmarshal(data, &args)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal service args: %v", err)
|
|
}
|
|
|
|
return args, nil
|
|
}
|
|
|
|
type olmService struct {
|
|
elog debug.Log
|
|
ctx context.Context
|
|
stop context.CancelFunc
|
|
args []string
|
|
}
|
|
|
|
func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
|
|
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
|
changes <- svc.Status{State: svc.StartPending}
|
|
|
|
s.elog.Info(1, fmt.Sprintf("Service Execute called with args: %v", args))
|
|
|
|
// Load saved service arguments
|
|
savedArgs, err := loadServiceArgs()
|
|
if err != nil {
|
|
s.elog.Error(1, fmt.Sprintf("Failed to load service args: %v", err))
|
|
// Continue with empty args if loading fails
|
|
savedArgs = []string{}
|
|
}
|
|
|
|
// Combine service start args with saved args, giving priority to service start args
|
|
finalArgs := []string{}
|
|
if len(args) > 0 {
|
|
// Skip the first arg which is typically the service name
|
|
if len(args) > 1 {
|
|
finalArgs = append(finalArgs, args[1:]...)
|
|
}
|
|
s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs))
|
|
}
|
|
|
|
// If no service start parameters, use saved args
|
|
if len(finalArgs) == 0 && len(savedArgs) > 0 {
|
|
finalArgs = savedArgs
|
|
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
|
|
}
|
|
|
|
s.args = finalArgs
|
|
|
|
// Start the main olm functionality
|
|
olmDone := make(chan struct{})
|
|
go func() {
|
|
s.runOlm()
|
|
close(olmDone)
|
|
}()
|
|
|
|
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
|
|
s.elog.Info(1, "Service status set to Running")
|
|
|
|
for {
|
|
select {
|
|
case c := <-r:
|
|
switch c.Cmd {
|
|
case svc.Interrogate:
|
|
changes <- c.CurrentStatus
|
|
case svc.Stop, svc.Shutdown:
|
|
s.elog.Info(1, "Service stopping")
|
|
changes <- svc.Status{State: svc.StopPending}
|
|
if s.stop != nil {
|
|
s.stop()
|
|
}
|
|
// Wait for main logic to finish or timeout
|
|
select {
|
|
case <-olmDone:
|
|
s.elog.Info(1, "Main logic finished gracefully")
|
|
case <-time.After(10 * time.Second):
|
|
s.elog.Info(1, "Timeout waiting for main logic to finish")
|
|
}
|
|
return false, 0
|
|
default:
|
|
s.elog.Error(1, fmt.Sprintf("Unexpected control request #%d", c))
|
|
}
|
|
case <-olmDone:
|
|
s.elog.Info(1, "Main olm logic completed, stopping service")
|
|
changes <- svc.Status{State: svc.StopPending}
|
|
return false, 0
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *olmService) runOlm() {
|
|
// Create a context that can be cancelled when the service stops
|
|
s.ctx, s.stop = context.WithCancel(context.Background())
|
|
|
|
// Setup logging for service mode
|
|
s.elog.Info(1, "Starting Olm main logic")
|
|
|
|
// Run the main olm logic and wait for it to complete
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
s.elog.Error(1, fmt.Sprintf("Olm panic: %v", r))
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
// Call the main olm function with stored arguments
|
|
runOlmMainWithArgs(s.ctx, s.args)
|
|
}()
|
|
|
|
// Wait for either context cancellation or main logic completion
|
|
select {
|
|
case <-s.ctx.Done():
|
|
s.elog.Info(1, "Olm service context cancelled")
|
|
case <-done:
|
|
s.elog.Info(1, "Olm main logic completed")
|
|
}
|
|
}
|
|
|
|
func runService(name string, isDebug bool, args []string) {
|
|
var err error
|
|
var elog debug.Log
|
|
|
|
if isDebug {
|
|
elog = debug.New(name)
|
|
fmt.Printf("Starting %s service in debug mode\n", name)
|
|
} else {
|
|
elog, err = eventlog.Open(name)
|
|
if err != nil {
|
|
fmt.Printf("Failed to open event log: %v\n", err)
|
|
return
|
|
}
|
|
}
|
|
defer elog.Close()
|
|
|
|
elog.Info(1, fmt.Sprintf("Starting %s service", name))
|
|
run := svc.Run
|
|
if isDebug {
|
|
run = debug.Run
|
|
}
|
|
|
|
service := &olmService{elog: elog, args: args}
|
|
err = run(name, service)
|
|
if err != nil {
|
|
elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err))
|
|
if isDebug {
|
|
fmt.Printf("Service failed: %v\n", err)
|
|
}
|
|
return
|
|
}
|
|
elog.Info(1, fmt.Sprintf("%s service stopped", name))
|
|
if isDebug {
|
|
fmt.Printf("%s service stopped\n", name)
|
|
}
|
|
}
|
|
|
|
func installService() error {
|
|
exepath, err := os.Executable()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get executable path: %v", err)
|
|
}
|
|
|
|
m, err := mgr.Connect()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
|
}
|
|
defer m.Disconnect()
|
|
|
|
s, err := m.OpenService(serviceName)
|
|
if err == nil {
|
|
s.Close()
|
|
return fmt.Errorf("service %s already exists", serviceName)
|
|
}
|
|
|
|
config := mgr.Config{
|
|
ServiceType: 0x10, // SERVICE_WIN32_OWN_PROCESS
|
|
StartType: mgr.StartManual,
|
|
ErrorControl: mgr.ErrorNormal,
|
|
DisplayName: serviceDisplayName,
|
|
Description: serviceDescription,
|
|
BinaryPathName: exepath,
|
|
}
|
|
|
|
s, err = m.CreateService(serviceName, exepath, config)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create service: %v", err)
|
|
}
|
|
defer s.Close()
|
|
|
|
err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info)
|
|
if err != nil {
|
|
s.Delete()
|
|
return fmt.Errorf("failed to install event log: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func removeService() error {
|
|
m, err := mgr.Connect()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
|
}
|
|
defer m.Disconnect()
|
|
|
|
s, err := m.OpenService(serviceName)
|
|
if err != nil {
|
|
return fmt.Errorf("service %s is not installed", serviceName)
|
|
}
|
|
defer s.Close()
|
|
|
|
// Stop the service if it's running
|
|
status, err := s.Query()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query service status: %v", err)
|
|
}
|
|
|
|
if status.State != svc.Stopped {
|
|
_, err = s.Control(svc.Stop)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to stop service: %v", err)
|
|
}
|
|
|
|
// Wait for service to stop
|
|
timeout := time.Now().Add(30 * time.Second)
|
|
for status.State != svc.Stopped {
|
|
if timeout.Before(time.Now()) {
|
|
return fmt.Errorf("timeout waiting for service to stop")
|
|
}
|
|
time.Sleep(300 * time.Millisecond)
|
|
status, err = s.Query()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query service status: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
err = s.Delete()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete service: %v", err)
|
|
}
|
|
|
|
err = eventlog.Remove(serviceName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to remove event log: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func startService(args []string) error {
|
|
// Save the service arguments as backup
|
|
if len(args) > 0 {
|
|
err := saveServiceArgs(args)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save service args: %v", err)
|
|
}
|
|
}
|
|
|
|
m, err := mgr.Connect()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
|
}
|
|
defer m.Disconnect()
|
|
|
|
s, err := m.OpenService(serviceName)
|
|
if err != nil {
|
|
return fmt.Errorf("service %s is not installed", serviceName)
|
|
}
|
|
defer s.Close()
|
|
|
|
// Pass arguments directly to the service start call
|
|
err = s.Start(args...)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start service: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func stopService() error {
|
|
m, err := mgr.Connect()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to service manager: %v", err)
|
|
}
|
|
defer m.Disconnect()
|
|
|
|
s, err := m.OpenService(serviceName)
|
|
if err != nil {
|
|
return fmt.Errorf("service %s is not installed", serviceName)
|
|
}
|
|
defer s.Close()
|
|
|
|
status, err := s.Control(svc.Stop)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to stop service: %v", err)
|
|
}
|
|
|
|
timeout := time.Now().Add(30 * time.Second)
|
|
for status.State != svc.Stopped {
|
|
if timeout.Before(time.Now()) {
|
|
return fmt.Errorf("timeout waiting for service to stop")
|
|
}
|
|
time.Sleep(300 * time.Millisecond)
|
|
status, err = s.Query()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query service status: %v", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func debugService(args []string) error {
|
|
// Save the service arguments before starting
|
|
if len(args) > 0 {
|
|
err := saveServiceArgs(args)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save service args: %v", err)
|
|
}
|
|
}
|
|
|
|
// Start the service with the provided arguments
|
|
err := startService(args)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start service: %v", err)
|
|
}
|
|
|
|
// Watch the log file
|
|
return watchLogFile(true)
|
|
}
|
|
|
|
func watchLogFile(end bool) error {
|
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
|
logPath := filepath.Join(logDir, "olm.log")
|
|
|
|
// Ensure the log directory exists
|
|
err := os.MkdirAll(logDir, 0755)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create log directory: %v", err)
|
|
}
|
|
|
|
// Wait for the log file to be created if it doesn't exist
|
|
var file *os.File
|
|
for i := 0; i < 30; i++ { // Wait up to 15 seconds
|
|
file, err = os.Open(logPath)
|
|
if err == nil {
|
|
break
|
|
}
|
|
if i == 0 {
|
|
fmt.Printf("Waiting for log file to be created...\n")
|
|
}
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open log file after waiting: %v", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
// Seek to the end of the file to only show new logs
|
|
_, err = file.Seek(0, 2)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to seek to end of file: %v", err)
|
|
}
|
|
|
|
// Set up signal handling for graceful exit
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
|
|
|
// Create a ticker to check for new content
|
|
ticker := time.NewTicker(500 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
buffer := make([]byte, 4096)
|
|
|
|
for {
|
|
select {
|
|
case <-sigCh:
|
|
fmt.Printf("\n\nStopping log watch...\n")
|
|
// stop the service if needed
|
|
if end {
|
|
if err := stopService(); err != nil {
|
|
fmt.Printf("Failed to stop service: %v\n", err)
|
|
}
|
|
}
|
|
fmt.Printf("Log watch stopped.\n")
|
|
return nil
|
|
case <-ticker.C:
|
|
// Read new content
|
|
n, err := file.Read(buffer)
|
|
if err != nil && err != io.EOF {
|
|
// Try to reopen the file in case it was recreated
|
|
file.Close()
|
|
file, err = os.Open(logPath)
|
|
if err != nil {
|
|
return fmt.Errorf("error reopening log file: %v", err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if n > 0 {
|
|
// Print the new content
|
|
fmt.Print(string(buffer[:n]))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func getServiceStatus() (string, error) {
|
|
m, err := mgr.Connect()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to connect to service manager: %v", err)
|
|
}
|
|
defer m.Disconnect()
|
|
|
|
s, err := m.OpenService(serviceName)
|
|
if err != nil {
|
|
return "Not Installed", nil
|
|
}
|
|
defer s.Close()
|
|
|
|
status, err := s.Query()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to query service status: %v", err)
|
|
}
|
|
|
|
switch status.State {
|
|
case svc.Stopped:
|
|
return "Stopped", nil
|
|
case svc.StartPending:
|
|
return "Starting", nil
|
|
case svc.StopPending:
|
|
return "Stopping", nil
|
|
case svc.Running:
|
|
return "Running", nil
|
|
case svc.ContinuePending:
|
|
return "Continue Pending", nil
|
|
case svc.PausePending:
|
|
return "Pause Pending", nil
|
|
case svc.Paused:
|
|
return "Paused", nil
|
|
default:
|
|
return "Unknown", nil
|
|
}
|
|
}
|
|
|
|
// showServiceConfig displays current saved service configuration
|
|
func showServiceConfig() {
|
|
configPath := getServiceArgsPath()
|
|
fmt.Printf("Service configuration file: %s\n", configPath)
|
|
|
|
args, err := loadServiceArgs()
|
|
if err != nil {
|
|
fmt.Printf("No saved configuration found or error loading: %v\n", err)
|
|
return
|
|
}
|
|
|
|
if len(args) == 0 {
|
|
fmt.Println("No saved service arguments found")
|
|
} else {
|
|
fmt.Printf("Saved service arguments: %v\n", args)
|
|
}
|
|
}
|
|
|
|
func isWindowsService() bool {
|
|
isWindowsService, err := svc.IsWindowsService()
|
|
return err == nil && isWindowsService
|
|
}
|
|
|
|
// rotateLogFile handles daily log rotation
|
|
func rotateLogFile(logDir string, logFile string) error {
|
|
// Get current log file info
|
|
info, err := os.Stat(logFile)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return nil // No current log file to rotate
|
|
}
|
|
return fmt.Errorf("failed to stat log file: %v", err)
|
|
}
|
|
|
|
// Check if log file is from today
|
|
now := time.Now()
|
|
fileTime := info.ModTime()
|
|
|
|
// If the log file is from today, no rotation needed
|
|
if now.Year() == fileTime.Year() && now.YearDay() == fileTime.YearDay() {
|
|
return nil
|
|
}
|
|
|
|
// Create rotated filename with date
|
|
rotatedName := fmt.Sprintf("olm-%s.log", fileTime.Format("2006-01-02"))
|
|
rotatedPath := filepath.Join(logDir, rotatedName)
|
|
|
|
// Rename current log file to dated filename
|
|
err = os.Rename(logFile, rotatedPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to rotate log file: %v", err)
|
|
}
|
|
|
|
// Clean up old log files (keep last 30 days)
|
|
cleanupOldLogFiles(logDir, 30)
|
|
|
|
return nil
|
|
}
|
|
|
|
// cleanupOldLogFiles removes log files older than specified days
|
|
func cleanupOldLogFiles(logDir string, daysToKeep int) {
|
|
cutoff := time.Now().AddDate(0, 0, -daysToKeep)
|
|
|
|
files, err := os.ReadDir(logDir)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
for _, file := range files {
|
|
if !file.IsDir() && strings.HasPrefix(file.Name(), "olm-") && strings.HasSuffix(file.Name(), ".log") {
|
|
filePath := filepath.Join(logDir, file.Name())
|
|
info, err := file.Info()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if info.ModTime().Before(cutoff) {
|
|
os.Remove(filePath)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func setupWindowsEventLog() {
|
|
// Create log directory if it doesn't exist
|
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
|
err := os.MkdirAll(logDir, 0755)
|
|
if err != nil {
|
|
fmt.Printf("Failed to create log directory: %v\n", err)
|
|
return
|
|
}
|
|
|
|
logFile := filepath.Join(logDir, "olm.log")
|
|
|
|
// Rotate log file if needed
|
|
err = rotateLogFile(logDir, logFile)
|
|
if err != nil {
|
|
fmt.Printf("Failed to rotate log file: %v\n", err)
|
|
// Continue anyway to create new log file
|
|
}
|
|
|
|
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
|
if err != nil {
|
|
fmt.Printf("Failed to open log file: %v\n", err)
|
|
return
|
|
}
|
|
|
|
// Set the custom logger output
|
|
logger.GetLogger().SetOutput(file)
|
|
|
|
log.Printf("Olm service logging initialized - log file: %s", logFile)
|
|
}
|