refactor: linting, simplification and fixes (#119)

* refactor: consolidate test helpers and reduce code duplication

- Fix prealloc lint issue in cmd_logswatch_test.go
- Add validateIPAndJails helper to consolidate IP/jail validation
- Add WithTestRunner/WithTestSudoChecker helpers for cleaner test setup
- Replace setupBasicMockResponses duplicates with StandardMockSetup
- Add SetupStandardResponses/SetupJailResponses to MockRunner
- Delegate cmd context helpers to fail2ban implementations
- Document context wrapper pattern in context_helpers.go

* refactor: consolidate duplicate code patterns across cmd and fail2ban packages

Add helper functions to reduce code duplication found by dupl:

- safeCloseFile/safeCloseReader: centralize file close error logging
- createTimeoutContext: consolidate timeout context creation pattern
- withContextCheck: wrap context cancellation checks
- recordOperationMetrics: unify metrics recording for commands/clients

Also includes Phase 1 consolidations:
- copyBuckets helper for metrics snapshots
- Table-driven context extraction in logging
- processWithValidation helper for IP processors

* refactor: consolidate LoggerInterface by embedding LoggerEntry

Both interfaces had identical method signatures. LoggerInterface now
embeds LoggerEntry to eliminate code duplication.

* refactor: consolidate test framework helpers and fix test patterns

- Add checkJSONFieldValue and failMissingJSONField helpers to reduce
  duplication in JSON assertion methods
- Add ParallelTimeout to default test config
- Fix test to use WithTestRunner inside test loop for proper mock scoping

* refactor: unify ban/unban operations with OperationType pattern

Introduce OperationType struct to consolidate duplicate ban/unban logic:
- Add ProcessOperation and ProcessOperationWithContext generic functions
- Add ProcessOperationParallel and ProcessOperationParallelWithContext
- Existing ProcessBan*/ProcessUnban* functions now delegate to generic versions
- Reduces ~120 lines of duplicate code between ban and unban operations

* refactor: consolidate time parsing cache pattern

Add ParseWithLayout method to BoundedTimeCache that consolidates the
cache-lookup-parse-store pattern. FastTimeCache and TimeParsingCache
now delegate to this method instead of duplicating the logic.

* refactor: consolidate command execution patterns in fail2ban

- Add validateCommandExecution helper for command/argument validation
- Add runWithTimerContext helper for timed runner operations
- Add executeIPActionWithContext to unify BanIP/UnbanIP implementations
- Reduces duplicate validation and execution boilerplate

* refactor: consolidate logrus adapter with embedded loggerCore

Introduce loggerCore type that provides the 8 standard logging methods
(Debug, Info, Warn, Error, Debugf, Infof, Warnf, Errorf). Both
logrusAdapter and logrusEntryAdapter now embed this type, eliminating
16 duplicate method implementations.

* refactor: consolidate path validation patterns

- Add validateConfigPathWithFallback helper in cmd/config_utils.go
  for the validate-or-fallback-with-logging pattern
- Add validateClientPath helper in fail2ban/helpers.go for client
  path validation delegation

* fix: add context cancellation checks to wrapper functions

- wrapWithContext0/1/2 now check ctx.Err() before invoking wrapped function
- WithCommand now validates and trims empty command strings

* refactor: extract formatLatencyBuckets for deterministic metrics output

Add formatLatencyBuckets helper that writes latency bucket distribution
with sorted keys for deterministic output, eliminating duplicate
formatting code for command and client latency buckets.

* refactor: add generic setNestedMapValue helper for mock configuration

Add setNestedMapValue[T] generic helper that consolidates the repeated
pattern of mutex-protected nested map initialization and value setting
used by SetBanError, SetBanResult, SetUnbanError, and SetUnbanResult.

* fix: use cmd.Context() for signal propagation and correct mock status

- ExecuteIPCommand now uses cmd.Context() instead of context.Background()
  to inherit Cobra's signal cancellation
- MockRunner.SetupJailResponses uses shared.Fail2BanStatusSuccess ("0")
  instead of literal "1" for proper success path simulation

* fix: restore operation-specific log messages in ProcessOperationWithContext

Add back Logger.WithFields().Info(opType.Message) call that was lost
during refactoring. This restores the distinction between ban and unban
operation messages (shared.MsgBanResult vs shared.MsgUnbanResult).

* fix: return aggregated errors from parallel operations

Previously, errors from individual parallel operations were silently
swallowed - converted to status strings but never returned to callers.

Now processOperations collects all errors and returns them aggregated
via errors.Join, allowing callers to distinguish partial failures from
complete success while still receiving all results.

* fix: add input validation to processOperations before parallel execution

Validate IP and jail inputs at the start of processOperations() using
fail2ban.CachedValidateIP and CachedValidateJail. This prevents invalid
or malicious inputs (empty values, path traversal attempts, malformed
IPs) from reaching the operation functions. All validation errors are
aggregated and returned before any operations execute.
This commit is contained in:
2026-01-25 19:07:45 +02:00
committed by GitHub
parent a668c4563e
commit 605f2b9580
33 changed files with 752 additions and 768 deletions

View File

@@ -301,7 +301,7 @@ func (m *MockLogsWatchClient) GetLogLines(jail, ip string) ([]string, error) {
logs = m.initialLogs
} else {
// Simulate new logs being added
logs = make([]string, len(m.initialLogs))
logs = make([]string, len(m.initialLogs), len(m.initialLogs)+1)
copy(logs, m.initialLogs)
logs = append(logs, fmt.Sprintf("new log line %d", m.callCount))
}

View File

@@ -88,7 +88,7 @@ func TestParallelOperationProcessor_EmptyJailsList(t *testing.T) {
}
func TestParallelOperationProcessor_ErrorHandling(t *testing.T) {
// Test error handling doesn't cause index issues
// Test error handling returns aggregated errors while still populating results
processor := NewParallelOperationProcessor(2)
// Mock client for testing
@@ -99,15 +99,16 @@ func TestParallelOperationProcessor_ErrorHandling(t *testing.T) {
results, err := processor.ProcessBanOperationParallel(mockClient, "192.168.1.100", jails)
if err != nil {
t.Fatalf("ProcessBanOperationParallel failed: %v", err)
// Errors should now be returned (aggregated)
if err == nil {
t.Error("Expected error for non-existent jails, got nil")
}
if len(results) != 2 {
t.Errorf("Expected 2 results, got %d", len(results))
}
// All results should have errors for non-existent jails
// All results should still be populated with error status
for i, result := range results {
if result.Jail == "" {
t.Errorf("Result %d has empty jail", i)

View File

@@ -149,9 +149,10 @@ func NewCommandTest(t *testing.T, commandName string) *CommandTestBuilder {
command: commandName,
args: make([]string, 0),
config: &Config{
Format: PlainFormat,
CommandTimeout: shared.DefaultCommandTimeout,
FileTimeout: shared.DefaultFileTimeout,
Format: PlainFormat,
CommandTimeout: shared.DefaultCommandTimeout,
FileTimeout: shared.DefaultFileTimeout,
ParallelTimeout: shared.DefaultParallelTimeout,
},
}
}
@@ -418,6 +419,20 @@ func (result *CommandTestResult) AssertExactOutput(expected string) *CommandTest
return result
}
// checkJSONFieldValue validates that a JSON field value matches the expected string.
func (result *CommandTestResult) checkJSONFieldValue(val interface{}, fieldName, expected string) {
result.t.Helper()
if fmt.Sprintf("%v", val) != expected {
result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val)
}
}
// failMissingJSONField reports a missing JSON field with context.
func (result *CommandTestResult) failMissingJSONField(fieldName, context string) {
result.t.Helper()
result.t.Fatalf("%s: JSON field %q not found%s: %s", result.name, fieldName, context, result.Output)
}
// AssertJSONField validates a specific field in JSON output
func (result *CommandTestResult) AssertJSONField(fieldPath, expected string) *CommandTestResult {
result.t.Helper()
@@ -434,22 +449,18 @@ func (result *CommandTestResult) AssertJSONField(fieldPath, expected string) *Co
switch v := data.(type) {
case map[string]interface{}:
if val, ok := v[fieldName]; ok {
if fmt.Sprintf("%v", val) != expected {
result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val)
}
result.checkJSONFieldValue(val, fieldName, expected)
} else {
result.t.Fatalf("%s: JSON field %q not found in output: %s", result.name, fieldName, result.Output)
result.failMissingJSONField(fieldName, " in output")
}
case []interface{}:
// Handle array case - look in first element
if len(v) > 0 {
if firstItem, ok := v[0].(map[string]interface{}); ok {
if val, ok := firstItem[fieldName]; ok {
if fmt.Sprintf("%v", val) != expected {
result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val)
}
result.checkJSONFieldValue(val, fieldName, expected)
} else {
result.t.Fatalf("%s: JSON field %q not found in first array element: %s", result.name, fieldName, result.Output)
result.failMissingJSONField(fieldName, " in first array element")
}
} else {
result.t.Fatalf("%s: first array element is not an object in output: %s", result.name, result.Output)

View File

@@ -12,13 +12,9 @@ import (
// TestTestFilterCmdCreation tests TestFilterCmd command creation
func TestTestFilterCmdCreation(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
mockRunner := fail2ban.NewMockRunner()
setupBasicMockResponses(mockRunner)
fail2ban.SetRunner(mockRunner)
defer fail2ban.WithTestRunner(t, mockRunner)()
fail2ban.StandardMockSetup(mockRunner)
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
require.NoError(t, err)
@@ -39,10 +35,6 @@ func TestTestFilterCmdCreation(t *testing.T) {
// TestTestFilterCmdExecution tests TestFilterCmd execution
func TestTestFilterCmdExecution(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
tests := []struct {
name string
setupMock func(*fail2ban.MockRunner)
@@ -52,7 +44,7 @@ func TestTestFilterCmdExecution(t *testing.T) {
{
name: "successful filter test",
setupMock: func(m *fail2ban.MockRunner) {
setupBasicMockResponses(m)
fail2ban.StandardMockSetup(m)
m.SetResponse("fail2ban-client get sshd logpath", []byte("/var/log/auth.log"))
m.SetResponse("sudo fail2ban-client get sshd logpath", []byte("/var/log/auth.log"))
},
@@ -62,7 +54,7 @@ func TestTestFilterCmdExecution(t *testing.T) {
{
name: "no filter provided - lists available",
setupMock: func(m *fail2ban.MockRunner) {
setupBasicMockResponses(m)
fail2ban.StandardMockSetup(m)
// Mock ListFiltersWithContext response
},
args: []string{},
@@ -71,7 +63,7 @@ func TestTestFilterCmdExecution(t *testing.T) {
{
name: "invalid filter name",
setupMock: func(m *fail2ban.MockRunner) {
setupBasicMockResponses(m)
fail2ban.StandardMockSetup(m)
},
args: []string{"../../../etc/passwd"},
expectError: true,
@@ -81,8 +73,8 @@ func TestTestFilterCmdExecution(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockRunner := fail2ban.NewMockRunner()
defer fail2ban.WithTestRunner(t, mockRunner)()
tt.setupMock(mockRunner)
fail2ban.SetRunner(mockRunner)
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
require.NoError(t, err)

View File

@@ -164,6 +164,17 @@ func validateConfigPath(path, pathType string) (string, error) {
return absPath, nil
}
// validateConfigPathWithFallback validates a config path and returns the fallback if validation fails.
// This consolidates the common pattern of validate-or-fallback-with-logging used for config paths.
func validateConfigPathWithFallback(path, pathType, defaultPath, errorMsg string) string {
validated, err := validateConfigPath(path, pathType)
if err != nil {
Logger.WithError(err).WithField(shared.LogFieldPath, path).Error(errorMsg)
return defaultPath
}
return validated
}
// isReasonableSystemPath checks if a path is in a reasonable system location
func isReasonableSystemPath(path, pathType string) bool {
// Allow common system directories based on path type
@@ -195,28 +206,20 @@ func NewConfigFromEnv() Config {
if logDir == "" {
logDir = shared.DefaultLogDir
}
validatedLogDir, err := validateConfigPath(logDir, shared.PathTypeLog)
if err != nil {
Logger.WithError(err).WithField(shared.LogFieldPath, logDir).Error("Invalid log directory from environment")
validatedLogDir = shared.DefaultLogDir // Fallback to safe default
}
cfg.LogDir = validatedLogDir
cfg.LogDir = validateConfigPathWithFallback(
logDir, shared.PathTypeLog, shared.DefaultLogDir,
"Invalid log directory from environment",
)
// Get and validate filter directory
filterDir := os.Getenv("F2B_FILTER_DIR")
if filterDir == "" {
filterDir = shared.DefaultFilterDir
}
validatedFilterDir, err := validateConfigPath(filterDir, shared.PathTypeFilter)
if err != nil {
Logger.WithError(err).
WithField(shared.LogFieldPath, filterDir).
Error("Invalid filter directory from environment")
validatedFilterDir = shared.DefaultFilterDir // Fallback to safe default
}
cfg.FilterDir = validatedFilterDir
cfg.FilterDir = validateConfigPathWithFallback(
filterDir, shared.PathTypeFilter, shared.DefaultFilterDir,
"Invalid filter directory from environment",
)
// Configure timeouts from environment variables
cfg.CommandTimeout = parseTimeoutFromEnv("F2B_COMMAND_TIMEOUT", shared.DefaultCommandTimeout)

View File

@@ -17,6 +17,20 @@ import (
"github.com/ivuorinen/f2b/fail2ban"
)
// createTimeoutContext creates a context with the configured command timeout.
// This helper consolidates the duplicate timeout handling pattern.
// If base is nil, context.Background() is used.
func createTimeoutContext(base context.Context, config *Config) (context.Context, context.CancelFunc) {
if base == nil {
base = context.Background()
}
timeout := shared.DefaultCommandTimeout
if config != nil && config.CommandTimeout > 0 {
timeout = config.CommandTimeout
}
return context.WithTimeout(base, timeout)
}
// IsCI detects if we're running in a CI environment
func IsCI() bool {
return fail2ban.IsCI()
@@ -50,17 +64,8 @@ func NewContextualCommand(
// Get the contextual logger
logger := GetContextualLogger()
// Base on Cobra's context so signals/cancellations propagate
base := cmd.Context()
if base == nil {
base = context.Background()
}
// Create timeout context for the entire operation
timeout := shared.DefaultCommandTimeout
if config != nil && config.CommandTimeout > 0 {
timeout = config.CommandTimeout
}
ctx, cancel := context.WithTimeout(base, timeout)
// Create timeout context based on Cobra's context so signals/cancellations propagate
ctx, cancel := createTimeoutContext(cmd.Context(), config)
defer cancel()
// Extract command name from use string (first word)
@@ -388,22 +393,63 @@ type OperationResult struct {
Status string `json:"status"`
}
// ProcessBanOperation processes ban operations across multiple jails
func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
// OperationType defines a ban or unban operation with its associated metadata
type OperationType struct {
// MetricsType is the metrics key for this operation (e.g., shared.MetricsBan)
MetricsType string
// Message is the log message for this operation (e.g., shared.MsgBanResult)
Message string
// Operation is the function to execute without context
Operation func(client fail2ban.Client, ip, jail string) (int, error)
// OperationCtx is the function to execute with context
OperationCtx func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error)
}
// BanOperationType defines the ban operation
var BanOperationType = OperationType{
MetricsType: shared.MetricsBan,
Message: shared.MsgBanResult,
Operation: func(c fail2ban.Client, ip, jail string) (int, error) {
return c.BanIP(ip, jail)
},
OperationCtx: func(ctx context.Context, c fail2ban.Client, ip, jail string) (int, error) {
return c.BanIPWithContext(ctx, ip, jail)
},
}
// UnbanOperationType defines the unban operation
var UnbanOperationType = OperationType{
MetricsType: shared.MetricsUnban,
Message: shared.MsgUnbanResult,
Operation: func(c fail2ban.Client, ip, jail string) (int, error) {
return c.UnbanIP(ip, jail)
},
OperationCtx: func(ctx context.Context, c fail2ban.Client, ip, jail string) (int, error) {
return c.UnbanIPWithContext(ctx, ip, jail)
},
}
// ProcessOperation processes operations across multiple jails using the specified operation type
func ProcessOperation(
client fail2ban.Client,
ip string,
jails []string,
opType OperationType,
) ([]OperationResult, error) {
results := make([]OperationResult, 0, len(jails))
for _, jail := range jails {
code, err := client.BanIP(ip, jail)
code, err := opType.Operation(client, ip, jail)
if err != nil {
return nil, err
}
status := InterpretBanStatus(code, shared.MetricsBan)
status := InterpretBanStatus(code, opType.MetricsType)
Logger.WithFields(map[string]interface{}{
"ip": ip,
"jail": jail,
"status": status,
}).Info(shared.MsgBanResult)
}).Info(opType.Message)
results = append(results, OperationResult{
IP: ip,
@@ -415,6 +461,59 @@ func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]O
return results, nil
}
// ProcessOperationWithContext processes operations across multiple jails with timeout context
func ProcessOperationWithContext(
ctx context.Context,
client fail2ban.Client,
ip string,
jails []string,
opType OperationType,
) ([]OperationResult, error) {
logger := GetContextualLogger()
results := make([]OperationResult, 0, len(jails))
for _, jail := range jails {
// Add jail to context for this operation
jailCtx := WithJail(ctx, jail)
// Time the operation
start := time.Now()
code, err := opType.OperationCtx(jailCtx, client, ip, jail)
duration := time.Since(start)
if err != nil {
// Log the failed operation with timing
logger.LogBanOperation(jailCtx, opType.MetricsType, ip, jail, false, duration)
return nil, err
}
status := InterpretBanStatus(code, opType.MetricsType)
// Log the successful operation with timing
logger.LogBanOperation(jailCtx, opType.MetricsType, ip, jail, true, duration)
// Log the operation-specific message (ban vs unban)
Logger.WithFields(map[string]interface{}{
"ip": ip,
"jail": jail,
"status": status,
}).Info(opType.Message)
results = append(results, OperationResult{
IP: ip,
Jail: jail,
Status: status,
})
}
return results, nil
}
// ProcessBanOperation processes ban operations across multiple jails
func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
return ProcessOperation(client, ip, jails, BanOperationType)
}
// ProcessBanOperationWithContext processes ban operations across multiple jails with timeout context
func ProcessBanOperationWithContext(
ctx context.Context,
@@ -422,70 +521,12 @@ func ProcessBanOperationWithContext(
ip string,
jails []string,
) ([]OperationResult, error) {
logger := GetContextualLogger()
results := make([]OperationResult, 0, len(jails))
for _, jail := range jails {
// Add jail to context for this operation
jailCtx := WithJail(ctx, jail)
// Time the ban operation
start := time.Now()
code, err := client.BanIPWithContext(jailCtx, ip, jail)
duration := time.Since(start)
if err != nil {
// Log the failed operation with timing
logger.LogBanOperation(jailCtx, shared.MetricsBan, ip, jail, false, duration)
return nil, err
}
status := InterpretBanStatus(code, shared.MetricsBan)
// Log the successful operation with timing
logger.LogBanOperation(jailCtx, shared.MetricsBan, ip, jail, true, duration)
Logger.WithFields(map[string]interface{}{
"ip": ip,
"jail": jail,
"status": status,
}).Info(shared.MsgBanResult)
results = append(results, OperationResult{
IP: ip,
Jail: jail,
Status: status,
})
}
return results, nil
return ProcessOperationWithContext(ctx, client, ip, jails, BanOperationType)
}
// ProcessUnbanOperation processes unban operations across multiple jails
func ProcessUnbanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
results := make([]OperationResult, 0, len(jails))
for _, jail := range jails {
code, err := client.UnbanIP(ip, jail)
if err != nil {
return nil, err
}
status := InterpretBanStatus(code, shared.MetricsUnban)
Logger.WithFields(map[string]interface{}{
"ip": ip,
"jail": jail,
"status": status,
}).Info(shared.MsgUnbanResult)
results = append(results, OperationResult{
IP: ip,
Jail: jail,
Status: status,
})
}
return results, nil
return ProcessOperation(client, ip, jails, UnbanOperationType)
}
// ProcessUnbanOperationWithContext processes unban operations across multiple jails with timeout context
@@ -495,43 +536,7 @@ func ProcessUnbanOperationWithContext(
ip string,
jails []string,
) ([]OperationResult, error) {
logger := GetContextualLogger()
results := make([]OperationResult, 0, len(jails))
for _, jail := range jails {
// Add jail to context for this operation
jailCtx := WithJail(ctx, jail)
// Time the unban operation
start := time.Now()
code, err := client.UnbanIPWithContext(jailCtx, ip, jail)
duration := time.Since(start)
if err != nil {
// Log the failed operation with timing
logger.LogBanOperation(jailCtx, shared.MetricsUnban, ip, jail, false, duration)
return nil, err
}
status := InterpretBanStatus(code, shared.MetricsUnban)
// Log the successful operation with timing
logger.LogBanOperation(jailCtx, shared.MetricsUnban, ip, jail, true, duration)
Logger.WithFields(map[string]interface{}{
"ip": ip,
"jail": jail,
"status": status,
}).Info(shared.MsgUnbanResult)
results = append(results, OperationResult{
IP: ip,
Jail: jail,
Status: status,
})
}
return results, nil
return ProcessOperationWithContext(ctx, client, ip, jails, UnbanOperationType)
}
// Argument validation helpers

View File

@@ -12,9 +12,7 @@ import (
// TestProcessBanOperation tests the ProcessBanOperation function
func TestProcessBanOperation(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
defer fail2ban.WithTestRunner(t, fail2ban.GetRunner())()
tests := []struct {
name string
@@ -27,7 +25,7 @@ func TestProcessBanOperation(t *testing.T) {
{
name: "successful ban single jail",
setupMock: func(m *fail2ban.MockRunner) {
setupBasicMockResponses(m)
fail2ban.StandardMockSetup(m)
m.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
m.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
},
@@ -39,7 +37,7 @@ func TestProcessBanOperation(t *testing.T) {
{
name: "successful ban multiple jails",
setupMock: func(m *fail2ban.MockRunner) {
setupBasicMockResponses(m)
fail2ban.StandardMockSetup(m)
m.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
m.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
m.SetResponse("fail2ban-client set apache banip 192.168.1.1", []byte("1"))
@@ -53,7 +51,7 @@ func TestProcessBanOperation(t *testing.T) {
{
name: "invalid IP address",
setupMock: func(m *fail2ban.MockRunner) {
setupBasicMockResponses(m)
fail2ban.StandardMockSetup(m)
},
ip: "invalid.ip",
jails: []string{"sshd"},
@@ -147,13 +145,3 @@ func TestParseTimeoutFromEnv(t *testing.T) {
})
}
}
// setupBasicMockResponses is a helper for setting up version check and ping responses
func setupBasicMockResponses(m *fail2ban.MockRunner) {
m.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
m.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
m.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
m.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache"))
m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache"))
}

View File

@@ -87,14 +87,9 @@ func ExecuteIPCommand(
// Get the contextual logger
logger := GetContextualLogger()
// Safe timeout handling with nil check
timeout := shared.DefaultCommandTimeout
if config != nil && config.CommandTimeout > 0 {
timeout = config.CommandTimeout
}
// Create timeout context for the entire operation
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// Use cmd.Context() to inherit Cobra's signal cancellation
ctx, cancel := createTimeoutContext(cmd.Context(), config)
defer cancel()
// Add command context

View File

@@ -9,6 +9,43 @@ import (
"github.com/ivuorinen/f2b/fail2ban"
)
// multiJailOperationFunc defines the signature for IP operation functions that process multiple jails
type multiJailOperationFunc func(
ctx context.Context,
client fail2ban.Client,
ip string,
jails []string,
) ([]OperationResult, error)
// validateIPAndJails validates an IP address and a list of jail names.
// Returns an error if any validation fails.
func validateIPAndJails(ip string, jails []string) error {
if err := fail2ban.ValidateIP(ip); err != nil {
return err
}
for _, jail := range jails {
if err := fail2ban.ValidateJail(jail); err != nil {
return err
}
}
return nil
}
// processWithValidation validates inputs and executes the provided operation function.
// This consolidates the common validation-then-execute pattern.
func processWithValidation(
ctx context.Context,
client fail2ban.Client,
ip string,
jails []string,
opFunc multiJailOperationFunc,
) ([]OperationResult, error) {
if err := validateIPAndJails(ip, jails); err != nil {
return nil, err
}
return opFunc(ctx, client, ip, jails)
}
// BanProcessor handles ban operations
type BanProcessor struct{}
@@ -19,19 +56,7 @@ func (p *BanProcessor) ProcessSingle(
ip string,
jails []string,
) ([]OperationResult, error) {
// Validate IP address before privilege escalation
if err := fail2ban.ValidateIP(ip); err != nil {
return nil, err
}
// Validate each jail name before privilege escalation
for _, jail := range jails {
if err := fail2ban.ValidateJail(jail); err != nil {
return nil, err
}
}
return ProcessBanOperationWithContext(ctx, client, ip, jails)
return processWithValidation(ctx, client, ip, jails, ProcessBanOperationWithContext)
}
// ProcessParallel processes ban operations for multiple jails in parallel
@@ -41,19 +66,7 @@ func (p *BanProcessor) ProcessParallel(
ip string,
jails []string,
) ([]OperationResult, error) {
// Validate IP address before privilege escalation
if err := fail2ban.ValidateIP(ip); err != nil {
return nil, err
}
// Validate each jail name before privilege escalation
for _, jail := range jails {
if err := fail2ban.ValidateJail(jail); err != nil {
return nil, err
}
}
return ProcessBanOperationParallelWithContext(ctx, client, ip, jails)
return processWithValidation(ctx, client, ip, jails, ProcessBanOperationParallelWithContext)
}
// UnbanProcessor handles unban operations
@@ -66,19 +79,7 @@ func (p *UnbanProcessor) ProcessSingle(
ip string,
jails []string,
) ([]OperationResult, error) {
// Validate IP address before privilege escalation
if err := fail2ban.ValidateIP(ip); err != nil {
return nil, err
}
// Validate each jail name before privilege escalation
for _, jail := range jails {
if err := fail2ban.ValidateJail(jail); err != nil {
return nil, err
}
}
return ProcessUnbanOperationWithContext(ctx, client, ip, jails)
return processWithValidation(ctx, client, ip, jails, ProcessUnbanOperationWithContext)
}
// ProcessParallel processes unban operations for multiple jails in parallel
@@ -88,17 +89,5 @@ func (p *UnbanProcessor) ProcessParallel(
ip string,
jails []string,
) ([]OperationResult, error) {
// Validate IP address before privilege escalation
if err := fail2ban.ValidateIP(ip); err != nil {
return nil, err
}
// Validate each jail name before privilege escalation
for _, jail := range jails {
if err := fail2ban.ValidateJail(jail); err != nil {
return nil, err
}
}
return ProcessUnbanOperationParallelWithContext(ctx, client, ip, jails)
return processWithValidation(ctx, client, ip, jails, ProcessUnbanOperationParallelWithContext)
}

View File

@@ -5,10 +5,12 @@ package cmd
import (
"context"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/ivuorinen/f2b/fail2ban"
"github.com/ivuorinen/f2b/shared"
)
@@ -56,57 +58,68 @@ func getVersion() string {
return version
}
// contextKeyEntry defines a context key and its log field name
type contextKeyEntry struct {
key any // The context key to look up
fieldName string // The log field name to use
}
// contextKeys lists all context keys to extract for logging
var contextKeys = []contextKeyEntry{
{shared.ContextKeyRequestID, string(shared.ContextKeyRequestID)},
{shared.ContextKeyOperation, string(shared.ContextKeyOperation)},
{shared.ContextKeyIP, string(shared.ContextKeyIP)},
{shared.ContextKeyJail, string(shared.ContextKeyJail)},
{shared.ContextKeyCommand, string(shared.ContextKeyCommand)},
}
// WithContext creates a logger entry with context values
func (cl *ContextualLogger) WithContext(ctx context.Context) *logrus.Entry {
entry := cl.WithFields(cl.defaultFields)
// Extract context values and add as fields (using consistent constants)
if requestID := ctx.Value(shared.ContextKeyRequestID); requestID != nil {
entry = entry.WithField(string(shared.ContextKeyRequestID), requestID)
}
if operation := ctx.Value(shared.ContextKeyOperation); operation != nil {
entry = entry.WithField(string(shared.ContextKeyOperation), operation)
}
if ip := ctx.Value(shared.ContextKeyIP); ip != nil {
entry = entry.WithField(string(shared.ContextKeyIP), ip)
}
if jail := ctx.Value(shared.ContextKeyJail); jail != nil {
entry = entry.WithField(string(shared.ContextKeyJail), jail)
}
if command := ctx.Value(shared.ContextKeyCommand); command != nil {
entry = entry.WithField(string(shared.ContextKeyCommand), command)
// Extract context values and add as fields using table-driven approach
for _, ck := range contextKeys {
if val := ctx.Value(ck.key); val != nil {
entry = entry.WithField(ck.fieldName, val)
}
}
return entry
}
// WithOperation adds operation context and returns a new context
// WithOperation adds operation context and returns a new context.
// Delegates to fail2ban.WithOperation for consistent validation.
func WithOperation(ctx context.Context, operation string) context.Context {
return context.WithValue(ctx, shared.ContextKeyOperation, operation)
return fail2ban.WithOperation(ctx, operation)
}
// WithIP adds IP context and returns a new context
// WithIP adds IP context and returns a new context.
// Delegates to fail2ban.WithIP for consistent IP validation.
func WithIP(ctx context.Context, ip string) context.Context {
return context.WithValue(ctx, shared.ContextKeyIP, ip)
return fail2ban.WithIP(ctx, ip)
}
// WithJail adds jail context and returns a new context
// WithJail adds jail context and returns a new context.
// Delegates to fail2ban.WithJail for consistent jail name validation.
func WithJail(ctx context.Context, jail string) context.Context {
return context.WithValue(ctx, shared.ContextKeyJail, jail)
return fail2ban.WithJail(ctx, jail)
}
// WithCommand adds command context and returns a new context
// WithCommand adds command context and returns a new context.
// This is cmd-specific as fail2ban doesn't need command tracking.
// Empty commands are not stored in context.
func WithCommand(ctx context.Context, command string) context.Context {
command = strings.TrimSpace(command)
if command == "" {
return ctx
}
return context.WithValue(ctx, shared.ContextKeyCommand, command)
}
// WithRequestID adds request ID context and returns a new context
// WithRequestID adds request ID context and returns a new context.
// Delegates to fail2ban.WithRequestID for consistent validation.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, shared.ContextKeyRequestID, requestID)
return fail2ban.WithRequestID(ctx, requestID)
}
// LogOperation logs the start and end of an operation with timing and metrics

View File

@@ -68,17 +68,34 @@ func NewMetrics() *Metrics {
}
}
// recordOperationMetrics records metrics for any operation type.
// This helper consolidates the duplicate metrics recording pattern.
func (m *Metrics) recordOperationMetrics(
execCounter, durationCounter, failureCounter *int64,
buckets map[string]*LatencyBucket,
operation string,
duration time.Duration,
success bool,
) {
atomic.AddInt64(execCounter, 1)
atomic.AddInt64(durationCounter, duration.Milliseconds())
if !success {
atomic.AddInt64(failureCounter, 1)
}
m.recordLatencyBucket(buckets, operation, duration)
}
// RecordCommandExecution records metrics for command execution
func (m *Metrics) RecordCommandExecution(command string, duration time.Duration, success bool) {
atomic.AddInt64(&m.CommandExecutions, 1)
atomic.AddInt64(&m.CommandTotalDuration, duration.Milliseconds())
if !success {
atomic.AddInt64(&m.CommandFailures, 1)
}
// Record latency bucket
m.recordLatencyBucket(m.commandLatencyBuckets, command, duration)
m.recordOperationMetrics(
&m.CommandExecutions,
&m.CommandTotalDuration,
&m.CommandFailures,
m.commandLatencyBuckets,
command,
duration,
success,
)
}
// RecordBanOperation records metrics for ban operations
@@ -99,15 +116,15 @@ func (m *Metrics) RecordBanOperation(operation string, _ time.Duration, success
// RecordClientOperation records metrics for client operations
func (m *Metrics) RecordClientOperation(operation string, duration time.Duration, success bool) {
atomic.AddInt64(&m.ClientOperations, 1)
atomic.AddInt64(&m.ClientTotalDuration, duration.Milliseconds())
if !success {
atomic.AddInt64(&m.ClientFailures, 1)
}
// Record latency bucket
m.recordLatencyBucket(m.clientLatencyBuckets, operation, duration)
m.recordOperationMetrics(
&m.ClientOperations,
&m.ClientTotalDuration,
&m.ClientFailures,
m.clientLatencyBuckets,
operation,
duration,
success,
)
}
// RecordValidationCacheHit records validation cache hits
@@ -143,6 +160,25 @@ func (m *Metrics) UpdateGoroutineCount(count int64) {
atomic.StoreInt64(&m.GoroutineCount, count)
}
// copyBuckets creates a snapshot copy of latency buckets
// This helper consolidates the duplicate bucket copying logic
func copyBuckets(buckets map[string]*LatencyBucket) map[string]LatencyBucketSnapshot {
result := make(map[string]LatencyBucketSnapshot, len(buckets))
for op, bucket := range buckets {
result[op] = LatencyBucketSnapshot{
Under1ms: atomic.LoadInt64(&bucket.Under1ms),
Under10ms: atomic.LoadInt64(&bucket.Under10ms),
Under100ms: atomic.LoadInt64(&bucket.Under100ms),
Under1s: atomic.LoadInt64(&bucket.Under1s),
Under10s: atomic.LoadInt64(&bucket.Under10s),
Over10s: atomic.LoadInt64(&bucket.Over10s),
Total: atomic.LoadInt64(&bucket.Total),
TotalTime: atomic.LoadInt64(&bucket.TotalTime),
}
}
return result
}
// recordLatencyBucket records latency in appropriate bucket
func (m *Metrics) recordLatencyBucket(buckets map[string]*LatencyBucket, operation string, duration time.Duration) {
m.mu.Lock()
@@ -177,37 +213,8 @@ func (m *Metrics) recordLatencyBucket(buckets map[string]*LatencyBucket, operati
// GetSnapshot returns a snapshot of current metrics
func (m *Metrics) GetSnapshot() MetricsSnapshot {
m.mu.RLock()
// Copy command latency buckets
commandBuckets := make(map[string]LatencyBucketSnapshot)
for op, bucket := range m.commandLatencyBuckets {
commandBuckets[op] = LatencyBucketSnapshot{
Under1ms: atomic.LoadInt64(&bucket.Under1ms),
Under10ms: atomic.LoadInt64(&bucket.Under10ms),
Under100ms: atomic.LoadInt64(&bucket.Under100ms),
Under1s: atomic.LoadInt64(&bucket.Under1s),
Under10s: atomic.LoadInt64(&bucket.Under10s),
Over10s: atomic.LoadInt64(&bucket.Over10s),
Total: atomic.LoadInt64(&bucket.Total),
TotalTime: atomic.LoadInt64(&bucket.TotalTime),
}
}
// Copy client latency buckets
clientBuckets := make(map[string]LatencyBucketSnapshot)
for op, bucket := range m.clientLatencyBuckets {
clientBuckets[op] = LatencyBucketSnapshot{
Under1ms: atomic.LoadInt64(&bucket.Under1ms),
Under10ms: atomic.LoadInt64(&bucket.Under10ms),
Under100ms: atomic.LoadInt64(&bucket.Under100ms),
Under1s: atomic.LoadInt64(&bucket.Under1s),
Under10s: atomic.LoadInt64(&bucket.Under10s),
Over10s: atomic.LoadInt64(&bucket.Over10s),
Total: atomic.LoadInt64(&bucket.Total),
TotalTime: atomic.LoadInt64(&bucket.TotalTime),
}
}
commandBuckets := copyBuckets(m.commandLatencyBuckets)
clientBuckets := copyBuckets(m.clientLatencyBuckets)
m.mu.RUnlock()
return MetricsSnapshot{

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
"sort"
"strings"
"github.com/spf13/cobra"
@@ -97,35 +98,40 @@ func printMetricsPlain(output io.Writer, snapshot MetricsSnapshot) error {
// Command latency distribution
if len(snapshot.CommandLatencyBuckets) > 0 {
sb.WriteString("Command Latency Distribution:\n")
for cmd, bucket := range snapshot.CommandLatencyBuckets {
sb.WriteString(fmt.Sprintf(shared.MetricsFmtOperationHeader, cmd))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1s, bucket.Under1s))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10s, bucket.Under10s))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyOver10s, bucket.Over10s))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtAverageLatency, bucket.GetAverageLatency()))
}
formatLatencyBuckets(&sb, snapshot.CommandLatencyBuckets)
sb.WriteString("\n")
}
// Client latency distribution
if len(snapshot.ClientLatencyBuckets) > 0 {
sb.WriteString("Client Operation Latency Distribution:\n")
for op, bucket := range snapshot.ClientLatencyBuckets {
sb.WriteString(fmt.Sprintf(shared.MetricsFmtOperationHeader, op))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1s, bucket.Under1s))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10s, bucket.Under10s))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyOver10s, bucket.Over10s))
sb.WriteString(fmt.Sprintf(shared.MetricsFmtAverageLatency, bucket.GetAverageLatency()))
}
formatLatencyBuckets(&sb, snapshot.ClientLatencyBuckets)
}
// Write the entire string at once
_, err := output.Write([]byte(sb.String()))
return err
}
// formatLatencyBuckets writes latency bucket distribution to the builder.
// Keys are sorted for deterministic output.
func formatLatencyBuckets(sb *strings.Builder, buckets map[string]LatencyBucketSnapshot) {
// Sort keys for deterministic output
keys := make([]string, 0, len(buckets))
for name := range buckets {
keys = append(keys, name)
}
sort.Strings(keys)
for _, name := range keys {
bucket := buckets[name]
fmt.Fprintf(sb, shared.MetricsFmtOperationHeader, name)
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms)
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms)
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms)
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder1s, bucket.Under1s)
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder10s, bucket.Under10s)
fmt.Fprintf(sb, shared.MetricsFmtLatencyOver10s, bucket.Over10s)
fmt.Fprintf(sb, shared.MetricsFmtAverageLatency, bucket.GetAverageLatency())
}
}

View File

@@ -2,11 +2,11 @@ package cmd
import (
"context"
"errors"
"runtime"
"sync"
"github.com/ivuorinen/f2b/fail2ban"
"github.com/ivuorinen/f2b/shared"
)
// ParallelOperationProcessor handles parallel ban/unban operations across multiple jails
@@ -24,15 +24,16 @@ func NewParallelOperationProcessor(workerCount int) *ParallelOperationProcessor
}
}
// ProcessBanOperationParallel processes ban operations across multiple jails in parallel
func (pop *ParallelOperationProcessor) ProcessBanOperationParallel(
// ProcessOperationParallel processes operations across multiple jails in parallel
func (pop *ParallelOperationProcessor) ProcessOperationParallel(
client fail2ban.Client,
ip string,
jails []string,
opType OperationType,
) ([]OperationResult, error) {
if len(jails) <= 1 {
// For single jail, use sequential processing to avoid overhead
return ProcessBanOperation(client, ip, jails)
return ProcessOperation(client, ip, jails, opType)
}
return pop.processOperations(
@@ -40,13 +41,43 @@ func (pop *ParallelOperationProcessor) ProcessBanOperationParallel(
client,
ip,
jails,
func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
return client.BanIPWithContext(ctx, ip, jail)
},
shared.MetricsBan,
opType.OperationCtx,
opType.MetricsType,
)
}
// ProcessOperationParallelWithContext processes operations across multiple jails in parallel with context
func (pop *ParallelOperationProcessor) ProcessOperationParallelWithContext(
ctx context.Context,
client fail2ban.Client,
ip string,
jails []string,
opType OperationType,
) ([]OperationResult, error) {
if len(jails) <= 1 {
// For single jail, use sequential processing to avoid overhead
return ProcessOperationWithContext(ctx, client, ip, jails, opType)
}
return pop.processOperations(
ctx,
client,
ip,
jails,
opType.OperationCtx,
opType.MetricsType,
)
}
// ProcessBanOperationParallel processes ban operations across multiple jails in parallel
func (pop *ParallelOperationProcessor) ProcessBanOperationParallel(
client fail2ban.Client,
ip string,
jails []string,
) ([]OperationResult, error) {
return pop.ProcessOperationParallel(client, ip, jails, BanOperationType)
}
// ProcessBanOperationParallelWithContext processes ban operations across
// multiple jails in parallel with timeout context
func (pop *ParallelOperationProcessor) ProcessBanOperationParallelWithContext(
@@ -55,21 +86,7 @@ func (pop *ParallelOperationProcessor) ProcessBanOperationParallelWithContext(
ip string,
jails []string,
) ([]OperationResult, error) {
if len(jails) <= 1 {
// For single jail, use sequential processing to avoid overhead
return ProcessBanOperationWithContext(ctx, client, ip, jails)
}
return pop.processOperations(
ctx,
client,
ip,
jails,
func(opCtx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
return client.BanIPWithContext(opCtx, ip, jail)
},
shared.MetricsBan,
)
return pop.ProcessOperationParallelWithContext(ctx, client, ip, jails, BanOperationType)
}
// ProcessUnbanOperationParallel processes unban operations across multiple jails in parallel
@@ -78,21 +95,7 @@ func (pop *ParallelOperationProcessor) ProcessUnbanOperationParallel(
ip string,
jails []string,
) ([]OperationResult, error) {
if len(jails) <= 1 {
// For single jail, use sequential processing to avoid overhead
return ProcessUnbanOperation(client, ip, jails)
}
return pop.processOperations(
context.Background(),
client,
ip,
jails,
func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
return client.UnbanIPWithContext(ctx, ip, jail)
},
shared.MetricsUnban,
)
return pop.ProcessOperationParallel(client, ip, jails, UnbanOperationType)
}
// ProcessUnbanOperationParallelWithContext processes unban operations across
@@ -103,26 +106,35 @@ func (pop *ParallelOperationProcessor) ProcessUnbanOperationParallelWithContext(
ip string,
jails []string,
) ([]OperationResult, error) {
if len(jails) <= 1 {
// For single jail, use sequential processing to avoid overhead
return ProcessUnbanOperationWithContext(ctx, client, ip, jails)
}
return pop.processOperations(
ctx,
client,
ip,
jails,
func(opCtx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
return client.UnbanIPWithContext(opCtx, ip, jail)
},
shared.MetricsUnban,
)
return pop.ProcessOperationParallelWithContext(ctx, client, ip, jails, UnbanOperationType)
}
// operationFunc represents a ban or unban operation with context
type operationFunc func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error)
// validateOperationInputs validates IP and jail inputs before parallel processing.
// Returns an aggregated error if any inputs are invalid.
func validateOperationInputs(ctx context.Context, ip string, jails []string) error {
var errs []error
// Validate IP address
if err := fail2ban.CachedValidateIP(ctx, ip); err != nil {
errs = append(errs, err)
}
// Validate each jail name
for _, jail := range jails {
if err := fail2ban.CachedValidateJail(ctx, jail); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
// processOperations handles the parallel processing of operations
func (pop *ParallelOperationProcessor) processOperations(
ctx context.Context,
@@ -132,6 +144,11 @@ func (pop *ParallelOperationProcessor) processOperations(
operation operationFunc,
operationType string,
) ([]OperationResult, error) {
// Validate inputs before processing
if err := validateOperationInputs(ctx, ip, jails); err != nil {
return nil, err
}
results := make([]OperationResult, len(jails))
resultCh := make(chan operationResult, len(jails))
@@ -167,13 +184,20 @@ func (pop *ParallelOperationProcessor) processOperations(
close(resultCh)
}()
// Collect results
// Collect results and errors
var errs []error
for result := range resultCh {
if result.index >= 0 && result.index < len(results) {
results[result.index] = result.result
}
if result.err != nil {
errs = append(errs, result.err)
}
}
if len(errs) > 0 {
return results, errors.Join(errs...)
}
return results, nil
}
@@ -187,6 +211,7 @@ type jailWork struct {
type operationResult struct {
result OperationResult
index int
err error
}
// worker processes jail operations
@@ -215,16 +240,15 @@ func (pop *ParallelOperationProcessor) worker(
"status": status,
}).Info("Operation result")
result := operationResult{
resultCh <- operationResult{
result: OperationResult{
IP: ip,
Jail: work.jail,
Status: status,
},
index: work.index,
err: err,
}
resultCh <- result
}
}

View File

@@ -12,17 +12,13 @@ import (
// TestUnbanProcessorProcessParallel tests the ProcessParallel method
func TestUnbanProcessorProcessParallel(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
mockRunner := fail2ban.NewMockRunner()
setupBasicMockResponses(mockRunner)
defer fail2ban.WithTestRunner(t, mockRunner)()
fail2ban.StandardMockSetup(mockRunner)
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("fail2ban-client set apache unbanip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("sudo fail2ban-client set apache unbanip 192.168.1.1", []byte("1"))
fail2ban.SetRunner(mockRunner)
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
require.NoError(t, err)

View File

@@ -12,17 +12,13 @@ import (
// TestProcessBanOperationParallel tests the ProcessBanOperationParallel wrapper function
func TestProcessBanOperationParallel(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
mockRunner := fail2ban.NewMockRunner()
setupBasicMockResponses(mockRunner)
defer fail2ban.WithTestRunner(t, mockRunner)()
fail2ban.StandardMockSetup(mockRunner)
mockRunner.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("fail2ban-client set apache banip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("sudo fail2ban-client set apache banip 192.168.1.1", []byte("1"))
fail2ban.SetRunner(mockRunner)
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
require.NoError(t, err)
@@ -34,15 +30,11 @@ func TestProcessBanOperationParallel(t *testing.T) {
// TestProcessUnbanOperationParallel tests the ProcessUnbanOperationParallel wrapper function
func TestProcessUnbanOperationParallel(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
mockRunner := fail2ban.NewMockRunner()
setupBasicMockResponses(mockRunner)
defer fail2ban.WithTestRunner(t, mockRunner)()
fail2ban.StandardMockSetup(mockRunner)
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
fail2ban.SetRunner(mockRunner)
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
require.NoError(t, err)
@@ -54,15 +46,11 @@ func TestProcessUnbanOperationParallel(t *testing.T) {
// TestProcessBanOperationParallelWithContext tests the wrapper with context
func TestProcessBanOperationParallelWithContext(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
mockRunner := fail2ban.NewMockRunner()
setupBasicMockResponses(mockRunner)
defer fail2ban.WithTestRunner(t, mockRunner)()
fail2ban.StandardMockSetup(mockRunner)
mockRunner.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
fail2ban.SetRunner(mockRunner)
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
require.NoError(t, err)
@@ -75,15 +63,11 @@ func TestProcessBanOperationParallelWithContext(t *testing.T) {
// TestProcessUnbanOperationParallelWithContext tests the wrapper with context
func TestProcessUnbanOperationParallelWithContext(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
mockRunner := fail2ban.NewMockRunner()
setupBasicMockResponses(mockRunner)
defer fail2ban.WithTestRunner(t, mockRunner)()
fail2ban.StandardMockSetup(mockRunner)
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
fail2ban.SetRunner(mockRunner)
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
require.NoError(t, err)

View File

@@ -15,7 +15,11 @@ func ServiceCmd(config *Config) *cobra.Command {
nil,
func(_ *cobra.Command, args []string) error {
// Validate service action argument
if err := RequireArguments(args, 1, "action required: start|stop|restart|status|reload|enable|disable"); err != nil {
if err := RequireArguments(
args,
1,
"action required: start|stop|restart|status|reload|enable|disable",
); err != nil {
return HandleValidationError(err)
}

View File

@@ -110,18 +110,14 @@ func TestValidateConfigPath(t *testing.T) {
// TestLogsWatchCmdCreation tests LogsWatchCmd creation
func TestLogsWatchCmdCreation(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
mockRunner := fail2ban.NewMockRunner()
defer fail2ban.WithTestRunner(t, mockRunner)()
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
fail2ban.SetRunner(mockRunner)
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
require.NoError(t, err)
@@ -142,18 +138,14 @@ func TestLogsWatchCmdCreation(t *testing.T) {
// TestGetLogLinesWithLimitAndContext_Function tests the function
func TestGetLogLinesWithLimitAndContext_Function(t *testing.T) {
// Save and restore original runner
originalRunner := fail2ban.GetRunner()
defer fail2ban.SetRunner(originalRunner)
mockRunner := fail2ban.NewMockRunner()
defer fail2ban.WithTestRunner(t, mockRunner)()
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
fail2ban.SetRunner(mockRunner)
tmpDir := t.TempDir()
oldLogDir := fail2ban.GetLogDir()