diff --git a/cmd/cmd_logswatch_test.go b/cmd/cmd_logswatch_test.go index f91349d..5ad7d0a 100644 --- a/cmd/cmd_logswatch_test.go +++ b/cmd/cmd_logswatch_test.go @@ -301,7 +301,7 @@ func (m *MockLogsWatchClient) GetLogLines(jail, ip string) ([]string, error) { logs = m.initialLogs } else { // Simulate new logs being added - logs = make([]string, len(m.initialLogs)) + logs = make([]string, len(m.initialLogs), len(m.initialLogs)+1) copy(logs, m.initialLogs) logs = append(logs, fmt.Sprintf("new log line %d", m.callCount)) } diff --git a/cmd/cmd_parallel_operations_test.go b/cmd/cmd_parallel_operations_test.go index 4d86439..d4008cc 100644 --- a/cmd/cmd_parallel_operations_test.go +++ b/cmd/cmd_parallel_operations_test.go @@ -88,7 +88,7 @@ func TestParallelOperationProcessor_EmptyJailsList(t *testing.T) { } func TestParallelOperationProcessor_ErrorHandling(t *testing.T) { - // Test error handling doesn't cause index issues + // Test error handling returns aggregated errors while still populating results processor := NewParallelOperationProcessor(2) // Mock client for testing @@ -99,15 +99,16 @@ func TestParallelOperationProcessor_ErrorHandling(t *testing.T) { results, err := processor.ProcessBanOperationParallel(mockClient, "192.168.1.100", jails) - if err != nil { - t.Fatalf("ProcessBanOperationParallel failed: %v", err) + // Errors should now be returned (aggregated) + if err == nil { + t.Error("Expected error for non-existent jails, got nil") } if len(results) != 2 { t.Errorf("Expected 2 results, got %d", len(results)) } - // All results should have errors for non-existent jails + // All results should still be populated with error status for i, result := range results { if result.Jail == "" { t.Errorf("Result %d has empty jail", i) diff --git a/cmd/command_test_framework.go b/cmd/command_test_framework.go index 3fa1990..881f97b 100644 --- a/cmd/command_test_framework.go +++ b/cmd/command_test_framework.go @@ -149,9 +149,10 @@ func NewCommandTest(t *testing.T, commandName string) *CommandTestBuilder { command: commandName, args: make([]string, 0), config: &Config{ - Format: PlainFormat, - CommandTimeout: shared.DefaultCommandTimeout, - FileTimeout: shared.DefaultFileTimeout, + Format: PlainFormat, + CommandTimeout: shared.DefaultCommandTimeout, + FileTimeout: shared.DefaultFileTimeout, + ParallelTimeout: shared.DefaultParallelTimeout, }, } } @@ -418,6 +419,20 @@ func (result *CommandTestResult) AssertExactOutput(expected string) *CommandTest return result } +// checkJSONFieldValue validates that a JSON field value matches the expected string. +func (result *CommandTestResult) checkJSONFieldValue(val interface{}, fieldName, expected string) { + result.t.Helper() + if fmt.Sprintf("%v", val) != expected { + result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val) + } +} + +// failMissingJSONField reports a missing JSON field with context. +func (result *CommandTestResult) failMissingJSONField(fieldName, context string) { + result.t.Helper() + result.t.Fatalf("%s: JSON field %q not found%s: %s", result.name, fieldName, context, result.Output) +} + // AssertJSONField validates a specific field in JSON output func (result *CommandTestResult) AssertJSONField(fieldPath, expected string) *CommandTestResult { result.t.Helper() @@ -434,22 +449,18 @@ func (result *CommandTestResult) AssertJSONField(fieldPath, expected string) *Co switch v := data.(type) { case map[string]interface{}: if val, ok := v[fieldName]; ok { - if fmt.Sprintf("%v", val) != expected { - result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val) - } + result.checkJSONFieldValue(val, fieldName, expected) } else { - result.t.Fatalf("%s: JSON field %q not found in output: %s", result.name, fieldName, result.Output) + result.failMissingJSONField(fieldName, " in output") } case []interface{}: // Handle array case - look in first element if len(v) > 0 { if firstItem, ok := v[0].(map[string]interface{}); ok { if val, ok := firstItem[fieldName]; ok { - if fmt.Sprintf("%v", val) != expected { - result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val) - } + result.checkJSONFieldValue(val, fieldName, expected) } else { - result.t.Fatalf("%s: JSON field %q not found in first array element: %s", result.name, fieldName, result.Output) + result.failMissingJSONField(fieldName, " in first array element") } } else { result.t.Fatalf("%s: first array element is not an object in output: %s", result.name, result.Output) diff --git a/cmd/commands_coverage_test.go b/cmd/commands_coverage_test.go index f8dad3d..86b4278 100644 --- a/cmd/commands_coverage_test.go +++ b/cmd/commands_coverage_test.go @@ -12,13 +12,9 @@ import ( // TestTestFilterCmdCreation tests TestFilterCmd command creation func TestTestFilterCmdCreation(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - mockRunner := fail2ban.NewMockRunner() - setupBasicMockResponses(mockRunner) - fail2ban.SetRunner(mockRunner) + defer fail2ban.WithTestRunner(t, mockRunner)() + fail2ban.StandardMockSetup(mockRunner) client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d") require.NoError(t, err) @@ -39,10 +35,6 @@ func TestTestFilterCmdCreation(t *testing.T) { // TestTestFilterCmdExecution tests TestFilterCmd execution func TestTestFilterCmdExecution(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - tests := []struct { name string setupMock func(*fail2ban.MockRunner) @@ -52,7 +44,7 @@ func TestTestFilterCmdExecution(t *testing.T) { { name: "successful filter test", setupMock: func(m *fail2ban.MockRunner) { - setupBasicMockResponses(m) + fail2ban.StandardMockSetup(m) m.SetResponse("fail2ban-client get sshd logpath", []byte("/var/log/auth.log")) m.SetResponse("sudo fail2ban-client get sshd logpath", []byte("/var/log/auth.log")) }, @@ -62,7 +54,7 @@ func TestTestFilterCmdExecution(t *testing.T) { { name: "no filter provided - lists available", setupMock: func(m *fail2ban.MockRunner) { - setupBasicMockResponses(m) + fail2ban.StandardMockSetup(m) // Mock ListFiltersWithContext response }, args: []string{}, @@ -71,7 +63,7 @@ func TestTestFilterCmdExecution(t *testing.T) { { name: "invalid filter name", setupMock: func(m *fail2ban.MockRunner) { - setupBasicMockResponses(m) + fail2ban.StandardMockSetup(m) }, args: []string{"../../../etc/passwd"}, expectError: true, @@ -81,8 +73,8 @@ func TestTestFilterCmdExecution(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockRunner := fail2ban.NewMockRunner() + defer fail2ban.WithTestRunner(t, mockRunner)() tt.setupMock(mockRunner) - fail2ban.SetRunner(mockRunner) client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d") require.NoError(t, err) diff --git a/cmd/config_utils.go b/cmd/config_utils.go index f7738bd..cb769d3 100644 --- a/cmd/config_utils.go +++ b/cmd/config_utils.go @@ -164,6 +164,17 @@ func validateConfigPath(path, pathType string) (string, error) { return absPath, nil } +// validateConfigPathWithFallback validates a config path and returns the fallback if validation fails. +// This consolidates the common pattern of validate-or-fallback-with-logging used for config paths. +func validateConfigPathWithFallback(path, pathType, defaultPath, errorMsg string) string { + validated, err := validateConfigPath(path, pathType) + if err != nil { + Logger.WithError(err).WithField(shared.LogFieldPath, path).Error(errorMsg) + return defaultPath + } + return validated +} + // isReasonableSystemPath checks if a path is in a reasonable system location func isReasonableSystemPath(path, pathType string) bool { // Allow common system directories based on path type @@ -195,28 +206,20 @@ func NewConfigFromEnv() Config { if logDir == "" { logDir = shared.DefaultLogDir } - - validatedLogDir, err := validateConfigPath(logDir, shared.PathTypeLog) - if err != nil { - Logger.WithError(err).WithField(shared.LogFieldPath, logDir).Error("Invalid log directory from environment") - validatedLogDir = shared.DefaultLogDir // Fallback to safe default - } - cfg.LogDir = validatedLogDir + cfg.LogDir = validateConfigPathWithFallback( + logDir, shared.PathTypeLog, shared.DefaultLogDir, + "Invalid log directory from environment", + ) // Get and validate filter directory filterDir := os.Getenv("F2B_FILTER_DIR") if filterDir == "" { filterDir = shared.DefaultFilterDir } - - validatedFilterDir, err := validateConfigPath(filterDir, shared.PathTypeFilter) - if err != nil { - Logger.WithError(err). - WithField(shared.LogFieldPath, filterDir). - Error("Invalid filter directory from environment") - validatedFilterDir = shared.DefaultFilterDir // Fallback to safe default - } - cfg.FilterDir = validatedFilterDir + cfg.FilterDir = validateConfigPathWithFallback( + filterDir, shared.PathTypeFilter, shared.DefaultFilterDir, + "Invalid filter directory from environment", + ) // Configure timeouts from environment variables cfg.CommandTimeout = parseTimeoutFromEnv("F2B_COMMAND_TIMEOUT", shared.DefaultCommandTimeout) diff --git a/cmd/helpers.go b/cmd/helpers.go index 67979f0..2bfe98f 100644 --- a/cmd/helpers.go +++ b/cmd/helpers.go @@ -17,6 +17,20 @@ import ( "github.com/ivuorinen/f2b/fail2ban" ) +// createTimeoutContext creates a context with the configured command timeout. +// This helper consolidates the duplicate timeout handling pattern. +// If base is nil, context.Background() is used. +func createTimeoutContext(base context.Context, config *Config) (context.Context, context.CancelFunc) { + if base == nil { + base = context.Background() + } + timeout := shared.DefaultCommandTimeout + if config != nil && config.CommandTimeout > 0 { + timeout = config.CommandTimeout + } + return context.WithTimeout(base, timeout) +} + // IsCI detects if we're running in a CI environment func IsCI() bool { return fail2ban.IsCI() @@ -50,17 +64,8 @@ func NewContextualCommand( // Get the contextual logger logger := GetContextualLogger() - // Base on Cobra's context so signals/cancellations propagate - base := cmd.Context() - if base == nil { - base = context.Background() - } - // Create timeout context for the entire operation - timeout := shared.DefaultCommandTimeout - if config != nil && config.CommandTimeout > 0 { - timeout = config.CommandTimeout - } - ctx, cancel := context.WithTimeout(base, timeout) + // Create timeout context based on Cobra's context so signals/cancellations propagate + ctx, cancel := createTimeoutContext(cmd.Context(), config) defer cancel() // Extract command name from use string (first word) @@ -388,22 +393,63 @@ type OperationResult struct { Status string `json:"status"` } -// ProcessBanOperation processes ban operations across multiple jails -func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) { +// OperationType defines a ban or unban operation with its associated metadata +type OperationType struct { + // MetricsType is the metrics key for this operation (e.g., shared.MetricsBan) + MetricsType string + // Message is the log message for this operation (e.g., shared.MsgBanResult) + Message string + // Operation is the function to execute without context + Operation func(client fail2ban.Client, ip, jail string) (int, error) + // OperationCtx is the function to execute with context + OperationCtx func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) +} + +// BanOperationType defines the ban operation +var BanOperationType = OperationType{ + MetricsType: shared.MetricsBan, + Message: shared.MsgBanResult, + Operation: func(c fail2ban.Client, ip, jail string) (int, error) { + return c.BanIP(ip, jail) + }, + OperationCtx: func(ctx context.Context, c fail2ban.Client, ip, jail string) (int, error) { + return c.BanIPWithContext(ctx, ip, jail) + }, +} + +// UnbanOperationType defines the unban operation +var UnbanOperationType = OperationType{ + MetricsType: shared.MetricsUnban, + Message: shared.MsgUnbanResult, + Operation: func(c fail2ban.Client, ip, jail string) (int, error) { + return c.UnbanIP(ip, jail) + }, + OperationCtx: func(ctx context.Context, c fail2ban.Client, ip, jail string) (int, error) { + return c.UnbanIPWithContext(ctx, ip, jail) + }, +} + +// ProcessOperation processes operations across multiple jails using the specified operation type +func ProcessOperation( + client fail2ban.Client, + ip string, + jails []string, + opType OperationType, +) ([]OperationResult, error) { results := make([]OperationResult, 0, len(jails)) for _, jail := range jails { - code, err := client.BanIP(ip, jail) + code, err := opType.Operation(client, ip, jail) if err != nil { return nil, err } - status := InterpretBanStatus(code, shared.MetricsBan) + status := InterpretBanStatus(code, opType.MetricsType) Logger.WithFields(map[string]interface{}{ "ip": ip, "jail": jail, "status": status, - }).Info(shared.MsgBanResult) + }).Info(opType.Message) results = append(results, OperationResult{ IP: ip, @@ -415,6 +461,59 @@ func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]O return results, nil } +// ProcessOperationWithContext processes operations across multiple jails with timeout context +func ProcessOperationWithContext( + ctx context.Context, + client fail2ban.Client, + ip string, + jails []string, + opType OperationType, +) ([]OperationResult, error) { + logger := GetContextualLogger() + results := make([]OperationResult, 0, len(jails)) + + for _, jail := range jails { + // Add jail to context for this operation + jailCtx := WithJail(ctx, jail) + + // Time the operation + start := time.Now() + code, err := opType.OperationCtx(jailCtx, client, ip, jail) + duration := time.Since(start) + + if err != nil { + // Log the failed operation with timing + logger.LogBanOperation(jailCtx, opType.MetricsType, ip, jail, false, duration) + return nil, err + } + + status := InterpretBanStatus(code, opType.MetricsType) + + // Log the successful operation with timing + logger.LogBanOperation(jailCtx, opType.MetricsType, ip, jail, true, duration) + + // Log the operation-specific message (ban vs unban) + Logger.WithFields(map[string]interface{}{ + "ip": ip, + "jail": jail, + "status": status, + }).Info(opType.Message) + + results = append(results, OperationResult{ + IP: ip, + Jail: jail, + Status: status, + }) + } + + return results, nil +} + +// ProcessBanOperation processes ban operations across multiple jails +func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) { + return ProcessOperation(client, ip, jails, BanOperationType) +} + // ProcessBanOperationWithContext processes ban operations across multiple jails with timeout context func ProcessBanOperationWithContext( ctx context.Context, @@ -422,70 +521,12 @@ func ProcessBanOperationWithContext( ip string, jails []string, ) ([]OperationResult, error) { - logger := GetContextualLogger() - results := make([]OperationResult, 0, len(jails)) - - for _, jail := range jails { - // Add jail to context for this operation - jailCtx := WithJail(ctx, jail) - - // Time the ban operation - start := time.Now() - code, err := client.BanIPWithContext(jailCtx, ip, jail) - duration := time.Since(start) - - if err != nil { - // Log the failed operation with timing - logger.LogBanOperation(jailCtx, shared.MetricsBan, ip, jail, false, duration) - return nil, err - } - - status := InterpretBanStatus(code, shared.MetricsBan) - - // Log the successful operation with timing - logger.LogBanOperation(jailCtx, shared.MetricsBan, ip, jail, true, duration) - - Logger.WithFields(map[string]interface{}{ - "ip": ip, - "jail": jail, - "status": status, - }).Info(shared.MsgBanResult) - - results = append(results, OperationResult{ - IP: ip, - Jail: jail, - Status: status, - }) - } - - return results, nil + return ProcessOperationWithContext(ctx, client, ip, jails, BanOperationType) } // ProcessUnbanOperation processes unban operations across multiple jails func ProcessUnbanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) { - results := make([]OperationResult, 0, len(jails)) - - for _, jail := range jails { - code, err := client.UnbanIP(ip, jail) - if err != nil { - return nil, err - } - - status := InterpretBanStatus(code, shared.MetricsUnban) - Logger.WithFields(map[string]interface{}{ - "ip": ip, - "jail": jail, - "status": status, - }).Info(shared.MsgUnbanResult) - - results = append(results, OperationResult{ - IP: ip, - Jail: jail, - Status: status, - }) - } - - return results, nil + return ProcessOperation(client, ip, jails, UnbanOperationType) } // ProcessUnbanOperationWithContext processes unban operations across multiple jails with timeout context @@ -495,43 +536,7 @@ func ProcessUnbanOperationWithContext( ip string, jails []string, ) ([]OperationResult, error) { - logger := GetContextualLogger() - results := make([]OperationResult, 0, len(jails)) - - for _, jail := range jails { - // Add jail to context for this operation - jailCtx := WithJail(ctx, jail) - - // Time the unban operation - start := time.Now() - code, err := client.UnbanIPWithContext(jailCtx, ip, jail) - duration := time.Since(start) - - if err != nil { - // Log the failed operation with timing - logger.LogBanOperation(jailCtx, shared.MetricsUnban, ip, jail, false, duration) - return nil, err - } - - status := InterpretBanStatus(code, shared.MetricsUnban) - - // Log the successful operation with timing - logger.LogBanOperation(jailCtx, shared.MetricsUnban, ip, jail, true, duration) - - Logger.WithFields(map[string]interface{}{ - "ip": ip, - "jail": jail, - "status": status, - }).Info(shared.MsgUnbanResult) - - results = append(results, OperationResult{ - IP: ip, - Jail: jail, - Status: status, - }) - } - - return results, nil + return ProcessOperationWithContext(ctx, client, ip, jails, UnbanOperationType) } // Argument validation helpers diff --git a/cmd/helpers_config_test.go b/cmd/helpers_config_test.go index 0a50774..d6920a6 100644 --- a/cmd/helpers_config_test.go +++ b/cmd/helpers_config_test.go @@ -12,9 +12,7 @@ import ( // TestProcessBanOperation tests the ProcessBanOperation function func TestProcessBanOperation(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) + defer fail2ban.WithTestRunner(t, fail2ban.GetRunner())() tests := []struct { name string @@ -27,7 +25,7 @@ func TestProcessBanOperation(t *testing.T) { { name: "successful ban single jail", setupMock: func(m *fail2ban.MockRunner) { - setupBasicMockResponses(m) + fail2ban.StandardMockSetup(m) m.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1")) m.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1")) }, @@ -39,7 +37,7 @@ func TestProcessBanOperation(t *testing.T) { { name: "successful ban multiple jails", setupMock: func(m *fail2ban.MockRunner) { - setupBasicMockResponses(m) + fail2ban.StandardMockSetup(m) m.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1")) m.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1")) m.SetResponse("fail2ban-client set apache banip 192.168.1.1", []byte("1")) @@ -53,7 +51,7 @@ func TestProcessBanOperation(t *testing.T) { { name: "invalid IP address", setupMock: func(m *fail2ban.MockRunner) { - setupBasicMockResponses(m) + fail2ban.StandardMockSetup(m) }, ip: "invalid.ip", jails: []string{"sshd"}, @@ -147,13 +145,3 @@ func TestParseTimeoutFromEnv(t *testing.T) { }) } } - -// setupBasicMockResponses is a helper for setting up version check and ping responses -func setupBasicMockResponses(m *fail2ban.MockRunner) { - m.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0")) - m.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0")) - m.SetResponse("fail2ban-client ping", []byte("Server replied: pong")) - m.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong")) - m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache")) - m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache")) -} diff --git a/cmd/ip_command_pattern.go b/cmd/ip_command_pattern.go index ad0695a..ce72e9e 100644 --- a/cmd/ip_command_pattern.go +++ b/cmd/ip_command_pattern.go @@ -87,14 +87,9 @@ func ExecuteIPCommand( // Get the contextual logger logger := GetContextualLogger() - // Safe timeout handling with nil check - timeout := shared.DefaultCommandTimeout - if config != nil && config.CommandTimeout > 0 { - timeout = config.CommandTimeout - } - // Create timeout context for the entire operation - ctx, cancel := context.WithTimeout(context.Background(), timeout) + // Use cmd.Context() to inherit Cobra's signal cancellation + ctx, cancel := createTimeoutContext(cmd.Context(), config) defer cancel() // Add command context diff --git a/cmd/ip_processors.go b/cmd/ip_processors.go index 0f36115..5bd0e57 100644 --- a/cmd/ip_processors.go +++ b/cmd/ip_processors.go @@ -9,6 +9,43 @@ import ( "github.com/ivuorinen/f2b/fail2ban" ) +// multiJailOperationFunc defines the signature for IP operation functions that process multiple jails +type multiJailOperationFunc func( + ctx context.Context, + client fail2ban.Client, + ip string, + jails []string, +) ([]OperationResult, error) + +// validateIPAndJails validates an IP address and a list of jail names. +// Returns an error if any validation fails. +func validateIPAndJails(ip string, jails []string) error { + if err := fail2ban.ValidateIP(ip); err != nil { + return err + } + for _, jail := range jails { + if err := fail2ban.ValidateJail(jail); err != nil { + return err + } + } + return nil +} + +// processWithValidation validates inputs and executes the provided operation function. +// This consolidates the common validation-then-execute pattern. +func processWithValidation( + ctx context.Context, + client fail2ban.Client, + ip string, + jails []string, + opFunc multiJailOperationFunc, +) ([]OperationResult, error) { + if err := validateIPAndJails(ip, jails); err != nil { + return nil, err + } + return opFunc(ctx, client, ip, jails) +} + // BanProcessor handles ban operations type BanProcessor struct{} @@ -19,19 +56,7 @@ func (p *BanProcessor) ProcessSingle( ip string, jails []string, ) ([]OperationResult, error) { - // Validate IP address before privilege escalation - if err := fail2ban.ValidateIP(ip); err != nil { - return nil, err - } - - // Validate each jail name before privilege escalation - for _, jail := range jails { - if err := fail2ban.ValidateJail(jail); err != nil { - return nil, err - } - } - - return ProcessBanOperationWithContext(ctx, client, ip, jails) + return processWithValidation(ctx, client, ip, jails, ProcessBanOperationWithContext) } // ProcessParallel processes ban operations for multiple jails in parallel @@ -41,19 +66,7 @@ func (p *BanProcessor) ProcessParallel( ip string, jails []string, ) ([]OperationResult, error) { - // Validate IP address before privilege escalation - if err := fail2ban.ValidateIP(ip); err != nil { - return nil, err - } - - // Validate each jail name before privilege escalation - for _, jail := range jails { - if err := fail2ban.ValidateJail(jail); err != nil { - return nil, err - } - } - - return ProcessBanOperationParallelWithContext(ctx, client, ip, jails) + return processWithValidation(ctx, client, ip, jails, ProcessBanOperationParallelWithContext) } // UnbanProcessor handles unban operations @@ -66,19 +79,7 @@ func (p *UnbanProcessor) ProcessSingle( ip string, jails []string, ) ([]OperationResult, error) { - // Validate IP address before privilege escalation - if err := fail2ban.ValidateIP(ip); err != nil { - return nil, err - } - - // Validate each jail name before privilege escalation - for _, jail := range jails { - if err := fail2ban.ValidateJail(jail); err != nil { - return nil, err - } - } - - return ProcessUnbanOperationWithContext(ctx, client, ip, jails) + return processWithValidation(ctx, client, ip, jails, ProcessUnbanOperationWithContext) } // ProcessParallel processes unban operations for multiple jails in parallel @@ -88,17 +89,5 @@ func (p *UnbanProcessor) ProcessParallel( ip string, jails []string, ) ([]OperationResult, error) { - // Validate IP address before privilege escalation - if err := fail2ban.ValidateIP(ip); err != nil { - return nil, err - } - - // Validate each jail name before privilege escalation - for _, jail := range jails { - if err := fail2ban.ValidateJail(jail); err != nil { - return nil, err - } - } - - return ProcessUnbanOperationParallelWithContext(ctx, client, ip, jails) + return processWithValidation(ctx, client, ip, jails, ProcessUnbanOperationParallelWithContext) } diff --git a/cmd/logging.go b/cmd/logging.go index 3734e35..654c67d 100644 --- a/cmd/logging.go +++ b/cmd/logging.go @@ -5,10 +5,12 @@ package cmd import ( "context" + "strings" "time" "github.com/sirupsen/logrus" + "github.com/ivuorinen/f2b/fail2ban" "github.com/ivuorinen/f2b/shared" ) @@ -56,57 +58,68 @@ func getVersion() string { return version } +// contextKeyEntry defines a context key and its log field name +type contextKeyEntry struct { + key any // The context key to look up + fieldName string // The log field name to use +} + +// contextKeys lists all context keys to extract for logging +var contextKeys = []contextKeyEntry{ + {shared.ContextKeyRequestID, string(shared.ContextKeyRequestID)}, + {shared.ContextKeyOperation, string(shared.ContextKeyOperation)}, + {shared.ContextKeyIP, string(shared.ContextKeyIP)}, + {shared.ContextKeyJail, string(shared.ContextKeyJail)}, + {shared.ContextKeyCommand, string(shared.ContextKeyCommand)}, +} + // WithContext creates a logger entry with context values func (cl *ContextualLogger) WithContext(ctx context.Context) *logrus.Entry { entry := cl.WithFields(cl.defaultFields) - // Extract context values and add as fields (using consistent constants) - if requestID := ctx.Value(shared.ContextKeyRequestID); requestID != nil { - entry = entry.WithField(string(shared.ContextKeyRequestID), requestID) - } - - if operation := ctx.Value(shared.ContextKeyOperation); operation != nil { - entry = entry.WithField(string(shared.ContextKeyOperation), operation) - } - - if ip := ctx.Value(shared.ContextKeyIP); ip != nil { - entry = entry.WithField(string(shared.ContextKeyIP), ip) - } - - if jail := ctx.Value(shared.ContextKeyJail); jail != nil { - entry = entry.WithField(string(shared.ContextKeyJail), jail) - } - - if command := ctx.Value(shared.ContextKeyCommand); command != nil { - entry = entry.WithField(string(shared.ContextKeyCommand), command) + // Extract context values and add as fields using table-driven approach + for _, ck := range contextKeys { + if val := ctx.Value(ck.key); val != nil { + entry = entry.WithField(ck.fieldName, val) + } } return entry } -// WithOperation adds operation context and returns a new context +// WithOperation adds operation context and returns a new context. +// Delegates to fail2ban.WithOperation for consistent validation. func WithOperation(ctx context.Context, operation string) context.Context { - return context.WithValue(ctx, shared.ContextKeyOperation, operation) + return fail2ban.WithOperation(ctx, operation) } -// WithIP adds IP context and returns a new context +// WithIP adds IP context and returns a new context. +// Delegates to fail2ban.WithIP for consistent IP validation. func WithIP(ctx context.Context, ip string) context.Context { - return context.WithValue(ctx, shared.ContextKeyIP, ip) + return fail2ban.WithIP(ctx, ip) } -// WithJail adds jail context and returns a new context +// WithJail adds jail context and returns a new context. +// Delegates to fail2ban.WithJail for consistent jail name validation. func WithJail(ctx context.Context, jail string) context.Context { - return context.WithValue(ctx, shared.ContextKeyJail, jail) + return fail2ban.WithJail(ctx, jail) } -// WithCommand adds command context and returns a new context +// WithCommand adds command context and returns a new context. +// This is cmd-specific as fail2ban doesn't need command tracking. +// Empty commands are not stored in context. func WithCommand(ctx context.Context, command string) context.Context { + command = strings.TrimSpace(command) + if command == "" { + return ctx + } return context.WithValue(ctx, shared.ContextKeyCommand, command) } -// WithRequestID adds request ID context and returns a new context +// WithRequestID adds request ID context and returns a new context. +// Delegates to fail2ban.WithRequestID for consistent validation. func WithRequestID(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, shared.ContextKeyRequestID, requestID) + return fail2ban.WithRequestID(ctx, requestID) } // LogOperation logs the start and end of an operation with timing and metrics diff --git a/cmd/metrics.go b/cmd/metrics.go index cec07db..67a854f 100644 --- a/cmd/metrics.go +++ b/cmd/metrics.go @@ -68,17 +68,34 @@ func NewMetrics() *Metrics { } } +// recordOperationMetrics records metrics for any operation type. +// This helper consolidates the duplicate metrics recording pattern. +func (m *Metrics) recordOperationMetrics( + execCounter, durationCounter, failureCounter *int64, + buckets map[string]*LatencyBucket, + operation string, + duration time.Duration, + success bool, +) { + atomic.AddInt64(execCounter, 1) + atomic.AddInt64(durationCounter, duration.Milliseconds()) + if !success { + atomic.AddInt64(failureCounter, 1) + } + m.recordLatencyBucket(buckets, operation, duration) +} + // RecordCommandExecution records metrics for command execution func (m *Metrics) RecordCommandExecution(command string, duration time.Duration, success bool) { - atomic.AddInt64(&m.CommandExecutions, 1) - atomic.AddInt64(&m.CommandTotalDuration, duration.Milliseconds()) - - if !success { - atomic.AddInt64(&m.CommandFailures, 1) - } - - // Record latency bucket - m.recordLatencyBucket(m.commandLatencyBuckets, command, duration) + m.recordOperationMetrics( + &m.CommandExecutions, + &m.CommandTotalDuration, + &m.CommandFailures, + m.commandLatencyBuckets, + command, + duration, + success, + ) } // RecordBanOperation records metrics for ban operations @@ -99,15 +116,15 @@ func (m *Metrics) RecordBanOperation(operation string, _ time.Duration, success // RecordClientOperation records metrics for client operations func (m *Metrics) RecordClientOperation(operation string, duration time.Duration, success bool) { - atomic.AddInt64(&m.ClientOperations, 1) - atomic.AddInt64(&m.ClientTotalDuration, duration.Milliseconds()) - - if !success { - atomic.AddInt64(&m.ClientFailures, 1) - } - - // Record latency bucket - m.recordLatencyBucket(m.clientLatencyBuckets, operation, duration) + m.recordOperationMetrics( + &m.ClientOperations, + &m.ClientTotalDuration, + &m.ClientFailures, + m.clientLatencyBuckets, + operation, + duration, + success, + ) } // RecordValidationCacheHit records validation cache hits @@ -143,6 +160,25 @@ func (m *Metrics) UpdateGoroutineCount(count int64) { atomic.StoreInt64(&m.GoroutineCount, count) } +// copyBuckets creates a snapshot copy of latency buckets +// This helper consolidates the duplicate bucket copying logic +func copyBuckets(buckets map[string]*LatencyBucket) map[string]LatencyBucketSnapshot { + result := make(map[string]LatencyBucketSnapshot, len(buckets)) + for op, bucket := range buckets { + result[op] = LatencyBucketSnapshot{ + Under1ms: atomic.LoadInt64(&bucket.Under1ms), + Under10ms: atomic.LoadInt64(&bucket.Under10ms), + Under100ms: atomic.LoadInt64(&bucket.Under100ms), + Under1s: atomic.LoadInt64(&bucket.Under1s), + Under10s: atomic.LoadInt64(&bucket.Under10s), + Over10s: atomic.LoadInt64(&bucket.Over10s), + Total: atomic.LoadInt64(&bucket.Total), + TotalTime: atomic.LoadInt64(&bucket.TotalTime), + } + } + return result +} + // recordLatencyBucket records latency in appropriate bucket func (m *Metrics) recordLatencyBucket(buckets map[string]*LatencyBucket, operation string, duration time.Duration) { m.mu.Lock() @@ -177,37 +213,8 @@ func (m *Metrics) recordLatencyBucket(buckets map[string]*LatencyBucket, operati // GetSnapshot returns a snapshot of current metrics func (m *Metrics) GetSnapshot() MetricsSnapshot { m.mu.RLock() - - // Copy command latency buckets - commandBuckets := make(map[string]LatencyBucketSnapshot) - for op, bucket := range m.commandLatencyBuckets { - commandBuckets[op] = LatencyBucketSnapshot{ - Under1ms: atomic.LoadInt64(&bucket.Under1ms), - Under10ms: atomic.LoadInt64(&bucket.Under10ms), - Under100ms: atomic.LoadInt64(&bucket.Under100ms), - Under1s: atomic.LoadInt64(&bucket.Under1s), - Under10s: atomic.LoadInt64(&bucket.Under10s), - Over10s: atomic.LoadInt64(&bucket.Over10s), - Total: atomic.LoadInt64(&bucket.Total), - TotalTime: atomic.LoadInt64(&bucket.TotalTime), - } - } - - // Copy client latency buckets - clientBuckets := make(map[string]LatencyBucketSnapshot) - for op, bucket := range m.clientLatencyBuckets { - clientBuckets[op] = LatencyBucketSnapshot{ - Under1ms: atomic.LoadInt64(&bucket.Under1ms), - Under10ms: atomic.LoadInt64(&bucket.Under10ms), - Under100ms: atomic.LoadInt64(&bucket.Under100ms), - Under1s: atomic.LoadInt64(&bucket.Under1s), - Under10s: atomic.LoadInt64(&bucket.Under10s), - Over10s: atomic.LoadInt64(&bucket.Over10s), - Total: atomic.LoadInt64(&bucket.Total), - TotalTime: atomic.LoadInt64(&bucket.TotalTime), - } - } - + commandBuckets := copyBuckets(m.commandLatencyBuckets) + clientBuckets := copyBuckets(m.clientLatencyBuckets) m.mu.RUnlock() return MetricsSnapshot{ diff --git a/cmd/metrics_cmd.go b/cmd/metrics_cmd.go index d99c4ce..f9e6001 100644 --- a/cmd/metrics_cmd.go +++ b/cmd/metrics_cmd.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" + "sort" "strings" "github.com/spf13/cobra" @@ -97,35 +98,40 @@ func printMetricsPlain(output io.Writer, snapshot MetricsSnapshot) error { // Command latency distribution if len(snapshot.CommandLatencyBuckets) > 0 { sb.WriteString("Command Latency Distribution:\n") - for cmd, bucket := range snapshot.CommandLatencyBuckets { - sb.WriteString(fmt.Sprintf(shared.MetricsFmtOperationHeader, cmd)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1s, bucket.Under1s)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10s, bucket.Under10s)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyOver10s, bucket.Over10s)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtAverageLatency, bucket.GetAverageLatency())) - } + formatLatencyBuckets(&sb, snapshot.CommandLatencyBuckets) sb.WriteString("\n") } // Client latency distribution if len(snapshot.ClientLatencyBuckets) > 0 { sb.WriteString("Client Operation Latency Distribution:\n") - for op, bucket := range snapshot.ClientLatencyBuckets { - sb.WriteString(fmt.Sprintf(shared.MetricsFmtOperationHeader, op)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1s, bucket.Under1s)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10s, bucket.Under10s)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyOver10s, bucket.Over10s)) - sb.WriteString(fmt.Sprintf(shared.MetricsFmtAverageLatency, bucket.GetAverageLatency())) - } + formatLatencyBuckets(&sb, snapshot.ClientLatencyBuckets) } // Write the entire string at once _, err := output.Write([]byte(sb.String())) return err } + +// formatLatencyBuckets writes latency bucket distribution to the builder. +// Keys are sorted for deterministic output. +func formatLatencyBuckets(sb *strings.Builder, buckets map[string]LatencyBucketSnapshot) { + // Sort keys for deterministic output + keys := make([]string, 0, len(buckets)) + for name := range buckets { + keys = append(keys, name) + } + sort.Strings(keys) + + for _, name := range keys { + bucket := buckets[name] + fmt.Fprintf(sb, shared.MetricsFmtOperationHeader, name) + fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms) + fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms) + fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms) + fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder1s, bucket.Under1s) + fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder10s, bucket.Under10s) + fmt.Fprintf(sb, shared.MetricsFmtLatencyOver10s, bucket.Over10s) + fmt.Fprintf(sb, shared.MetricsFmtAverageLatency, bucket.GetAverageLatency()) + } +} diff --git a/cmd/parallel_operations.go b/cmd/parallel_operations.go index a16bd6a..de458a8 100644 --- a/cmd/parallel_operations.go +++ b/cmd/parallel_operations.go @@ -2,11 +2,11 @@ package cmd import ( "context" + "errors" "runtime" "sync" "github.com/ivuorinen/f2b/fail2ban" - "github.com/ivuorinen/f2b/shared" ) // ParallelOperationProcessor handles parallel ban/unban operations across multiple jails @@ -24,15 +24,16 @@ func NewParallelOperationProcessor(workerCount int) *ParallelOperationProcessor } } -// ProcessBanOperationParallel processes ban operations across multiple jails in parallel -func (pop *ParallelOperationProcessor) ProcessBanOperationParallel( +// ProcessOperationParallel processes operations across multiple jails in parallel +func (pop *ParallelOperationProcessor) ProcessOperationParallel( client fail2ban.Client, ip string, jails []string, + opType OperationType, ) ([]OperationResult, error) { if len(jails) <= 1 { // For single jail, use sequential processing to avoid overhead - return ProcessBanOperation(client, ip, jails) + return ProcessOperation(client, ip, jails, opType) } return pop.processOperations( @@ -40,13 +41,43 @@ func (pop *ParallelOperationProcessor) ProcessBanOperationParallel( client, ip, jails, - func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) { - return client.BanIPWithContext(ctx, ip, jail) - }, - shared.MetricsBan, + opType.OperationCtx, + opType.MetricsType, ) } +// ProcessOperationParallelWithContext processes operations across multiple jails in parallel with context +func (pop *ParallelOperationProcessor) ProcessOperationParallelWithContext( + ctx context.Context, + client fail2ban.Client, + ip string, + jails []string, + opType OperationType, +) ([]OperationResult, error) { + if len(jails) <= 1 { + // For single jail, use sequential processing to avoid overhead + return ProcessOperationWithContext(ctx, client, ip, jails, opType) + } + + return pop.processOperations( + ctx, + client, + ip, + jails, + opType.OperationCtx, + opType.MetricsType, + ) +} + +// ProcessBanOperationParallel processes ban operations across multiple jails in parallel +func (pop *ParallelOperationProcessor) ProcessBanOperationParallel( + client fail2ban.Client, + ip string, + jails []string, +) ([]OperationResult, error) { + return pop.ProcessOperationParallel(client, ip, jails, BanOperationType) +} + // ProcessBanOperationParallelWithContext processes ban operations across // multiple jails in parallel with timeout context func (pop *ParallelOperationProcessor) ProcessBanOperationParallelWithContext( @@ -55,21 +86,7 @@ func (pop *ParallelOperationProcessor) ProcessBanOperationParallelWithContext( ip string, jails []string, ) ([]OperationResult, error) { - if len(jails) <= 1 { - // For single jail, use sequential processing to avoid overhead - return ProcessBanOperationWithContext(ctx, client, ip, jails) - } - - return pop.processOperations( - ctx, - client, - ip, - jails, - func(opCtx context.Context, client fail2ban.Client, ip, jail string) (int, error) { - return client.BanIPWithContext(opCtx, ip, jail) - }, - shared.MetricsBan, - ) + return pop.ProcessOperationParallelWithContext(ctx, client, ip, jails, BanOperationType) } // ProcessUnbanOperationParallel processes unban operations across multiple jails in parallel @@ -78,21 +95,7 @@ func (pop *ParallelOperationProcessor) ProcessUnbanOperationParallel( ip string, jails []string, ) ([]OperationResult, error) { - if len(jails) <= 1 { - // For single jail, use sequential processing to avoid overhead - return ProcessUnbanOperation(client, ip, jails) - } - - return pop.processOperations( - context.Background(), - client, - ip, - jails, - func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) { - return client.UnbanIPWithContext(ctx, ip, jail) - }, - shared.MetricsUnban, - ) + return pop.ProcessOperationParallel(client, ip, jails, UnbanOperationType) } // ProcessUnbanOperationParallelWithContext processes unban operations across @@ -103,26 +106,35 @@ func (pop *ParallelOperationProcessor) ProcessUnbanOperationParallelWithContext( ip string, jails []string, ) ([]OperationResult, error) { - if len(jails) <= 1 { - // For single jail, use sequential processing to avoid overhead - return ProcessUnbanOperationWithContext(ctx, client, ip, jails) - } - - return pop.processOperations( - ctx, - client, - ip, - jails, - func(opCtx context.Context, client fail2ban.Client, ip, jail string) (int, error) { - return client.UnbanIPWithContext(opCtx, ip, jail) - }, - shared.MetricsUnban, - ) + return pop.ProcessOperationParallelWithContext(ctx, client, ip, jails, UnbanOperationType) } // operationFunc represents a ban or unban operation with context type operationFunc func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) +// validateOperationInputs validates IP and jail inputs before parallel processing. +// Returns an aggregated error if any inputs are invalid. +func validateOperationInputs(ctx context.Context, ip string, jails []string) error { + var errs []error + + // Validate IP address + if err := fail2ban.CachedValidateIP(ctx, ip); err != nil { + errs = append(errs, err) + } + + // Validate each jail name + for _, jail := range jails { + if err := fail2ban.CachedValidateJail(ctx, jail); err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + // processOperations handles the parallel processing of operations func (pop *ParallelOperationProcessor) processOperations( ctx context.Context, @@ -132,6 +144,11 @@ func (pop *ParallelOperationProcessor) processOperations( operation operationFunc, operationType string, ) ([]OperationResult, error) { + // Validate inputs before processing + if err := validateOperationInputs(ctx, ip, jails); err != nil { + return nil, err + } + results := make([]OperationResult, len(jails)) resultCh := make(chan operationResult, len(jails)) @@ -167,13 +184,20 @@ func (pop *ParallelOperationProcessor) processOperations( close(resultCh) }() - // Collect results + // Collect results and errors + var errs []error for result := range resultCh { if result.index >= 0 && result.index < len(results) { results[result.index] = result.result } + if result.err != nil { + errs = append(errs, result.err) + } } + if len(errs) > 0 { + return results, errors.Join(errs...) + } return results, nil } @@ -187,6 +211,7 @@ type jailWork struct { type operationResult struct { result OperationResult index int + err error } // worker processes jail operations @@ -215,16 +240,15 @@ func (pop *ParallelOperationProcessor) worker( "status": status, }).Info("Operation result") - result := operationResult{ + resultCh <- operationResult{ result: OperationResult{ IP: ip, Jail: work.jail, Status: status, }, index: work.index, + err: err, } - - resultCh <- result } } diff --git a/cmd/processors_test.go b/cmd/processors_test.go index 3282870..52352a9 100644 --- a/cmd/processors_test.go +++ b/cmd/processors_test.go @@ -12,17 +12,13 @@ import ( // TestUnbanProcessorProcessParallel tests the ProcessParallel method func TestUnbanProcessorProcessParallel(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - mockRunner := fail2ban.NewMockRunner() - setupBasicMockResponses(mockRunner) + defer fail2ban.WithTestRunner(t, mockRunner)() + fail2ban.StandardMockSetup(mockRunner) mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1")) mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1")) mockRunner.SetResponse("fail2ban-client set apache unbanip 192.168.1.1", []byte("1")) mockRunner.SetResponse("sudo fail2ban-client set apache unbanip 192.168.1.1", []byte("1")) - fail2ban.SetRunner(mockRunner) client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d") require.NoError(t, err) diff --git a/cmd/remaining_coverage_test.go b/cmd/remaining_coverage_test.go index f42139a..f635b54 100644 --- a/cmd/remaining_coverage_test.go +++ b/cmd/remaining_coverage_test.go @@ -12,17 +12,13 @@ import ( // TestProcessBanOperationParallel tests the ProcessBanOperationParallel wrapper function func TestProcessBanOperationParallel(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - mockRunner := fail2ban.NewMockRunner() - setupBasicMockResponses(mockRunner) + defer fail2ban.WithTestRunner(t, mockRunner)() + fail2ban.StandardMockSetup(mockRunner) mockRunner.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1")) mockRunner.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1")) mockRunner.SetResponse("fail2ban-client set apache banip 192.168.1.1", []byte("1")) mockRunner.SetResponse("sudo fail2ban-client set apache banip 192.168.1.1", []byte("1")) - fail2ban.SetRunner(mockRunner) client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d") require.NoError(t, err) @@ -34,15 +30,11 @@ func TestProcessBanOperationParallel(t *testing.T) { // TestProcessUnbanOperationParallel tests the ProcessUnbanOperationParallel wrapper function func TestProcessUnbanOperationParallel(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - mockRunner := fail2ban.NewMockRunner() - setupBasicMockResponses(mockRunner) + defer fail2ban.WithTestRunner(t, mockRunner)() + fail2ban.StandardMockSetup(mockRunner) mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1")) mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1")) - fail2ban.SetRunner(mockRunner) client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d") require.NoError(t, err) @@ -54,15 +46,11 @@ func TestProcessUnbanOperationParallel(t *testing.T) { // TestProcessBanOperationParallelWithContext tests the wrapper with context func TestProcessBanOperationParallelWithContext(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - mockRunner := fail2ban.NewMockRunner() - setupBasicMockResponses(mockRunner) + defer fail2ban.WithTestRunner(t, mockRunner)() + fail2ban.StandardMockSetup(mockRunner) mockRunner.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1")) mockRunner.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1")) - fail2ban.SetRunner(mockRunner) client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d") require.NoError(t, err) @@ -75,15 +63,11 @@ func TestProcessBanOperationParallelWithContext(t *testing.T) { // TestProcessUnbanOperationParallelWithContext tests the wrapper with context func TestProcessUnbanOperationParallelWithContext(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - mockRunner := fail2ban.NewMockRunner() - setupBasicMockResponses(mockRunner) + defer fail2ban.WithTestRunner(t, mockRunner)() + fail2ban.StandardMockSetup(mockRunner) mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1")) mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1")) - fail2ban.SetRunner(mockRunner) client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d") require.NoError(t, err) diff --git a/cmd/service.go b/cmd/service.go index 17bf0e4..f38e0d5 100644 --- a/cmd/service.go +++ b/cmd/service.go @@ -15,7 +15,11 @@ func ServiceCmd(config *Config) *cobra.Command { nil, func(_ *cobra.Command, args []string) error { // Validate service action argument - if err := RequireArguments(args, 1, "action required: start|stop|restart|status|reload|enable|disable"); err != nil { + if err := RequireArguments( + args, + 1, + "action required: start|stop|restart|status|reload|enable|disable", + ); err != nil { return HandleValidationError(err) } diff --git a/cmd/test_framework_additional_test.go b/cmd/test_framework_additional_test.go index 2f30f55..644cc71 100644 --- a/cmd/test_framework_additional_test.go +++ b/cmd/test_framework_additional_test.go @@ -110,18 +110,14 @@ func TestValidateConfigPath(t *testing.T) { // TestLogsWatchCmdCreation tests LogsWatchCmd creation func TestLogsWatchCmdCreation(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - mockRunner := fail2ban.NewMockRunner() + defer fail2ban.WithTestRunner(t, mockRunner)() mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0")) mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0")) mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong")) mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong")) mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) - fail2ban.SetRunner(mockRunner) client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d") require.NoError(t, err) @@ -142,18 +138,14 @@ func TestLogsWatchCmdCreation(t *testing.T) { // TestGetLogLinesWithLimitAndContext_Function tests the function func TestGetLogLinesWithLimitAndContext_Function(t *testing.T) { - // Save and restore original runner - originalRunner := fail2ban.GetRunner() - defer fail2ban.SetRunner(originalRunner) - mockRunner := fail2ban.NewMockRunner() + defer fail2ban.WithTestRunner(t, mockRunner)() mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0")) mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0")) mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong")) mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong")) mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) - fail2ban.SetRunner(mockRunner) tmpDir := t.TempDir() oldLogDir := fail2ban.GetLogDir() diff --git a/fail2ban/ban_record_parser.go b/fail2ban/ban_record_parser.go index bfe55cf..fdcd095 100644 --- a/fail2ban/ban_record_parser.go +++ b/fail2ban/ban_record_parser.go @@ -95,6 +95,23 @@ func (btc *BoundedTimeCache) Size() int { return len(btc.cache) } +// ParseWithLayout parses a time string using the specified layout with caching. +// This method consolidates the cache-lookup-parse-store pattern used across +// different time parsing caches in the codebase. +func (btc *BoundedTimeCache) ParseWithLayout(timeStr, layout string) (time.Time, error) { + // Fast path: check cache + if cached, ok := btc.Load(timeStr); ok { + return cached, nil + } + + // Parse and cache - only cache successful parses + t, err := time.Parse(layout, timeStr) + if err == nil { + btc.Store(timeStr, t) + } + return t, err +} + // BanRecordParser provides high-performance parsing of ban records type BanRecordParser struct { // Pools for zero-allocation parsing (goroutine-safe) @@ -167,17 +184,7 @@ func NewFastTimeCache(layout string) (*FastTimeCache, error) { // ParseTimeOptimized parses time with minimal allocations func (ftc *FastTimeCache) ParseTimeOptimized(timeStr string) (time.Time, error) { - // Fast path: check cache - if cached, ok := ftc.parseCache.Load(timeStr); ok { - return cached, nil - } - - // Parse and cache - only cache successful parses - t, err := time.Parse(ftc.layout, timeStr) - if err == nil { - ftc.parseCache.Store(timeStr, t) - } - return t, err + return ftc.parseCache.ParseWithLayout(timeStr, ftc.layout) } // BuildTimeStringOptimized builds time string with zero allocations using byte buffer diff --git a/fail2ban/client_withcontext_test.go b/fail2ban/client_withcontext_test.go index c0a18e6..ee26295 100644 --- a/fail2ban/client_withcontext_test.go +++ b/fail2ban/client_withcontext_test.go @@ -12,17 +12,6 @@ import ( "github.com/stretchr/testify/require" ) -// setupBasicMockResponses sets up the basic responses needed for client initialization -func setupBasicMockResponses(m *MockRunner) { - m.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0")) - m.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0")) - m.SetResponse("fail2ban-client ping", []byte("Server replied: pong")) - m.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong")) - // NewClient calls fetchJailsWithContext which runs status - m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache")) - m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache")) -} - // TestListJailsWithContext tests jail listing with context func TestListJailsWithContext(t *testing.T) { tests := []struct { @@ -35,11 +24,11 @@ func TestListJailsWithContext(t *testing.T) { { name: "successful jail listing", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) }, timeout: 5 * time.Second, expectError: false, - expectJails: []string{"sshd", "apache"}, // From setupBasicMockResponses + expectJails: []string{"sshd", "apache"}, // From StandardMockSetup }, } @@ -83,7 +72,7 @@ func TestStatusAllWithContext(t *testing.T) { { name: "successful status all", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) }, @@ -94,7 +83,7 @@ func TestStatusAllWithContext(t *testing.T) { { name: "context timeout", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) }, @@ -145,7 +134,7 @@ func TestStatusJailWithContext(t *testing.T) { name: "successful status jail", jail: "sshd", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse( "fail2ban-client status sshd", []byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"), @@ -163,7 +152,7 @@ func TestStatusJailWithContext(t *testing.T) { name: "invalid jail name", jail: "invalid@jail", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) // Validation will fail before command execution }, timeout: 5 * time.Second, @@ -173,7 +162,7 @@ func TestStatusJailWithContext(t *testing.T) { name: "context timeout", jail: "sshd", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse( "fail2ban-client status sshd", []byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"), @@ -234,7 +223,7 @@ func TestUnbanIPWithContext(t *testing.T) { ip: "192.168.1.100", jail: "sshd", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("0")) m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("0")) }, @@ -247,7 +236,7 @@ func TestUnbanIPWithContext(t *testing.T) { ip: "192.168.1.100", jail: "sshd", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("1")) m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("1")) }, @@ -260,7 +249,7 @@ func TestUnbanIPWithContext(t *testing.T) { ip: "invalid-ip", jail: "sshd", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) // Validation will fail before command execution }, timeout: 5 * time.Second, @@ -271,7 +260,7 @@ func TestUnbanIPWithContext(t *testing.T) { ip: "192.168.1.100", jail: "invalid@jail", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) // Validation will fail before command execution }, timeout: 5 * time.Second, @@ -282,7 +271,7 @@ func TestUnbanIPWithContext(t *testing.T) { ip: "192.168.1.100", jail: "sshd", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("0")) m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("0")) }, @@ -332,7 +321,7 @@ func TestListFiltersWithContext(t *testing.T) { { name: "successful filter listing", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) // Mock responses not needed - uses file system }, setupEnv: func() { @@ -345,7 +334,7 @@ func TestListFiltersWithContext(t *testing.T) { { name: "context timeout", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) // Not applicable for file system operation }, setupEnv: func() { @@ -412,7 +401,7 @@ logpath = /var/log/auth.log name: "successful filter test", filter: "sshd", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse( "fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"), []byte("Success: 0 matches"), @@ -429,7 +418,7 @@ logpath = /var/log/auth.log name: "invalid filter name", filter: "invalid@filter", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) // Validation will fail before command execution }, timeout: 5 * time.Second, @@ -439,7 +428,7 @@ logpath = /var/log/auth.log name: "context timeout", filter: "sshd", setupMock: func(m *MockRunner) { - setupBasicMockResponses(m) + StandardMockSetup(m) m.SetResponse( "fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"), []byte("Success: 0 matches"), @@ -485,7 +474,7 @@ logpath = /var/log/auth.log // TestWithContextCancellation tests that all WithContext functions respect cancellation func TestWithContextCancellation(t *testing.T) { mock := NewMockRunner() - setupBasicMockResponses(mock) + StandardMockSetup(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") @@ -520,7 +509,7 @@ func TestWithContextCancellation(t *testing.T) { // TestWithContextDeadline tests that all WithContext functions respect deadlines func TestWithContextDeadline(t *testing.T) { mock := NewMockRunner() - setupBasicMockResponses(mock) + StandardMockSetup(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") @@ -576,7 +565,7 @@ func TestWithContextDeadline(t *testing.T) { // TestWithContextValidation tests that validation happens before context usage func TestWithContextValidation(t *testing.T) { mock := NewMockRunner() - setupBasicMockResponses(mock) + StandardMockSetup(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") diff --git a/fail2ban/context_helpers.go b/fail2ban/context_helpers.go index c57171e..6eca013 100644 --- a/fail2ban/context_helpers.go +++ b/fail2ban/context_helpers.go @@ -2,30 +2,60 @@ package fail2ban import "context" -// ContextWrappers provides a helper to automatically generate WithContext method wrappers. -// This eliminates the need for duplicate WithContext implementations across different Client types. -// Usage: embed this in your Client struct and call DefineContextWrappers to get automatic context support. -type ContextWrappers struct{} +// Context Wrapper Pattern +// +// This package provides generic helper functions for creating WithContext method wrappers. +// These helpers eliminate boilerplate for methods that simply need context propagation +// without additional logic. +// +// For simple methods (no validation, direct delegation to non-context version): +// +// func (c *Client) ListJailsWithContext(ctx context.Context) ([]string, error) { +// return wrapWithContext0(c.ListJails)(ctx) +// } +// +// For complex methods (validation, custom logic, different execution paths): +// Implement the WithContext version directly with the required logic. +// +// Note: Code generation was evaluated but determined unnecessary because: +// - Only 2 methods use the simple wrapper pattern +// - Most WithContext methods require custom validation/logic +// - The generic helpers below already solve the simple cases cleanly // Helper functions to reduce boilerplate in WithContext implementations // wrapWithContext0 wraps a function with no parameters to accept a context parameter. +// It checks for context cancellation before invoking the underlying function. func wrapWithContext0[T any](fn func() (T, error)) func(context.Context) (T, error) { - return func(_ context.Context) (T, error) { + return func(ctx context.Context) (T, error) { + if err := ctx.Err(); err != nil { + var zero T + return zero, err + } return fn() } } // wrapWithContext1 wraps a function with one parameter to accept a context parameter. +// It checks for context cancellation before invoking the underlying function. func wrapWithContext1[T any, A any](fn func(A) (T, error)) func(context.Context, A) (T, error) { - return func(_ context.Context, a A) (T, error) { + return func(ctx context.Context, a A) (T, error) { + if err := ctx.Err(); err != nil { + var zero T + return zero, err + } return fn(a) } } // wrapWithContext2 wraps a function with two parameters to accept a context parameter. +// It checks for context cancellation before invoking the underlying function. func wrapWithContext2[T any, A any, B any](fn func(A, B) (T, error)) func(context.Context, A, B) (T, error) { - return func(_ context.Context, a A, b B) (T, error) { + return func(ctx context.Context, a A, b B) (T, error) { + if err := ctx.Err(); err != nil { + var zero T + return zero, err + } return fn(a, b) } } diff --git a/fail2ban/coverage_boost_test.go b/fail2ban/coverage_boost_test.go index b3d5a59..234c339 100644 --- a/fail2ban/coverage_boost_test.go +++ b/fail2ban/coverage_boost_test.go @@ -54,8 +54,7 @@ func TestRunnerFunctions(t *testing.T) { // Set up mock runner for testing mockRunner := NewMockRunner() mockRunner.SetResponse("test-cmd arg1", []byte("test output")) - SetRunner(mockRunner) - defer SetRunner(&OSRunner{}) // Restore real runner + defer WithTestRunner(t, mockRunner)() // Test RunnerCombinedOutput output, err := RunnerCombinedOutput("test-cmd", "arg1") diff --git a/fail2ban/fail2ban.go b/fail2ban/fail2ban.go index 5d5fe5b..a970cac 100644 --- a/fail2ban/fail2ban.go +++ b/fail2ban/fail2ban.go @@ -52,6 +52,18 @@ func SetFilterDir(dir string) { // OSRunner runs commands locally. type OSRunner struct{} +// validateCommandExecution validates command name and arguments before execution. +// This helper consolidates the duplicate validation pattern used in command execution methods. +func validateCommandExecution(ctx context.Context, name string, args []string) error { + if err := CachedValidateCommand(ctx, name); err != nil { + return fmt.Errorf(shared.ErrCommandValidationFailed, err) + } + if err := ValidateArgumentsWithContext(ctx, args); err != nil { + return fmt.Errorf(shared.ErrArgumentValidationFailed, err) + } + return nil +} + // CombinedOutput executes a command without sudo. func (r *OSRunner) CombinedOutput(name string, args ...string) ([]byte, error) { return r.CombinedOutputWithContext(context.Background(), name, args...) @@ -59,13 +71,8 @@ func (r *OSRunner) CombinedOutput(name string, args ...string) ([]byte, error) { // CombinedOutputWithContext executes a command without sudo with context support. func (r *OSRunner) CombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) { - // Validate command for security - if err := CachedValidateCommand(ctx, name); err != nil { - return nil, fmt.Errorf(shared.ErrCommandValidationFailed, err) - } - // Validate arguments for security - if err := ValidateArgumentsWithContext(ctx, args); err != nil { - return nil, fmt.Errorf(shared.ErrArgumentValidationFailed, err) + if err := validateCommandExecution(ctx, name, args); err != nil { + return nil, err } return exec.CommandContext(ctx, name, args...).CombinedOutput() } @@ -77,13 +84,8 @@ func (r *OSRunner) CombinedOutputWithSudo(name string, args ...string) ([]byte, // CombinedOutputWithSudoContext executes a command with sudo if needed, with context support. func (r *OSRunner) CombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) { - // Validate command for security - if err := CachedValidateCommand(ctx, name); err != nil { - return nil, fmt.Errorf(shared.ErrCommandValidationFailed, err) - } - // Validate arguments for security - if err := ValidateArgumentsWithContext(ctx, args); err != nil { - return nil, fmt.Errorf(shared.ErrArgumentValidationFailed, err) + if err := validateCommandExecution(ctx, name, args); err != nil { + return nil, err } checker := GetSudoChecker() @@ -158,30 +160,38 @@ func RunnerCombinedOutputWithSudo(name string, args ...string) ([]byte, error) { return output, err } +// runWithTimerContext is a helper that consolidates the common pattern of +// creating a timer, getting the runner, executing a command, and finishing the timer. +// This reduces code duplication between RunnerCombinedOutputWithContext and RunnerCombinedOutputWithSudoContext. +func runWithTimerContext( + ctx context.Context, + opName, name string, + args []string, + runFn func(Runner, context.Context, string, ...string) ([]byte, error), +) ([]byte, error) { + timer := NewTimedOperation(opName, name, args...) + runner := GetRunner() + output, err := runFn(runner, ctx, name, args...) + timer.FinishWithContext(ctx, err) + return output, err +} + // RunnerCombinedOutputWithContext invokes the runner for a command with context support. // RunnerCombinedOutputWithContext executes a command with context using the global runner. func RunnerCombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) { - timer := NewTimedOperation("RunnerCombinedOutputWithContext", name, args...) - - runner := GetRunner() - - output, err := runner.CombinedOutputWithContext(ctx, name, args...) - timer.FinishWithContext(ctx, err) - - return output, err + return runWithTimerContext(ctx, "RunnerCombinedOutputWithContext", name, args, + func(r Runner, c context.Context, n string, a ...string) ([]byte, error) { + return r.CombinedOutputWithContext(c, n, a...) + }) } // RunnerCombinedOutputWithSudoContext invokes the runner for a command with sudo and context support. // RunnerCombinedOutputWithSudoContext executes a command with sudo privileges and context using the global runner. func RunnerCombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) { - timer := NewTimedOperation("RunnerCombinedOutputWithSudoContext", name, args...) - - runner := GetRunner() - - output, err := runner.CombinedOutputWithSudoContext(ctx, name, args...) - timer.FinishWithContext(ctx, err) - - return output, err + return runWithTimerContext(ctx, "RunnerCombinedOutputWithSudoContext", name, args, + func(r Runner, c context.Context, n string, a ...string) ([]byte, error) { + return r.CombinedOutputWithSudoContext(c, n, a...) + }) } // MockRunner is a simple mock for Runner, used in unit tests. @@ -274,6 +284,17 @@ func (m *MockRunner) CombinedOutputWithSudo(name string, args ...string) ([]byte return m.CombinedOutput(name, args...) } +// withContextCheck wraps an operation with context cancellation check. +// This helper consolidates the duplicate context cancellation pattern. +func withContextCheck(ctx context.Context, fn func() ([]byte, error)) ([]byte, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return fn() +} + // SetResponse sets a response for a command. func (m *MockRunner) SetResponse(cmd string, response []byte) { m.mu.Lock() @@ -298,30 +319,50 @@ func (m *MockRunner) GetCalls() []string { return calls } +// SetupStandardResponses configures comprehensive standard responses for testing. +// This eliminates the need for repetitive SetResponse calls in individual tests. +func (m *MockRunner) SetupStandardResponses() { + StandardMockSetup(m) +} + +// SetupJailResponses configures responses for a specific jail. +// This is useful for tests that focus on a single jail's behavior. +func (m *MockRunner) SetupJailResponses(jail string) { + statusResponse := fmt.Sprintf("Status for the jail: %s\n|- Filter\n| |- Currently failed:\t0\n| "+ + "|- Total failed:\t5\n| `- File list:\t/var/log/auth.log\n`- Actions\n "+ + "|- Currently banned:\t1\n |- Total banned:\t2\n `- Banned IP list:\t192.168.1.100", jail) + + m.SetResponse(fmt.Sprintf("fail2ban-client status %s", jail), []byte(statusResponse)) + m.SetResponse(fmt.Sprintf("sudo fail2ban-client status %s", jail), []byte(statusResponse)) + + // Common ban/unban operations for the jail (use success status, not already-processed) + m.SetResponse(fmt.Sprintf("fail2ban-client set %s banip 192.168.1.100", jail), []byte(shared.Fail2BanStatusSuccess)) + m.SetResponse( + fmt.Sprintf("sudo fail2ban-client set %s banip 192.168.1.100", jail), + []byte(shared.Fail2BanStatusSuccess), + ) + m.SetResponse( + fmt.Sprintf("fail2ban-client set %s unbanip 192.168.1.100", jail), + []byte(shared.Fail2BanStatusSuccess), + ) + m.SetResponse( + fmt.Sprintf("sudo fail2ban-client set %s unbanip 192.168.1.100", jail), + []byte(shared.Fail2BanStatusSuccess), + ) +} + // CombinedOutputWithContext returns a mocked response or error for a command with context support. func (m *MockRunner) CombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) { - // Check if context is canceled - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - // Delegate to the non-context version for simplicity in tests - return m.CombinedOutput(name, args...) + return withContextCheck(ctx, func() ([]byte, error) { + return m.CombinedOutput(name, args...) + }) } // CombinedOutputWithSudoContext returns a mocked response for sudo commands with context support. func (m *MockRunner) CombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) { - // Check if context is canceled - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - // Delegate to the non-context version for simplicity in tests - return m.CombinedOutputWithSudo(name, args...) + return withContextCheck(ctx, func() ([]byte, error) { + return m.CombinedOutputWithSudo(name, args...) + }) } func (c *RealClient) fetchJailsWithContext(ctx context.Context) ([]string, error) { @@ -487,8 +528,12 @@ func (c *RealClient) StatusJailWithContext(ctx context.Context, jail string) (st return string(out), err } -// BanIPWithContext bans an IP address in the specified jail with context support. -func (c *RealClient) BanIPWithContext(ctx context.Context, ip, jail string) (int, error) { +// executeIPActionWithContext executes a ban/unban IP action with validation and response parsing. +// It returns (0, nil) for success, (1, nil) if already processed, or an error. +func (c *RealClient) executeIPActionWithContext( + ctx context.Context, + ip, jail, action, errorTemplate string, +) (int, error) { if err := CachedValidateIP(ctx, ip); err != nil { return 0, err } @@ -497,10 +542,9 @@ func (c *RealClient) BanIPWithContext(ctx context.Context, ip, jail string) (int } currentRunner := GetRunner() - - out, err := currentRunner.CombinedOutputWithSudoContext(ctx, c.Path, shared.ActionSet, jail, shared.ActionBanIP, ip) + out, err := currentRunner.CombinedOutputWithSudoContext(ctx, c.Path, shared.ActionSet, jail, action, ip) if err != nil { - return 0, fmt.Errorf(shared.ErrFailedToBanIP, ip, jail, err) + return 0, fmt.Errorf(errorTemplate, ip, jail, err) } code := strings.TrimSpace(string(out)) if code == shared.Fail2BanStatusSuccess { @@ -512,36 +556,14 @@ func (c *RealClient) BanIPWithContext(ctx context.Context, ip, jail string) (int return 0, fmt.Errorf(shared.ErrUnexpectedOutput, code) } +// BanIPWithContext bans an IP address in the specified jail with context support. +func (c *RealClient) BanIPWithContext(ctx context.Context, ip, jail string) (int, error) { + return c.executeIPActionWithContext(ctx, ip, jail, shared.ActionBanIP, shared.ErrFailedToBanIP) +} + // UnbanIPWithContext unbans an IP address from the specified jail with context support. func (c *RealClient) UnbanIPWithContext(ctx context.Context, ip, jail string) (int, error) { - if err := CachedValidateIP(ctx, ip); err != nil { - return 0, err - } - if err := CachedValidateJail(ctx, jail); err != nil { - return 0, err - } - - currentRunner := GetRunner() - - out, err := currentRunner.CombinedOutputWithSudoContext( - ctx, - c.Path, - shared.ActionSet, - jail, - shared.ActionUnbanIP, - ip, - ) - if err != nil { - return 0, fmt.Errorf(shared.ErrFailedToUnbanIP, ip, jail, err) - } - code := strings.TrimSpace(string(out)) - if code == shared.Fail2BanStatusSuccess { - return 0, nil - } - if code == shared.Fail2BanStatusAlreadyProcessed { - return 1, nil - } - return 0, fmt.Errorf(shared.ErrUnexpectedOutput, code) + return c.executeIPActionWithContext(ctx, ip, jail, shared.ActionUnbanIP, shared.ErrFailedToUnbanIP) } // BannedInWithContext returns a list of jails where the specified IP address is currently banned with context support. diff --git a/fail2ban/fail2ban_concurrency_test.go b/fail2ban/fail2ban_concurrency_test.go index 3c0b535..af2b070 100644 --- a/fail2ban/fail2ban_concurrency_test.go +++ b/fail2ban/fail2ban_concurrency_test.go @@ -10,8 +10,7 @@ import ( // TestRunnerConcurrentAccess tests that concurrent access to the runner // is safe and doesn't cause race conditions. func TestRunnerConcurrentAccess(t *testing.T) { - original := GetRunner() - defer SetRunner(original) + defer WithTestRunner(t, GetRunner())() const numGoroutines = 100 const numOperations = 50 @@ -53,12 +52,9 @@ func TestRunnerConcurrentAccess(t *testing.T) { // TestRunnerCombinedOutputConcurrency tests that concurrent calls to // RunnerCombinedOutput are safe. func TestRunnerCombinedOutputConcurrency(t *testing.T) { - original := GetRunner() - defer SetRunner(original) - mockRunner := NewMockRunner() + defer WithTestRunner(t, mockRunner)() mockRunner.SetResponse("echo test", []byte("test output")) - SetRunner(mockRunner) const numGoroutines = 50 var wg sync.WaitGroup @@ -120,12 +116,10 @@ func TestRunnerCombinedOutputWithSudoConcurrency(t *testing.T) { // TestMixedConcurrentOperations tests mixed concurrent operations including // setting runners and executing commands. func TestMixedConcurrentOperations(t *testing.T) { - original := GetRunner() - defer SetRunner(original) - // Set up a single shared MockRunner with all required responses // This avoids race conditions from multiple goroutines setting different runners sharedMockRunner := NewMockRunner() + defer WithTestRunner(t, sharedMockRunner)() // Set up responses for valid fail2ban commands to avoid validation errors sharedMockRunner.SetResponse("fail2ban-client status", []byte("Status: OK")) @@ -135,8 +129,6 @@ func TestMixedConcurrentOperations(t *testing.T) { sharedMockRunner.SetResponse("sudo fail2ban-client status", []byte("Status: OK")) sharedMockRunner.SetResponse("sudo fail2ban-client -V", []byte("Version: 1.0.0")) - SetRunner(sharedMockRunner) - const numGoroutines = 30 var wg sync.WaitGroup @@ -203,8 +195,7 @@ func TestMixedConcurrentOperations(t *testing.T) { // TestRunnerManagerLockOrdering verifies there are no deadlocks in the // runner manager's lock ordering. func TestRunnerManagerLockOrdering(t *testing.T) { - original := GetRunner() - defer SetRunner(original) + defer WithTestRunner(t, GetRunner())() // This test specifically looks for deadlocks by creating scenarios // where multiple goroutines could potentially deadlock if locks @@ -245,13 +236,10 @@ func TestRunnerManagerLockOrdering(t *testing.T) { // TestRunnerStateConsistency verifies that the runner state remains // consistent across concurrent operations. func TestRunnerStateConsistency(t *testing.T) { - original := GetRunner() - defer SetRunner(original) - // Set initial state initialRunner := NewMockRunner() initialRunner.SetResponse("initial", []byte("initial response")) - SetRunner(initialRunner) + defer WithTestRunner(t, initialRunner)() const numReaders = 50 const numWriters = 10 diff --git a/fail2ban/fail2ban_fail2ban_test.go b/fail2ban/fail2ban_fail2ban_test.go index 7e1a791..aa687c1 100644 --- a/fail2ban/fail2ban_fail2ban_test.go +++ b/fail2ban/fail2ban_fail2ban_test.go @@ -182,7 +182,10 @@ func TestBanIP(t *testing.T) { fmt.Errorf("command failed"), ) } else { - mock.SetResponse(fmt.Sprintf("sudo fail2ban-client set %s banip %s", tt.jail, tt.ip), []byte(tt.mockResponse)) + mock.SetResponse( + fmt.Sprintf("sudo fail2ban-client set %s banip %s", tt.jail, tt.ip), + []byte(tt.mockResponse), + ) } client, err := NewClient(shared.DefaultLogDir, shared.DefaultFilterDir) diff --git a/fail2ban/gzip_detection.go b/fail2ban/gzip_detection.go index e377958..9c519da 100644 --- a/fail2ban/gzip_detection.go +++ b/fail2ban/gzip_detection.go @@ -11,6 +11,32 @@ import ( "github.com/ivuorinen/f2b/shared" ) +// safeCloseFile closes a file and logs any error. +// This helper consolidates the duplicate close-and-log pattern. +func safeCloseFile(f *os.File, path string) { + if f == nil { + return + } + if closeErr := f.Close(); closeErr != nil { + getLogger().WithError(closeErr). + WithField(shared.LogFieldFile, path). + Warn("Failed to close file") + } +} + +// safeCloseReader closes a reader and logs any error. +// This helper consolidates the duplicate close-and-log pattern for io.ReadCloser. +func safeCloseReader(r io.ReadCloser, path string) { + if r == nil { + return + } + if closeErr := r.Close(); closeErr != nil { + getLogger().WithError(closeErr). + WithField(shared.LogFieldFile, path). + Warn("Failed to close reader") + } +} + // GzipDetector provides utilities for detecting and handling gzip-compressed files type GzipDetector struct{} @@ -38,13 +64,7 @@ func (gd *GzipDetector) hasGzipMagicBytes(path string) (bool, error) { if err != nil { return false, err } - defer func() { - if closeErr := f.Close(); closeErr != nil { - getLogger().WithError(closeErr). - WithField(shared.LogFieldFile, path). - Warn("Failed to close file in gzip magic byte check") - } - }() + defer safeCloseFile(f, path) var magic [2]byte n, err := f.Read(magic[:]) @@ -70,11 +90,7 @@ func (gd *GzipDetector) OpenGzipAwareReader(path string) (io.ReadCloser, error) isGzip, err := gd.IsGzipFile(path) if err != nil { - if closeErr := f.Close(); closeErr != nil { - getLogger().WithError(closeErr). - WithField(shared.LogFieldFile, path). - Warn("Failed to close file during error handling") - } + safeCloseFile(f, path) return nil, err } @@ -82,21 +98,13 @@ func (gd *GzipDetector) OpenGzipAwareReader(path string) (io.ReadCloser, error) // For gzip files, we need to position at the beginning and create gzip reader _, err = f.Seek(0, io.SeekStart) if err != nil { - if closeErr := f.Close(); closeErr != nil { - getLogger().WithError(closeErr). - WithField(shared.LogFieldFile, path). - Warn("Failed to close file during seek error handling") - } + safeCloseFile(f, path) return nil, err } gz, err := gzip.NewReader(f) if err != nil { - if closeErr := f.Close(); closeErr != nil { - getLogger().WithError(closeErr). - WithField(shared.LogFieldFile, path). - Warn("Failed to close file during gzip reader error handling") - } + safeCloseFile(f, path) return nil, err } @@ -128,11 +136,7 @@ func (gd *GzipDetector) CreateGzipAwareScannerWithBuffer(path string, maxLineSiz } cleanup := func() { - if err := reader.Close(); err != nil { - getLogger().WithError(err). - WithField(shared.LogFieldFile, path). - Warn("Failed to close reader during cleanup") - } + safeCloseReader(reader, path) } return scanner, cleanup, nil diff --git a/fail2ban/helpers.go b/fail2ban/helpers.go index 54feda8..5997f5d 100644 --- a/fail2ban/helpers.go +++ b/fail2ban/helpers.go @@ -792,20 +792,23 @@ func ValidateLogPath(ctx context.Context, path string, logDir string) (string, e return ValidatePathWithSecurity(path, config) } +// validateClientPath is a generic helper for client path validation. +// It reduces duplication between ValidateClientLogPath and ValidateClientFilterPath. +func validateClientPath(ctx context.Context, path string, configFn func() PathSecurityConfig) (string, error) { + _ = ctx // Context not currently used by ValidatePathWithSecurity + return ValidatePathWithSecurity(path, configFn()) +} + // ValidateClientLogPath validates log directory path for client initialization // Context parameter accepted for API consistency but not currently used func ValidateClientLogPath(ctx context.Context, logDir string) (string, error) { - _ = ctx // Context not currently used by ValidatePathWithSecurity - config := CreateLogPathConfig() - return ValidatePathWithSecurity(logDir, config) + return validateClientPath(ctx, logDir, CreateLogPathConfig) } // ValidateClientFilterPath validates filter directory path for client initialization // Context parameter accepted for API consistency but not currently used func ValidateClientFilterPath(ctx context.Context, filterDir string) (string, error) { - _ = ctx // Context not currently used by ValidatePathWithSecurity - config := CreateFilterPathConfig() - return ValidatePathWithSecurity(filterDir, config) + return validateClientPath(ctx, filterDir, CreateFilterPathConfig) } // ValidateFilterName validates a filter name for path traversal prevention. diff --git a/fail2ban/helpers_validation_test.go b/fail2ban/helpers_validation_test.go index 034a4d8..ed29df9 100644 --- a/fail2ban/helpers_validation_test.go +++ b/fail2ban/helpers_validation_test.go @@ -75,18 +75,14 @@ func TestValidateFilterName(t *testing.T) { // TestGetLogLinesWrapper tests the GetLogLines wrapper function func TestGetLogLinesWrapper(t *testing.T) { - // Save and restore original runner - originalRunner := GetRunner() - defer SetRunner(originalRunner) - mockRunner := NewMockRunner() + defer WithTestRunner(t, mockRunner)() mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0")) mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0")) mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong")) mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong")) mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) - SetRunner(mockRunner) // Create temporary log directory tmpDir := t.TempDir() @@ -107,9 +103,7 @@ func TestGetLogLinesWrapper(t *testing.T) { // TestBanIPWithContext tests the BanIPWithContext function func TestBanIPWithContext(t *testing.T) { - // Save and restore original runner - originalRunner := GetRunner() - defer SetRunner(originalRunner) + defer WithTestRunner(t, GetRunner())() tests := []struct { name string @@ -160,18 +154,14 @@ func TestBanIPWithContext(t *testing.T) { // TestGetLogLinesWithLimitAndContext tests the GetLogLinesWithLimitAndContext function func TestGetLogLinesWithLimitAndContext(t *testing.T) { - // Save and restore original runner - originalRunner := GetRunner() - defer SetRunner(originalRunner) - mockRunner := NewMockRunner() + defer WithTestRunner(t, mockRunner)() mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0")) mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0")) mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong")) mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong")) mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) - SetRunner(mockRunner) // Create temporary log directory tmpDir := t.TempDir() diff --git a/fail2ban/logrus_adapter.go b/fail2ban/logrus_adapter.go index c22fdea..3fb6a3d 100644 --- a/fail2ban/logrus_adapter.go +++ b/fail2ban/logrus_adapter.go @@ -2,14 +2,47 @@ package fail2ban import "github.com/sirupsen/logrus" -// logrusAdapter wraps logrus to implement our decoupled LoggerInterface -type logrusAdapter struct { +// loggerCore provides common logging methods that delegate to a logrus.Entry. +// This type is embedded in both logrusAdapter and logrusEntryAdapter to +// eliminate duplicate method implementations. +type loggerCore struct { entry *logrus.Entry } -// logrusEntryAdapter wraps logrus.Entry to implement LoggerEntry +// Debug logs a debug-level message. +func (c *loggerCore) Debug(args ...interface{}) { c.entry.Debug(args...) } + +// Info logs an info-level message. +func (c *loggerCore) Info(args ...interface{}) { c.entry.Info(args...) } + +// Warn logs a warning-level message. +func (c *loggerCore) Warn(args ...interface{}) { c.entry.Warn(args...) } + +// Error logs an error-level message. +func (c *loggerCore) Error(args ...interface{}) { c.entry.Error(args...) } + +// Debugf logs a formatted debug-level message. +func (c *loggerCore) Debugf(format string, args ...interface{}) { c.entry.Debugf(format, args...) } + +// Infof logs a formatted info-level message. +func (c *loggerCore) Infof(format string, args ...interface{}) { c.entry.Infof(format, args...) } + +// Warnf logs a formatted warning-level message. +func (c *loggerCore) Warnf(format string, args ...interface{}) { c.entry.Warnf(format, args...) } + +// Errorf logs a formatted error-level message. +func (c *loggerCore) Errorf(format string, args ...interface{}) { c.entry.Errorf(format, args...) } + +// logrusAdapter wraps logrus to implement our decoupled LoggerInterface. +// It embeds loggerCore to provide the 8 standard logging methods. +type logrusAdapter struct { + loggerCore // embeds Debug, Info, Warn, Error, Debugf, Infof, Warnf, Errorf +} + +// logrusEntryAdapter wraps logrus.Entry to implement LoggerEntry. +// It embeds loggerCore to provide the 8 standard logging methods. type logrusEntryAdapter struct { - entry *logrus.Entry + loggerCore // embeds Debug, Info, Warn, Error, Debugf, Infof, Warnf, Errorf } // Ensure logrusAdapter implements LoggerInterface @@ -23,117 +56,37 @@ func NewLogrusAdapter(logger *logrus.Logger) LoggerInterface { if logger == nil { logger = logrus.StandardLogger() } - return &logrusAdapter{entry: logrus.NewEntry(logger)} + return &logrusAdapter{loggerCore: loggerCore{entry: logrus.NewEntry(logger)}} } // WithField implements LoggerInterface func (l *logrusAdapter) WithField(key string, value interface{}) LoggerEntry { - return &logrusEntryAdapter{entry: l.entry.WithField(key, value)} + return &logrusEntryAdapter{loggerCore: loggerCore{entry: l.entry.WithField(key, value)}} } // WithFields implements LoggerInterface func (l *logrusAdapter) WithFields(fields Fields) LoggerEntry { - return &logrusEntryAdapter{entry: l.entry.WithFields(logrus.Fields(fields))} + return &logrusEntryAdapter{loggerCore: loggerCore{entry: l.entry.WithFields(logrus.Fields(fields))}} } // WithError implements LoggerInterface func (l *logrusAdapter) WithError(err error) LoggerEntry { - return &logrusEntryAdapter{entry: l.entry.WithError(err)} -} - -// Debug implements LoggerInterface -func (l *logrusAdapter) Debug(args ...interface{}) { - l.entry.Debug(args...) -} - -// Info implements LoggerInterface -func (l *logrusAdapter) Info(args ...interface{}) { - l.entry.Info(args...) -} - -// Warn implements LoggerInterface -func (l *logrusAdapter) Warn(args ...interface{}) { - l.entry.Warn(args...) -} - -// Error implements LoggerInterface -func (l *logrusAdapter) Error(args ...interface{}) { - l.entry.Error(args...) -} - -// Debugf implements LoggerInterface -func (l *logrusAdapter) Debugf(format string, args ...interface{}) { - l.entry.Debugf(format, args...) -} - -// Infof implements LoggerInterface -func (l *logrusAdapter) Infof(format string, args ...interface{}) { - l.entry.Infof(format, args...) -} - -// Warnf implements LoggerInterface -func (l *logrusAdapter) Warnf(format string, args ...interface{}) { - l.entry.Warnf(format, args...) -} - -// Errorf implements LoggerInterface -func (l *logrusAdapter) Errorf(format string, args ...interface{}) { - l.entry.Errorf(format, args...) + return &logrusEntryAdapter{loggerCore: loggerCore{entry: l.entry.WithError(err)}} } // LoggerEntry implementation for logrusEntryAdapter // WithField implements LoggerEntry func (e *logrusEntryAdapter) WithField(key string, value interface{}) LoggerEntry { - return &logrusEntryAdapter{entry: e.entry.WithField(key, value)} + return &logrusEntryAdapter{loggerCore: loggerCore{entry: e.entry.WithField(key, value)}} } // WithFields implements LoggerEntry func (e *logrusEntryAdapter) WithFields(fields Fields) LoggerEntry { - return &logrusEntryAdapter{entry: e.entry.WithFields(logrus.Fields(fields))} + return &logrusEntryAdapter{loggerCore: loggerCore{entry: e.entry.WithFields(logrus.Fields(fields))}} } // WithError implements LoggerEntry func (e *logrusEntryAdapter) WithError(err error) LoggerEntry { - return &logrusEntryAdapter{entry: e.entry.WithError(err)} -} - -// Debug implements LoggerEntry -func (e *logrusEntryAdapter) Debug(args ...interface{}) { - e.entry.Debug(args...) -} - -// Info implements LoggerEntry -func (e *logrusEntryAdapter) Info(args ...interface{}) { - e.entry.Info(args...) -} - -// Warn implements LoggerEntry -func (e *logrusEntryAdapter) Warn(args ...interface{}) { - e.entry.Warn(args...) -} - -// Error implements LoggerEntry -func (e *logrusEntryAdapter) Error(args ...interface{}) { - e.entry.Error(args...) -} - -// Debugf implements LoggerEntry -func (e *logrusEntryAdapter) Debugf(format string, args ...interface{}) { - e.entry.Debugf(format, args...) -} - -// Infof implements LoggerEntry -func (e *logrusEntryAdapter) Infof(format string, args ...interface{}) { - e.entry.Infof(format, args...) -} - -// Warnf implements LoggerEntry -func (e *logrusEntryAdapter) Warnf(format string, args ...interface{}) { - e.entry.Warnf(format, args...) -} - -// Errorf implements LoggerEntry -func (e *logrusEntryAdapter) Errorf(format string, args ...interface{}) { - e.entry.Errorf(format, args...) + return &logrusEntryAdapter{loggerCore: loggerCore{entry: e.entry.WithError(err)}} } diff --git a/fail2ban/logs.go b/fail2ban/logs.go index 18578cd..46f88f1 100644 --- a/fail2ban/logs.go +++ b/fail2ban/logs.go @@ -377,11 +377,7 @@ func readLogFile(path string) ([]byte, error) { if err != nil { return nil, err } - defer func() { - if cerr := reader.Close(); cerr != nil { - getLogger().WithError(cerr).Error("failed to close log file") - } - }() + defer safeCloseReader(reader, cleanPath) return io.ReadAll(reader) } diff --git a/fail2ban/mock.go b/fail2ban/mock.go index 2d61e0e..dbdbc18 100644 --- a/fail2ban/mock.go +++ b/fail2ban/mock.go @@ -9,6 +9,17 @@ import ( "time" ) +// setNestedMapValue sets a value in a nested map[string]map[string]T structure with mutex protection. +// It initializes the inner map if nil. +func setNestedMapValue[T any](mu *sync.Mutex, mp map[string]map[string]T, jail, ip string, value T) { + mu.Lock() + defer mu.Unlock() + if mp[jail] == nil { + mp[jail] = make(map[string]T) + } + mp[jail][ip] = value +} + // MockClient is a stateful, thread-safe mock implementation of the Client interface for testing. type MockClient struct { mu sync.Mutex @@ -286,42 +297,22 @@ func (m *MockClient) Reset() { // SetBanError configures an error to return for BanIP(ip, jail). func (m *MockClient) SetBanError(jail, ip string, err error) { - m.mu.Lock() - defer m.mu.Unlock() - if m.BanErrors[jail] == nil { - m.BanErrors[jail] = make(map[string]error) - } - m.BanErrors[jail][ip] = err + setNestedMapValue(&m.mu, m.BanErrors, jail, ip, err) } // SetBanResult configures a result code to return for BanIP(ip, jail). func (m *MockClient) SetBanResult(jail, ip string, result int) { - m.mu.Lock() - defer m.mu.Unlock() - if m.BanResults[jail] == nil { - m.BanResults[jail] = make(map[string]int) - } - m.BanResults[jail][ip] = result + setNestedMapValue(&m.mu, m.BanResults, jail, ip, result) } // SetUnbanError configures an error to return for UnbanIP(ip, jail). func (m *MockClient) SetUnbanError(jail, ip string, err error) { - m.mu.Lock() - defer m.mu.Unlock() - if m.UnbanErrors[jail] == nil { - m.UnbanErrors[jail] = make(map[string]error) - } - m.UnbanErrors[jail][ip] = err + setNestedMapValue(&m.mu, m.UnbanErrors, jail, ip, err) } // SetUnbanResult configures a result code to return for UnbanIP(ip, jail). func (m *MockClient) SetUnbanResult(jail, ip string, result int) { - m.mu.Lock() - defer m.mu.Unlock() - if m.UnbanResults[jail] == nil { - m.UnbanResults[jail] = make(map[string]int) - } - m.UnbanResults[jail][ip] = result + setNestedMapValue(&m.mu, m.UnbanResults, jail, ip, result) } // SetStatusJailData configures the status data for a specific jail. diff --git a/fail2ban/test_helpers.go b/fail2ban/test_helpers.go index cd37a87..009cb35 100644 --- a/fail2ban/test_helpers.go +++ b/fail2ban/test_helpers.go @@ -262,6 +262,24 @@ func assertContainsText(t *testing.T, lines []string, text string) { t.Errorf("Expected to find '%s' in results", text) } +// WithTestRunner sets a test runner and returns a cleanup function. +// Usage: defer fail2ban.WithTestRunner(t, mockRunner)() +func WithTestRunner(t TestingInterface, runner Runner) func() { + t.Helper() + original := GetRunner() + SetRunner(runner) + return func() { SetRunner(original) } +} + +// WithTestSudoChecker sets a test sudo checker and returns a cleanup function. +// Usage: defer fail2ban.WithTestSudoChecker(t, mockChecker)() +func WithTestSudoChecker(t TestingInterface, checker SudoChecker) func() { + t.Helper() + original := GetSudoChecker() + SetSudoChecker(checker) + return func() { SetSudoChecker(original) } +} + // StandardMockSetup configures comprehensive standard responses for MockRunner // This eliminates the need for repetitive SetResponse calls in individual tests func StandardMockSetup(mockRunner *MockRunner) { diff --git a/fail2ban/time_parser.go b/fail2ban/time_parser.go index 8832874..243b8f4 100644 --- a/fail2ban/time_parser.go +++ b/fail2ban/time_parser.go @@ -36,17 +36,7 @@ func NewTimeParsingCache(layout string) (*TimeParsingCache, error) { // ParseTime parses a time string with bounded caching for performance func (tpc *TimeParsingCache) ParseTime(timeStr string) (time.Time, error) { - // Check cache first - if cached, ok := tpc.parseCache.Load(timeStr); ok { - return cached, nil - } - - // Parse and cache - t, err := time.Parse(tpc.layout, timeStr) - if err == nil { - tpc.parseCache.Store(timeStr, t) - } - return t, err + return tpc.parseCache.ParseWithLayout(timeStr, tpc.layout) } // BuildTimeString efficiently builds a time string from date and time components diff --git a/fail2ban/types.go b/fail2ban/types.go index 4153654..e058bb8 100644 --- a/fail2ban/types.go +++ b/fail2ban/types.go @@ -33,19 +33,10 @@ type LoggerEntry interface { Errorf(format string, args ...interface{}) } -// LoggerInterface defines the top-level logging interface (decoupled from logrus) +// LoggerInterface defines the top-level logging interface (decoupled from logrus). +// It embeds LoggerEntry since both interfaces share the same method signatures. type LoggerInterface interface { - WithField(key string, value interface{}) LoggerEntry - WithFields(fields Fields) LoggerEntry - WithError(err error) LoggerEntry - Debug(args ...interface{}) - Info(args ...interface{}) - Warn(args ...interface{}) - Error(args ...interface{}) - Debugf(format string, args ...interface{}) - Infof(format string, args ...interface{}) - Warnf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) + LoggerEntry } // LogCollectionConfig configures log line collection behavior