Files
f2b/fail2ban/helpers.go
Ismo Vuorinen 605f2b9580 refactor: linting, simplification and fixes (#119)
* refactor: consolidate test helpers and reduce code duplication

- Fix prealloc lint issue in cmd_logswatch_test.go
- Add validateIPAndJails helper to consolidate IP/jail validation
- Add WithTestRunner/WithTestSudoChecker helpers for cleaner test setup
- Replace setupBasicMockResponses duplicates with StandardMockSetup
- Add SetupStandardResponses/SetupJailResponses to MockRunner
- Delegate cmd context helpers to fail2ban implementations
- Document context wrapper pattern in context_helpers.go

* refactor: consolidate duplicate code patterns across cmd and fail2ban packages

Add helper functions to reduce code duplication found by dupl:

- safeCloseFile/safeCloseReader: centralize file close error logging
- createTimeoutContext: consolidate timeout context creation pattern
- withContextCheck: wrap context cancellation checks
- recordOperationMetrics: unify metrics recording for commands/clients

Also includes Phase 1 consolidations:
- copyBuckets helper for metrics snapshots
- Table-driven context extraction in logging
- processWithValidation helper for IP processors

* refactor: consolidate LoggerInterface by embedding LoggerEntry

Both interfaces had identical method signatures. LoggerInterface now
embeds LoggerEntry to eliminate code duplication.

* refactor: consolidate test framework helpers and fix test patterns

- Add checkJSONFieldValue and failMissingJSONField helpers to reduce
  duplication in JSON assertion methods
- Add ParallelTimeout to default test config
- Fix test to use WithTestRunner inside test loop for proper mock scoping

* refactor: unify ban/unban operations with OperationType pattern

Introduce OperationType struct to consolidate duplicate ban/unban logic:
- Add ProcessOperation and ProcessOperationWithContext generic functions
- Add ProcessOperationParallel and ProcessOperationParallelWithContext
- Existing ProcessBan*/ProcessUnban* functions now delegate to generic versions
- Reduces ~120 lines of duplicate code between ban and unban operations

* refactor: consolidate time parsing cache pattern

Add ParseWithLayout method to BoundedTimeCache that consolidates the
cache-lookup-parse-store pattern. FastTimeCache and TimeParsingCache
now delegate to this method instead of duplicating the logic.

* refactor: consolidate command execution patterns in fail2ban

- Add validateCommandExecution helper for command/argument validation
- Add runWithTimerContext helper for timed runner operations
- Add executeIPActionWithContext to unify BanIP/UnbanIP implementations
- Reduces duplicate validation and execution boilerplate

* refactor: consolidate logrus adapter with embedded loggerCore

Introduce loggerCore type that provides the 8 standard logging methods
(Debug, Info, Warn, Error, Debugf, Infof, Warnf, Errorf). Both
logrusAdapter and logrusEntryAdapter now embed this type, eliminating
16 duplicate method implementations.

* refactor: consolidate path validation patterns

- Add validateConfigPathWithFallback helper in cmd/config_utils.go
  for the validate-or-fallback-with-logging pattern
- Add validateClientPath helper in fail2ban/helpers.go for client
  path validation delegation

* fix: add context cancellation checks to wrapper functions

- wrapWithContext0/1/2 now check ctx.Err() before invoking wrapped function
- WithCommand now validates and trims empty command strings

* refactor: extract formatLatencyBuckets for deterministic metrics output

Add formatLatencyBuckets helper that writes latency bucket distribution
with sorted keys for deterministic output, eliminating duplicate
formatting code for command and client latency buckets.

* refactor: add generic setNestedMapValue helper for mock configuration

Add setNestedMapValue[T] generic helper that consolidates the repeated
pattern of mutex-protected nested map initialization and value setting
used by SetBanError, SetBanResult, SetUnbanError, and SetUnbanResult.

* fix: use cmd.Context() for signal propagation and correct mock status

- ExecuteIPCommand now uses cmd.Context() instead of context.Background()
  to inherit Cobra's signal cancellation
- MockRunner.SetupJailResponses uses shared.Fail2BanStatusSuccess ("0")
  instead of literal "1" for proper success path simulation

* fix: restore operation-specific log messages in ProcessOperationWithContext

Add back Logger.WithFields().Info(opType.Message) call that was lost
during refactoring. This restores the distinction between ban and unban
operation messages (shared.MsgBanResult vs shared.MsgUnbanResult).

* fix: return aggregated errors from parallel operations

Previously, errors from individual parallel operations were silently
swallowed - converted to status strings but never returned to callers.

Now processOperations collects all errors and returns them aggregated
via errors.Join, allowing callers to distinguish partial failures from
complete success while still receiving all results.

* fix: add input validation to processOperations before parallel execution

Validate IP and jail inputs at the start of processOperations() using
fail2ban.CachedValidateIP and CachedValidateJail. This prevents invalid
or malicious inputs (empty values, path traversal attempts, malformed
IPs) from reaching the operation functions. All validation errors are
aggregated and returned before any operations execute.
2026-01-25 19:07:45 +02:00

841 lines
25 KiB
Go

package fail2ban
import (
"context"
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"unicode"
"github.com/hashicorp/go-version"
"github.com/ivuorinen/f2b/shared"
)
func init() {
// Configure logging for CI/test environments to reduce noise
// This now comes from the logging_env module
}
// Validation constants
// Validation helpers
// ValidateIP validates an IP address string and returns an error if invalid
func ValidateIP(ip string) error {
if ip == "" {
return ErrIPRequiredError
}
// Check for valid IPv4 or IPv6 address
parsed := net.ParseIP(ip)
if parsed == nil {
// Don't include potentially malicious input in error message
if containsCommandInjectionPatterns(ip) || len(ip) > shared.MaxIPAddressLength {
return fmt.Errorf("invalid IP address format")
}
return NewInvalidIPError(ip)
}
return nil
}
// ValidateJail validates a jail name and returns an error if invalid
func ValidateJail(jail string) error {
if jail == "" {
return ErrJailRequiredError
}
// Jail names should be reasonable length
if len(jail) > shared.MaxJailNameLength {
// Don't include potentially malicious input in error message
if containsCommandInjectionPatterns(jail) {
return fmt.Errorf(shared.ErrInvalidJailFormat)
}
return NewInvalidJailError(jail + " (too long)")
}
// First character should be alphanumeric
if len(jail) > 0 {
first := rune(jail[0])
if !unicode.IsLetter(first) && !unicode.IsDigit(first) {
// Don't include potentially malicious input in error message
if containsCommandInjectionPatterns(jail) {
return fmt.Errorf(shared.ErrInvalidJailFormat)
}
return NewInvalidJailError(jail + " (invalid format)")
}
}
// Rest can be alphanumeric, dash, underscore, or dot
for _, r := range jail {
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '-' && r != '_' && r != '.' {
// Don't include potentially malicious input in error message
if containsCommandInjectionPatterns(jail) {
return fmt.Errorf(shared.ErrInvalidJailFormat)
}
return NewInvalidJailError(jail + " (invalid character)")
}
}
return nil
}
// ValidateFilter validates a filter name and returns an error if invalid
func ValidateFilter(filter string) error {
if filter == "" {
return ErrFilterRequiredError
}
// Check length limits to prevent buffer overflow attacks
if len(filter) > shared.MaxFilterNameLength {
return NewInvalidFilterError(filter + " (too long)")
}
// Check for null bytes
if strings.Contains(filter, "\x00") {
return NewInvalidFilterError(filter + " (contains null bytes)")
}
// Enhanced path traversal detection
if ContainsPathTraversal(filter) {
return NewInvalidFilterError(filter + " (path traversal)")
}
// Check for command injection patterns (defense in depth)
if containsCommandInjectionPatterns(filter) {
return NewInvalidFilterError(filter + " (injection patterns)")
}
// Character validation - only allow safe characters
for _, r := range filter {
if !isValidFilterChar(r) {
return NewInvalidFilterError(filter + " (invalid characters)")
}
}
// Additional validation: ensure filter doesn't start/end with dangerous patterns
if strings.HasPrefix(filter, ".") || strings.HasSuffix(filter, ".") {
// Allow single extension like ".conf" but not ".." or "..."
if strings.Contains(filter, "..") {
return NewInvalidFilterError(filter + " (invalid dot patterns)")
}
}
return nil
}
// ValidateJailExists checks if a jail exists in the given list
func ValidateJailExists(jail string, jails []string) error {
for _, j := range jails {
if j == jail {
return nil
}
}
return NewJailNotFoundError(jail)
}
// Command execution helpers
// Parsing helpers
// ParseJailList parses the jail list output from fail2ban-client status
func ParseJailList(output string) ([]string, error) {
// Optimized: Find "Jail list:" position directly instead of splitting all lines
jailListPos := strings.Index(output, "Jail list:")
if jailListPos == -1 {
return nil, fmt.Errorf(shared.ErrFailedToParseJails)
}
// Find the start of the jail list content (after "Jail list:")
colonPos := strings.Index(output[jailListPos:], ":")
if colonPos == -1 {
return nil, fmt.Errorf(shared.ErrFailedToParseJails)
}
// Find the end of the line
start := jailListPos + colonPos + 1
end := strings.Index(output[start:], "\n")
if end == -1 {
end = len(output) - start
}
jailList := strings.TrimSpace(output[start : start+end])
if jailList == "" {
return []string{}, nil // Return empty list for no jails
}
// Optimized: Use byte replacement instead of string replacement for single character
if strings.Contains(jailList, ",") {
jailList = strings.ReplaceAll(jailList, ",", " ")
}
return strings.Fields(jailList), nil
}
// ParseBracketedList parses bracketed output like "[jail1, jail2]"
func ParseBracketedList(output string) []string {
// Optimized: Manual bracket removal instead of Trim to avoid checking both ends
s := output
if len(s) >= 2 && s[0] == '[' && s[len(s)-1] == ']' {
s = s[1 : len(s)-1]
}
if s == "" {
return []string{}
}
// Optimized: Remove quotes first, then split to avoid multiple string operations
if strings.Contains(s, "\"") {
s = strings.ReplaceAll(s, "\"", "")
}
parts := strings.Split(s, ",")
// Optimized: Trim in-place to avoid additional allocations
for i, part := range parts {
parts[i] = strings.TrimSpace(part)
}
return parts
}
// Utility helpers
// CompareVersions compares two version strings
var (
fail2banVersionPattern = regexp.MustCompile(`(?i)fail2ban(?:-client)?[\s-]*v?([0-9]+(?:\.[0-9]+)*)(?:[-+].*)?`)
versionNumberPattern = regexp.MustCompile(`^v?([0-9]+(?:\.[0-9]+)*)(?:[-+].*)?$`)
)
// CompareVersions compares two version strings
func CompareVersions(v1, v2 string) int {
version1, err1 := version.NewVersion(v1)
version2, err2 := version.NewVersion(v2)
// If either version is invalid, fall back to string comparison
if err1 != nil || err2 != nil {
return strings.Compare(v1, v2)
}
return version1.Compare(version2)
}
// ExtractFail2BanVersion extracts the semantic version from fail2ban-client -V output
func ExtractFail2BanVersion(output string) (string, error) {
trimmed := strings.TrimSpace(output)
if trimmed == "" {
return "", fmt.Errorf("empty version output")
}
if match := fail2banVersionPattern.FindStringSubmatch(trimmed); len(match) == 2 {
return match[1], nil
}
if match := versionNumberPattern.FindStringSubmatch(trimmed); len(match) == 2 {
return match[1], nil
}
return "", fmt.Errorf("unable to parse version from %q", trimmed)
}
// FormatDuration formats seconds into a human-readable duration string
func FormatDuration(sec int64) string {
days := sec / shared.SecondsPerDay
h := (sec % shared.SecondsPerDay) / shared.SecondsPerHour
m := (sec % shared.SecondsPerHour) / shared.SecondsPerMinute
s := sec % shared.SecondsPerMinute
return fmt.Sprintf("%02d:%02d:%02d:%02d", days, h, m, s)
}
// ValidateCommand validates that a command is in the allowlist for security
func ValidateCommand(command string) error {
// Allowlist of commands that f2b is permitted to execute
allowedCommands := map[string]bool{
shared.Fail2BanClientCommand: true,
shared.Fail2BanRegexCommand: true,
shared.Fail2BanServerCommand: true,
"service": true,
"systemctl": true,
"sudo": true, // Only when used internally
}
if command == "" {
return NewInvalidCommandError("command cannot be empty")
}
// Check for null bytes (command injection attempt)
if strings.ContainsRune(command, '\x00') {
// Don't include potentially malicious input in error message
return fmt.Errorf(shared.ErrInvalidCommandFormat)
}
// Check for dangerous patterns first (before including command in error messages)
dangerousPatterns := GetDangerousCommandPatterns()
cmdLower := strings.ToLower(command)
for _, pattern := range dangerousPatterns {
if strings.Contains(cmdLower, strings.ToLower(pattern)) {
// Don't include potentially dangerous command in error message
return fmt.Errorf(shared.ErrInvalidCommandFormat)
}
}
// Check for path traversal in command name
if ContainsPathTraversal(command) {
// Don't include potentially malicious input in error message
return NewInvalidCommandError(command + " (path traversal)")
}
// Additional security checks for command injection patterns
if containsCommandInjectionPatterns(command) {
// Don't include potentially malicious input in error message
return fmt.Errorf(shared.ErrInvalidCommandFormat)
}
// Command must be a bare executable name (no paths or whitespace)
if strings.ContainsAny(command, "/\\ \t") {
return fmt.Errorf(shared.ErrInvalidCommandFormat)
}
// Validate against allowlist (safe to include command name for allowed commands)
if !allowedCommands[command] {
return NewCommandNotAllowedError(command)
}
return nil
}
// ValidateArguments validates command arguments for security
func ValidateArguments(args []string) error {
return ValidateArgumentsWithContext(context.Background(), args)
}
// ValidateArgumentsWithContext validates command arguments for security with context support
func ValidateArgumentsWithContext(ctx context.Context, args []string) error {
for i, arg := range args {
if err := validateSingleArgument(ctx, arg, i); err != nil {
return fmt.Errorf("argument %d invalid: %w", i, err)
}
}
return nil
}
// validateSingleArgument validates a single command argument
func validateSingleArgument(ctx context.Context, arg string, _ int) error {
// Check for null bytes
if strings.ContainsRune(arg, '\x00') {
return NewInvalidArgumentError(arg + " (contains null byte)")
}
// Check length to prevent buffer overflow
if len(arg) > shared.MaxArgumentLength {
return NewInvalidArgumentError(fmt.Sprintf("%s (too long: %d chars)", arg, len(arg)))
}
// Check for command injection patterns
if containsCommandInjectionPatterns(arg) {
return NewInvalidArgumentError(arg + " (injection patterns)")
}
// For IP arguments, validate IP format
if isLikelyIPArgument(arg) {
if err := CachedValidateIP(ctx, arg); err != nil {
return fmt.Errorf("invalid IP format: %w", err)
}
}
return nil
}
// containsCommandInjectionPatterns detects common command injection patterns
func containsCommandInjectionPatterns(input string) bool {
// Optimized: Check single characters first (fastest)
for _, r := range input {
switch r {
case ';', '&', '|', '`', '$', '<', '>', '\n', '\r', '\t':
return true
}
}
// Optimized: Convert to lower case only once and check multi-character patterns
inputLower := strings.ToLower(input)
// Multi-character patterns - be specific to avoid false positives
multiCharPatterns := []string{
"$(", "${", "&&", "||", ">>", "<<",
"exec ", "system(", "eval(",
}
for _, pattern := range multiCharPatterns {
if strings.Contains(inputLower, pattern) {
return true
}
}
return false
}
// isLikelyIPArgument heuristically determines if an argument looks like an IP address
func isLikelyIPArgument(arg string) bool {
// Simple heuristic: contains dots and digits
return strings.Contains(arg, ".") && strings.ContainsAny(arg, "0123456789")
}
// Internal helper functions
// isValidFilterChar checks if a character is allowed in filter names
func isValidFilterChar(r rune) bool {
// Allow letters, digits, and safe punctuation
return unicode.IsLetter(r) ||
unicode.IsDigit(r) ||
r == '-' ||
r == '_' ||
r == '.' ||
r == '@' || // Allow @ for email-like patterns
r == '+' || // Allow + for variations
r == '~' // Allow ~ for common naming
}
// Timing infrastructure for performance monitoring
// TimedOperation represents a timed operation with metadata
type TimedOperation struct {
Name string
Command string
Args []string
StartTime time.Time
}
// NewTimedOperation creates a new timed operation and starts timing
func NewTimedOperation(name, command string, args ...string) *TimedOperation {
return &TimedOperation{
Name: name,
Command: command,
Args: args,
StartTime: time.Now(),
}
}
// Finish completes the timed operation and logs the duration with context
func (t *TimedOperation) Finish(err error) {
duration := time.Since(t.StartTime)
fields := Fields{
"operation": t.Name,
"command": t.Command,
"duration": duration,
"args": strings.Join(t.Args, " "),
}
if err != nil {
getLogger().WithFields(fields).
WithField(shared.LogFieldError, err.Error()).
Warnf(shared.ErrOperationFailed, duration)
} else {
if duration > time.Second {
// Log slow operations as warnings for visibility
getLogger().WithFields(fields).Warnf(shared.ErrSlowOperation, duration)
} else {
// Log fast operations at debug level to reduce noise
getLogger().WithFields(fields).Debugf(shared.MsgOperationCompleted, duration)
}
}
}
// FinishWithContext completes the timed operation and logs the duration with context
func (t *TimedOperation) FinishWithContext(ctx context.Context, err error) {
duration := time.Since(t.StartTime)
// Get logger with context fields
logger := LoggerFromContext(ctx)
// Add timing-specific fields
fields := Fields{
"operation": t.Name,
"command": t.Command,
"duration": duration,
"args": strings.Join(t.Args, " "),
}
logger = logger.WithFields(fields)
if err != nil {
logger.WithField(shared.LogFieldError, err.Error()).Warnf(shared.ErrOperationFailed, duration)
} else {
if duration > time.Second {
// Log slow operations as warnings for visibility
logger.Warnf(shared.ErrSlowOperation, duration)
} else {
// Log fast operations at debug level to reduce noise
logger.Debugf(shared.MsgOperationCompleted, duration)
}
}
}
// Path helper functions for centralized path validation
// PathSecurityConfig holds configuration for path security validation
type PathSecurityConfig struct {
AllowedBasePaths []string // List of allowed base directories
MaxPathLength int // Maximum allowed path length (0 = unlimited)
AllowSymlinks bool // Whether to allow symlinks
ResolveSymlinks bool // Whether to resolve symlinks before validation
}
// GetLogAllowedPaths returns allowed paths for log directories
func GetLogAllowedPaths() []string {
paths := []string{"/var/log", "/opt", "/usr/local", "/home"}
paths = appendDevPathsIfAllowed(paths)
return expandAllowedPaths(paths)
}
// GetFilterAllowedPaths returns allowed paths for filter directories
func GetFilterAllowedPaths() []string {
paths := []string{"/etc/fail2ban", "/usr/local/etc/fail2ban", "/opt/fail2ban", "/home"}
paths = appendDevPathsIfAllowed(paths)
return expandAllowedPaths(paths)
}
// appendDevPathsIfAllowed adds development paths if ALLOW_DEV_PATHS is set
func appendDevPathsIfAllowed(paths []string) []string {
if os.Getenv("ALLOW_DEV_PATHS") != "" {
return append(paths, "/tmp", "/var/folders") // macOS temp dirs
}
return paths
}
// expandAllowedPaths adds resolved equivalents for allowed paths and removes duplicates
func expandAllowedPaths(paths []string) []string {
seen := make(map[string]struct{}, len(paths)*2)
expanded := make([]string, 0, len(paths)*2)
for _, p := range paths {
if p == "" {
continue
}
if _, ok := seen[p]; !ok {
expanded = append(expanded, p)
seen[p] = struct{}{}
}
if resolved, err := resolveAncestorSymlinks(p, true); err == nil && resolved != "" && resolved != p {
if _, ok := seen[resolved]; !ok {
expanded = append(expanded, resolved)
seen[resolved] = struct{}{}
}
}
}
return expanded
}
// CreateLogPathConfig creates a standard PathSecurityConfig for log directories
func CreateLogPathConfig() PathSecurityConfig {
return PathSecurityConfig{
AllowedBasePaths: GetLogAllowedPaths(),
MaxPathLength: 4096,
AllowSymlinks: true,
ResolveSymlinks: true,
}
}
// CreateFilterPathConfig creates a standard PathSecurityConfig for filter directories
func CreateFilterPathConfig() PathSecurityConfig {
return PathSecurityConfig{
AllowedBasePaths: GetFilterAllowedPaths(),
MaxPathLength: 4096,
AllowSymlinks: true,
ResolveSymlinks: true,
}
}
// CreateSingleDirPathConfig creates a path config for a single directory (like log file validation)
func CreateSingleDirPathConfig(baseDir string) PathSecurityConfig {
return PathSecurityConfig{
AllowedBasePaths: []string{baseDir},
MaxPathLength: 4096,
AllowSymlinks: false,
ResolveSymlinks: true,
}
}
// ValidatePathWithSecurity performs comprehensive path security validation
func ValidatePathWithSecurity(path string, config PathSecurityConfig) (string, error) {
if path == "" {
return "", fmt.Errorf("empty path not allowed")
}
// Check path length limits (initial check)
if config.MaxPathLength > 0 && len(path) > config.MaxPathLength {
return "", fmt.Errorf("path too long: %d characters (max: %d)", len(path), config.MaxPathLength)
}
// Detect and prevent null byte injection (initial check)
if strings.Contains(path, "\x00") {
return "", fmt.Errorf("path contains null byte")
}
// Decode URL-encoded path traversal attempts (path semantics)
if decodedPath, err := url.PathUnescape(path); err == nil && decodedPath != path {
getLogger().Debug("Detected URL-encoded path; using decoded version for validation")
path = decodedPath
}
// Normalize unicode characters to prevent bypass attempts
path = normalizeUnicode(path)
// Re-validate after decoding and normalization to prevent bypass
if config.MaxPathLength > 0 && len(path) > config.MaxPathLength {
return "", fmt.Errorf("path too long after decoding: %d characters (max: %d)", len(path), config.MaxPathLength)
}
// Re-check for null bytes after decoding and normalization
if strings.Contains(path, "\x00") {
return "", fmt.Errorf("path contains null byte after decoding")
}
// Basic path traversal detection (before cleaning)
if hasPathTraversal(path) {
return "", fmt.Errorf("path contains path traversal patterns")
}
// Clean and resolve the path
cleanPath, err := filepath.Abs(filepath.Clean(path))
if err != nil {
return "", fmt.Errorf("invalid path: %w", err)
}
// Additional check after cleaning (double-check for sophisticated attacks)
if hasPathTraversal(cleanPath) {
return "", fmt.Errorf("path contains path traversal patterns after normalization")
}
// Handle symlinks according to configuration
finalPath, err := handleSymlinks(cleanPath, config)
if err != nil {
return "", err
}
// Validate against allowed base paths using Rel, not prefix
if err := validateBasePath(finalPath, config.AllowedBasePaths); err != nil {
return "", err
}
// Check if path points to a device file or other dangerous file types
if err := validateFileType(finalPath); err != nil {
return "", err
}
return finalPath, nil
}
// hasPathTraversal detects various path traversal patterns
func hasPathTraversal(path string) bool {
// Check for various path traversal patterns
dangerousPatterns := []string{
"..",
"./",
".\\",
"//",
"\\\\",
"/../",
"\\..\\",
"%2e%2e", // URL encoded ..
"%2f", // URL encoded /
"%5c", // URL encoded \
"\u002e\u002e", // Unicode ..
"\u2024\u2024", // Unicode bullet points (can look like ..)
"\uff0e\uff0e", // Full-width Unicode ..
}
pathLower := strings.ToLower(path)
for _, pattern := range dangerousPatterns {
if strings.Contains(pathLower, strings.ToLower(pattern)) {
return true
}
}
return false
}
// normalizeUnicode normalizes unicode characters to prevent bypass attempts
func normalizeUnicode(path string) string {
// Replace various Unicode representations of dots and slashes
replacements := map[string]string{
"\u002e": ".", // Unicode dot
"\u2024": ".", // Unicode bullet (one dot leader)
"\uff0e": ".", // Full-width dot
"\u002f": "/", // Unicode slash
"\u2044": "/", // Unicode fraction slash
"\uff0f": "/", // Full-width slash
"\u005c": "\\", // Unicode backslash
"\uff3c": "\\", // Full-width backslash
}
result := path
for unicode, ascii := range replacements {
result = strings.ReplaceAll(result, unicode, ascii)
}
return result
}
// handleSymlinks resolves or validates symlinks according to configuration
func handleSymlinks(path string, config PathSecurityConfig) (string, error) {
// Check if the path is a symlink
if info, err := os.Lstat(path); err == nil {
if info.Mode()&os.ModeSymlink != 0 {
if !config.AllowSymlinks {
return "", fmt.Errorf("symlinks not allowed: %s", path)
}
if config.ResolveSymlinks {
resolved, err := filepath.EvalSymlinks(path)
if err != nil {
return "", fmt.Errorf(shared.ErrFailedToResolveSymlink, err)
}
return resolved, nil
}
}
} else if !os.IsNotExist(err) {
return "", fmt.Errorf("failed to check file info: %w", err)
}
// If leaf doesn't exist, resolve symlinks in the deepest existing ancestor
if config.ResolveSymlinks {
return resolveAncestorSymlinks(path, config.AllowSymlinks)
}
return path, nil
}
// resolveAncestorSymlinks resolves symlinks in existing ancestor directories
func resolveAncestorSymlinks(path string, allowSymlinks bool) (string, error) {
dir := path
var tail []string
for {
d := filepath.Dir(dir)
if d == dir {
break
}
if _, err := os.Lstat(dir); err == nil {
break
}
tail = append([]string{filepath.Base(dir)}, tail...)
dir = d
}
if fi, err := os.Lstat(dir); err == nil && fi.Mode()&os.ModeSymlink != 0 {
if !allowSymlinks {
return "", fmt.Errorf("symlinks not allowed in path: %s", dir)
}
resolved, err := filepath.EvalSymlinks(dir)
if err != nil {
return "", fmt.Errorf(shared.ErrFailedToResolveSymlink, err)
}
return filepath.Join(append([]string{resolved}, tail...)...), nil
}
return path, nil
}
// validateBasePath ensures the path is within allowed base directories
func validateBasePath(path string, allowedBasePaths []string) error {
if len(allowedBasePaths) == 0 {
return nil // No restrictions if no base paths configured
}
for _, basePath := range allowedBasePaths {
cleanBasePath, err := filepath.Abs(filepath.Clean(basePath))
if err != nil {
continue
}
rel, err := filepath.Rel(cleanBasePath, path)
if err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return nil
}
}
return fmt.Errorf("path outside allowed directories: %s", path)
}
// validateFileType checks for dangerous file types (devices, named pipes, etc.)
func validateFileType(path string) error {
// Check if file exists
info, err := os.Stat(path)
if os.IsNotExist(err) {
return nil // File doesn't exist yet, allow it
}
if err != nil {
return fmt.Errorf("failed to stat file: %w", err)
}
mode := info.Mode()
// Block device files
if mode&os.ModeDevice != 0 {
return fmt.Errorf("device files not allowed: %s", path)
}
// Block named pipes (FIFOs)
if mode&os.ModeNamedPipe != 0 {
return fmt.Errorf("named pipes not allowed: %s", path)
}
// Block socket files
if mode&os.ModeSocket != 0 {
return fmt.Errorf("socket files not allowed: %s", path)
}
// Block irregular files (anything that's not a regular file or directory)
if !mode.IsRegular() && !mode.IsDir() {
return fmt.Errorf("irregular file type not allowed: %s", path)
}
return nil
}
// ValidateLogPath validates and sanitizes a log file path using standard log directory config
// Context parameter accepted for API consistency but not currently used
func ValidateLogPath(ctx context.Context, path string, logDir string) (string, error) {
_ = ctx // Context not currently used by ValidatePathWithSecurity
config := CreateSingleDirPathConfig(logDir)
return ValidatePathWithSecurity(path, config)
}
// validateClientPath is a generic helper for client path validation.
// It reduces duplication between ValidateClientLogPath and ValidateClientFilterPath.
func validateClientPath(ctx context.Context, path string, configFn func() PathSecurityConfig) (string, error) {
_ = ctx // Context not currently used by ValidatePathWithSecurity
return ValidatePathWithSecurity(path, configFn())
}
// ValidateClientLogPath validates log directory path for client initialization
// Context parameter accepted for API consistency but not currently used
func ValidateClientLogPath(ctx context.Context, logDir string) (string, error) {
return validateClientPath(ctx, logDir, CreateLogPathConfig)
}
// ValidateClientFilterPath validates filter directory path for client initialization
// Context parameter accepted for API consistency but not currently used
func ValidateClientFilterPath(ctx context.Context, filterDir string) (string, error) {
return validateClientPath(ctx, filterDir, CreateFilterPathConfig)
}
// ValidateFilterName validates a filter name for path traversal prevention.
// Rejects: "..", "/", "\", absolute paths, drive letters
// Allows: letters, digits, dash, underscore only
func ValidateFilterName(filter string) error {
filter = strings.TrimSpace(filter)
if filter == "" {
return fmt.Errorf("filter name cannot be empty")
}
// Check for path traversal
if ContainsPathTraversal(filter) {
return fmt.Errorf("filter name contains path traversal")
}
// Check for absolute paths
if filepath.IsAbs(filter) {
return fmt.Errorf("filter name cannot be an absolute path")
}
// Only allow safe characters (alphanumeric, dash, underscore)
if !regexp.MustCompile(`^[a-zA-Z0-9_-]+$`).MatchString(filter) {
return fmt.Errorf("filter name contains invalid characters")
}
return nil
}