diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index f618803..a7f0b6a 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -37,33 +37,38 @@ func (s Health) String() string { // Config holds the health check configuration for a target type Config struct { - ID int `json:"id"` - Enabled bool `json:"hcEnabled"` - Path string `json:"hcPath"` - Scheme string `json:"hcScheme"` - Mode string `json:"hcMode"` - Hostname string `json:"hcHostname"` - Port int `json:"hcPort"` - Interval int `json:"hcInterval"` // in seconds - UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds - Timeout int `json:"hcTimeout"` // in seconds - Headers map[string]string `json:"hcHeaders"` - Method string `json:"hcMethod"` - Status int `json:"hcStatus"` // HTTP status code - TLSServerName string `json:"hcTlsServerName"` + ID int `json:"id"` + Enabled bool `json:"hcEnabled"` + Path string `json:"hcPath"` + Scheme string `json:"hcScheme"` + Mode string `json:"hcMode"` + Hostname string `json:"hcHostname"` + Port int `json:"hcPort"` + Interval int `json:"hcInterval"` // in seconds + UnhealthyInterval int `json:"hcUnhealthyInterval"` // in seconds + Timeout int `json:"hcTimeout"` // in seconds + FollowRedirects bool `json:"hcFollowRedirects"` + Headers map[string]string `json:"hcHeaders"` + Method string `json:"hcMethod"` + Status int `json:"hcStatus"` // HTTP status code + TLSServerName string `json:"hcTlsServerName"` + HealthyThreshold int `json:"hcHealthyThreshold"` // consecutive successes required to become healthy + UnhealthyThreshold int `json:"hcUnhealthyThreshold"` // consecutive failures required to become unhealthy } // Target represents a health check target with its current status type Target struct { - Config Config `json:"config"` - Status Health `json:"status"` - LastCheck time.Time `json:"lastCheck"` - LastError string `json:"lastError,omitempty"` - CheckCount int `json:"checkCount"` - timer *time.Timer - ctx context.Context - cancel context.CancelFunc - client *http.Client + Config Config `json:"config"` + Status Health `json:"status"` + LastCheck time.Time `json:"lastCheck"` + LastError string `json:"lastError,omitempty"` + CheckCount int `json:"checkCount"` + timer *time.Timer + ctx context.Context + cancel context.CancelFunc + client *http.Client + consecutiveSuccesses int + consecutiveFailures int } // StatusChangeCallback is called when any target's status changes @@ -165,9 +170,16 @@ func (m *Monitor) addTargetUnsafe(config Config) error { if config.Timeout == 0 { config.Timeout = 5 } + if config.HealthyThreshold == 0 { + config.HealthyThreshold = 1 + } + if config.UnhealthyThreshold == 0 { + config.UnhealthyThreshold = 1 + } - logger.Debug("Target %d configuration: scheme=%s, method=%s, interval=%ds, timeout=%ds", - config.ID, config.Scheme, config.Method, config.Interval, config.Timeout) + logger.Debug("Target %d configuration: mode=%s, scheme=%s, method=%s, interval=%ds, timeout=%ds, healthyThreshold=%d, unhealthyThreshold=%d", + config.ID, config.Mode, config.Scheme, config.Method, config.Interval, config.Timeout, + config.HealthyThreshold, config.UnhealthyThreshold) // Parse headers if provided as string if len(config.Headers) == 0 && config.Path != "" { @@ -189,6 +201,14 @@ func (m *Monitor) addTargetUnsafe(config Config) error { ctx: ctx, cancel: cancel, client: &http.Client{ + CheckRedirect: func() func(*http.Request, []*http.Request) error { + if !config.FollowRedirects { + return func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } + return nil + }(), Transport: &http.Transport{ TLSClientConfig: &tls.Config{ // Configure TLS settings based on certificate enforcement @@ -361,12 +381,69 @@ func (m *Monitor) monitorTarget(target *Target) { } } -// performHealthCheck performs a health check on a target +// performHealthCheck performs a health check on a target and applies threshold logic func (m *Monitor) performHealthCheck(target *Target) { target.CheckCount++ target.LastCheck = time.Now() - target.LastError = "" + var passed bool + var checkErr string + + switch strings.ToLower(target.Config.Mode) { + case "tcp": + passed, checkErr = m.performTCPCheck(target) + default: + // "http", "https", or anything else falls through to HTTP + passed, checkErr = m.performHTTPCheck(target) + } + + if passed { + target.consecutiveFailures = 0 + target.consecutiveSuccesses++ + + logger.Debug("Target %d: check passed (consecutive successes: %d / threshold: %d)", + target.Config.ID, target.consecutiveSuccesses, target.Config.HealthyThreshold) + + if target.consecutiveSuccesses >= target.Config.HealthyThreshold { + target.Status = StatusHealthy + target.LastError = "" + } + } else { + target.consecutiveSuccesses = 0 + target.consecutiveFailures++ + target.LastError = checkErr + + logger.Debug("Target %d: check failed (consecutive failures: %d / threshold: %d): %s", + target.Config.ID, target.consecutiveFailures, target.Config.UnhealthyThreshold, checkErr) + + if target.consecutiveFailures >= target.Config.UnhealthyThreshold { + target.Status = StatusUnhealthy + } + } +} + +// performTCPCheck dials the target's host:port over TCP and returns whether it succeeded +func (m *Monitor) performTCPCheck(target *Target) (bool, string) { + address := net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port)) + timeout := time.Duration(target.Config.Timeout) * time.Second + + logger.Debug("Target %d: performing TCP health check to %s (timeout: %v)", + target.Config.ID, address, timeout) + + conn, err := net.DialTimeout("tcp", address, timeout) + if err != nil { + msg := fmt.Sprintf("TCP dial failed: %v", err) + logger.Warn("Target %d: %s", target.Config.ID, msg) + return false, msg + } + conn.Close() + + logger.Debug("Target %d: TCP health check passed", target.Config.ID) + return true, "" +} + +// performHTTPCheck performs an HTTP/HTTPS health check and returns whether it succeeded +func (m *Monitor) performHTTPCheck(target *Target) (bool, string) { // Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports) host := target.Config.Hostname if target.Config.Port > 0 { @@ -380,7 +457,7 @@ func (m *Monitor) performHealthCheck(target *Target) { url += target.Config.Path } - logger.Debug("Target %d: performing health check %d to %s", + logger.Debug("Target %d: performing HTTP health check %d to %s", target.Config.ID, target.CheckCount, url) if target.Config.Scheme == "https" { @@ -388,16 +465,15 @@ func (m *Monitor) performHealthCheck(target *Target) { target.Config.ID, m.enforceCert) } - // Create request + // Create request with timeout context ctx, cancel := context.WithTimeout(context.Background(), time.Duration(target.Config.Timeout)*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, target.Config.Method, url, nil) if err != nil { - target.Status = StatusUnhealthy - target.LastError = fmt.Sprintf("failed to create request: %v", err) - logger.Warn("Target %d: failed to create request: %v", target.Config.ID, err) - return + msg := fmt.Sprintf("failed to create request: %v", err) + logger.Warn("Target %d: %s", target.Config.ID, msg) + return false, msg } // Add headers @@ -413,43 +489,34 @@ func (m *Monitor) performHealthCheck(target *Target) { // Perform request resp, err := target.client.Do(req) if err != nil { - target.Status = StatusUnhealthy - target.LastError = fmt.Sprintf("request failed: %v", err) + msg := fmt.Sprintf("request failed: %v", err) logger.Warn("Target %d: health check failed: %v", target.Config.ID, err) - return + return false, msg } defer resp.Body.Close() // Check response status - var expectedStatus int if target.Config.Status > 0 { - expectedStatus = target.Config.Status - } else { - expectedStatus = 0 // Use range check for 200-299 + // Check for specific status code + logger.Debug("Target %d: checking status against expected code %d", target.Config.ID, target.Config.Status) + if resp.StatusCode == target.Config.Status { + logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode) + return true, "" + } + msg := fmt.Sprintf("unexpected status code: %d (expected: %d)", resp.StatusCode, target.Config.Status) + logger.Warn("Target %d: %s", target.Config.ID, msg) + return false, msg } - if expectedStatus > 0 { - logger.Debug("Target %d: checking health status against expected code %d", target.Config.ID, expectedStatus) - // Check for specific status code - if resp.StatusCode == expectedStatus { - target.Status = StatusHealthy - logger.Debug("Target %d: health check passed (status: %d, expected: %d)", target.Config.ID, resp.StatusCode, expectedStatus) - } else { - target.Status = StatusUnhealthy - target.LastError = fmt.Sprintf("unexpected status code: %d (expected: %d)", resp.StatusCode, expectedStatus) - logger.Warn("Target %d: health check failed with status code %d (expected: %d)", target.Config.ID, resp.StatusCode, expectedStatus) - } - } else { - // Check for 2xx range - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - target.Status = StatusHealthy - logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode) - } else { - target.Status = StatusUnhealthy - target.LastError = fmt.Sprintf("unhealthy status code: %d", resp.StatusCode) - logger.Warn("Target %d: health check failed with status code %d", target.Config.ID, resp.StatusCode) - } + // Default: check for 2xx range + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + logger.Debug("Target %d: health check passed (status: %d)", target.Config.ID, resp.StatusCode) + return true, "" } + + msg := fmt.Sprintf("unhealthy status code: %d", resp.StatusCode) + logger.Warn("Target %d: health check failed with status code %d", target.Config.ID, resp.StatusCode) + return false, msg } // Stop stops monitoring all targets