mirror of
https://github.com/ivuorinen/f2b.git
synced 2026-01-26 03:13:58 +00:00
refactor: linting, simplification and fixes (#119)
* refactor: consolidate test helpers and reduce code duplication
- Fix prealloc lint issue in cmd_logswatch_test.go
- Add validateIPAndJails helper to consolidate IP/jail validation
- Add WithTestRunner/WithTestSudoChecker helpers for cleaner test setup
- Replace setupBasicMockResponses duplicates with StandardMockSetup
- Add SetupStandardResponses/SetupJailResponses to MockRunner
- Delegate cmd context helpers to fail2ban implementations
- Document context wrapper pattern in context_helpers.go
* refactor: consolidate duplicate code patterns across cmd and fail2ban packages
Add helper functions to reduce code duplication found by dupl:
- safeCloseFile/safeCloseReader: centralize file close error logging
- createTimeoutContext: consolidate timeout context creation pattern
- withContextCheck: wrap context cancellation checks
- recordOperationMetrics: unify metrics recording for commands/clients
Also includes Phase 1 consolidations:
- copyBuckets helper for metrics snapshots
- Table-driven context extraction in logging
- processWithValidation helper for IP processors
* refactor: consolidate LoggerInterface by embedding LoggerEntry
Both interfaces had identical method signatures. LoggerInterface now
embeds LoggerEntry to eliminate code duplication.
* refactor: consolidate test framework helpers and fix test patterns
- Add checkJSONFieldValue and failMissingJSONField helpers to reduce
duplication in JSON assertion methods
- Add ParallelTimeout to default test config
- Fix test to use WithTestRunner inside test loop for proper mock scoping
* refactor: unify ban/unban operations with OperationType pattern
Introduce OperationType struct to consolidate duplicate ban/unban logic:
- Add ProcessOperation and ProcessOperationWithContext generic functions
- Add ProcessOperationParallel and ProcessOperationParallelWithContext
- Existing ProcessBan*/ProcessUnban* functions now delegate to generic versions
- Reduces ~120 lines of duplicate code between ban and unban operations
* refactor: consolidate time parsing cache pattern
Add ParseWithLayout method to BoundedTimeCache that consolidates the
cache-lookup-parse-store pattern. FastTimeCache and TimeParsingCache
now delegate to this method instead of duplicating the logic.
* refactor: consolidate command execution patterns in fail2ban
- Add validateCommandExecution helper for command/argument validation
- Add runWithTimerContext helper for timed runner operations
- Add executeIPActionWithContext to unify BanIP/UnbanIP implementations
- Reduces duplicate validation and execution boilerplate
* refactor: consolidate logrus adapter with embedded loggerCore
Introduce loggerCore type that provides the 8 standard logging methods
(Debug, Info, Warn, Error, Debugf, Infof, Warnf, Errorf). Both
logrusAdapter and logrusEntryAdapter now embed this type, eliminating
16 duplicate method implementations.
* refactor: consolidate path validation patterns
- Add validateConfigPathWithFallback helper in cmd/config_utils.go
for the validate-or-fallback-with-logging pattern
- Add validateClientPath helper in fail2ban/helpers.go for client
path validation delegation
* fix: add context cancellation checks to wrapper functions
- wrapWithContext0/1/2 now check ctx.Err() before invoking wrapped function
- WithCommand now validates and trims empty command strings
* refactor: extract formatLatencyBuckets for deterministic metrics output
Add formatLatencyBuckets helper that writes latency bucket distribution
with sorted keys for deterministic output, eliminating duplicate
formatting code for command and client latency buckets.
* refactor: add generic setNestedMapValue helper for mock configuration
Add setNestedMapValue[T] generic helper that consolidates the repeated
pattern of mutex-protected nested map initialization and value setting
used by SetBanError, SetBanResult, SetUnbanError, and SetUnbanResult.
* fix: use cmd.Context() for signal propagation and correct mock status
- ExecuteIPCommand now uses cmd.Context() instead of context.Background()
to inherit Cobra's signal cancellation
- MockRunner.SetupJailResponses uses shared.Fail2BanStatusSuccess ("0")
instead of literal "1" for proper success path simulation
* fix: restore operation-specific log messages in ProcessOperationWithContext
Add back Logger.WithFields().Info(opType.Message) call that was lost
during refactoring. This restores the distinction between ban and unban
operation messages (shared.MsgBanResult vs shared.MsgUnbanResult).
* fix: return aggregated errors from parallel operations
Previously, errors from individual parallel operations were silently
swallowed - converted to status strings but never returned to callers.
Now processOperations collects all errors and returns them aggregated
via errors.Join, allowing callers to distinguish partial failures from
complete success while still receiving all results.
* fix: add input validation to processOperations before parallel execution
Validate IP and jail inputs at the start of processOperations() using
fail2ban.CachedValidateIP and CachedValidateJail. This prevents invalid
or malicious inputs (empty values, path traversal attempts, malformed
IPs) from reaching the operation functions. All validation errors are
aggregated and returned before any operations execute.
This commit is contained in:
@@ -301,7 +301,7 @@ func (m *MockLogsWatchClient) GetLogLines(jail, ip string) ([]string, error) {
|
|||||||
logs = m.initialLogs
|
logs = m.initialLogs
|
||||||
} else {
|
} else {
|
||||||
// Simulate new logs being added
|
// Simulate new logs being added
|
||||||
logs = make([]string, len(m.initialLogs))
|
logs = make([]string, len(m.initialLogs), len(m.initialLogs)+1)
|
||||||
copy(logs, m.initialLogs)
|
copy(logs, m.initialLogs)
|
||||||
logs = append(logs, fmt.Sprintf("new log line %d", m.callCount))
|
logs = append(logs, fmt.Sprintf("new log line %d", m.callCount))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func TestParallelOperationProcessor_EmptyJailsList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParallelOperationProcessor_ErrorHandling(t *testing.T) {
|
func TestParallelOperationProcessor_ErrorHandling(t *testing.T) {
|
||||||
// Test error handling doesn't cause index issues
|
// Test error handling returns aggregated errors while still populating results
|
||||||
processor := NewParallelOperationProcessor(2)
|
processor := NewParallelOperationProcessor(2)
|
||||||
|
|
||||||
// Mock client for testing
|
// Mock client for testing
|
||||||
@@ -99,15 +99,16 @@ func TestParallelOperationProcessor_ErrorHandling(t *testing.T) {
|
|||||||
|
|
||||||
results, err := processor.ProcessBanOperationParallel(mockClient, "192.168.1.100", jails)
|
results, err := processor.ProcessBanOperationParallel(mockClient, "192.168.1.100", jails)
|
||||||
|
|
||||||
if err != nil {
|
// Errors should now be returned (aggregated)
|
||||||
t.Fatalf("ProcessBanOperationParallel failed: %v", err)
|
if err == nil {
|
||||||
|
t.Error("Expected error for non-existent jails, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(results) != 2 {
|
if len(results) != 2 {
|
||||||
t.Errorf("Expected 2 results, got %d", len(results))
|
t.Errorf("Expected 2 results, got %d", len(results))
|
||||||
}
|
}
|
||||||
|
|
||||||
// All results should have errors for non-existent jails
|
// All results should still be populated with error status
|
||||||
for i, result := range results {
|
for i, result := range results {
|
||||||
if result.Jail == "" {
|
if result.Jail == "" {
|
||||||
t.Errorf("Result %d has empty jail", i)
|
t.Errorf("Result %d has empty jail", i)
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ func NewCommandTest(t *testing.T, commandName string) *CommandTestBuilder {
|
|||||||
Format: PlainFormat,
|
Format: PlainFormat,
|
||||||
CommandTimeout: shared.DefaultCommandTimeout,
|
CommandTimeout: shared.DefaultCommandTimeout,
|
||||||
FileTimeout: shared.DefaultFileTimeout,
|
FileTimeout: shared.DefaultFileTimeout,
|
||||||
|
ParallelTimeout: shared.DefaultParallelTimeout,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -418,6 +419,20 @@ func (result *CommandTestResult) AssertExactOutput(expected string) *CommandTest
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkJSONFieldValue validates that a JSON field value matches the expected string.
|
||||||
|
func (result *CommandTestResult) checkJSONFieldValue(val interface{}, fieldName, expected string) {
|
||||||
|
result.t.Helper()
|
||||||
|
if fmt.Sprintf("%v", val) != expected {
|
||||||
|
result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// failMissingJSONField reports a missing JSON field with context.
|
||||||
|
func (result *CommandTestResult) failMissingJSONField(fieldName, context string) {
|
||||||
|
result.t.Helper()
|
||||||
|
result.t.Fatalf("%s: JSON field %q not found%s: %s", result.name, fieldName, context, result.Output)
|
||||||
|
}
|
||||||
|
|
||||||
// AssertJSONField validates a specific field in JSON output
|
// AssertJSONField validates a specific field in JSON output
|
||||||
func (result *CommandTestResult) AssertJSONField(fieldPath, expected string) *CommandTestResult {
|
func (result *CommandTestResult) AssertJSONField(fieldPath, expected string) *CommandTestResult {
|
||||||
result.t.Helper()
|
result.t.Helper()
|
||||||
@@ -434,22 +449,18 @@ func (result *CommandTestResult) AssertJSONField(fieldPath, expected string) *Co
|
|||||||
switch v := data.(type) {
|
switch v := data.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
if val, ok := v[fieldName]; ok {
|
if val, ok := v[fieldName]; ok {
|
||||||
if fmt.Sprintf("%v", val) != expected {
|
result.checkJSONFieldValue(val, fieldName, expected)
|
||||||
result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
result.t.Fatalf("%s: JSON field %q not found in output: %s", result.name, fieldName, result.Output)
|
result.failMissingJSONField(fieldName, " in output")
|
||||||
}
|
}
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
// Handle array case - look in first element
|
// Handle array case - look in first element
|
||||||
if len(v) > 0 {
|
if len(v) > 0 {
|
||||||
if firstItem, ok := v[0].(map[string]interface{}); ok {
|
if firstItem, ok := v[0].(map[string]interface{}); ok {
|
||||||
if val, ok := firstItem[fieldName]; ok {
|
if val, ok := firstItem[fieldName]; ok {
|
||||||
if fmt.Sprintf("%v", val) != expected {
|
result.checkJSONFieldValue(val, fieldName, expected)
|
||||||
result.t.Fatalf(shared.ErrTestJSONFieldMismatch, result.name, fieldName, expected, val)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
result.t.Fatalf("%s: JSON field %q not found in first array element: %s", result.name, fieldName, result.Output)
|
result.failMissingJSONField(fieldName, " in first array element")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
result.t.Fatalf("%s: first array element is not an object in output: %s", result.name, result.Output)
|
result.t.Fatalf("%s: first array element is not an object in output: %s", result.name, result.Output)
|
||||||
|
|||||||
@@ -12,13 +12,9 @@ import (
|
|||||||
|
|
||||||
// TestTestFilterCmdCreation tests TestFilterCmd command creation
|
// TestTestFilterCmdCreation tests TestFilterCmd command creation
|
||||||
func TestTestFilterCmdCreation(t *testing.T) {
|
func TestTestFilterCmdCreation(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
setupBasicMockResponses(mockRunner)
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
fail2ban.SetRunner(mockRunner)
|
fail2ban.StandardMockSetup(mockRunner)
|
||||||
|
|
||||||
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -39,10 +35,6 @@ func TestTestFilterCmdCreation(t *testing.T) {
|
|||||||
|
|
||||||
// TestTestFilterCmdExecution tests TestFilterCmd execution
|
// TestTestFilterCmdExecution tests TestFilterCmd execution
|
||||||
func TestTestFilterCmdExecution(t *testing.T) {
|
func TestTestFilterCmdExecution(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupMock func(*fail2ban.MockRunner)
|
setupMock func(*fail2ban.MockRunner)
|
||||||
@@ -52,7 +44,7 @@ func TestTestFilterCmdExecution(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful filter test",
|
name: "successful filter test",
|
||||||
setupMock: func(m *fail2ban.MockRunner) {
|
setupMock: func(m *fail2ban.MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
fail2ban.StandardMockSetup(m)
|
||||||
m.SetResponse("fail2ban-client get sshd logpath", []byte("/var/log/auth.log"))
|
m.SetResponse("fail2ban-client get sshd logpath", []byte("/var/log/auth.log"))
|
||||||
m.SetResponse("sudo fail2ban-client get sshd logpath", []byte("/var/log/auth.log"))
|
m.SetResponse("sudo fail2ban-client get sshd logpath", []byte("/var/log/auth.log"))
|
||||||
},
|
},
|
||||||
@@ -62,7 +54,7 @@ func TestTestFilterCmdExecution(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "no filter provided - lists available",
|
name: "no filter provided - lists available",
|
||||||
setupMock: func(m *fail2ban.MockRunner) {
|
setupMock: func(m *fail2ban.MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
fail2ban.StandardMockSetup(m)
|
||||||
// Mock ListFiltersWithContext response
|
// Mock ListFiltersWithContext response
|
||||||
},
|
},
|
||||||
args: []string{},
|
args: []string{},
|
||||||
@@ -71,7 +63,7 @@ func TestTestFilterCmdExecution(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid filter name",
|
name: "invalid filter name",
|
||||||
setupMock: func(m *fail2ban.MockRunner) {
|
setupMock: func(m *fail2ban.MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
fail2ban.StandardMockSetup(m)
|
||||||
},
|
},
|
||||||
args: []string{"../../../etc/passwd"},
|
args: []string{"../../../etc/passwd"},
|
||||||
expectError: true,
|
expectError: true,
|
||||||
@@ -81,8 +73,8 @@ func TestTestFilterCmdExecution(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
tt.setupMock(mockRunner)
|
tt.setupMock(mockRunner)
|
||||||
fail2ban.SetRunner(mockRunner)
|
|
||||||
|
|
||||||
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -164,6 +164,17 @@ func validateConfigPath(path, pathType string) (string, error) {
|
|||||||
return absPath, nil
|
return absPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateConfigPathWithFallback validates a config path and returns the fallback if validation fails.
|
||||||
|
// This consolidates the common pattern of validate-or-fallback-with-logging used for config paths.
|
||||||
|
func validateConfigPathWithFallback(path, pathType, defaultPath, errorMsg string) string {
|
||||||
|
validated, err := validateConfigPath(path, pathType)
|
||||||
|
if err != nil {
|
||||||
|
Logger.WithError(err).WithField(shared.LogFieldPath, path).Error(errorMsg)
|
||||||
|
return defaultPath
|
||||||
|
}
|
||||||
|
return validated
|
||||||
|
}
|
||||||
|
|
||||||
// isReasonableSystemPath checks if a path is in a reasonable system location
|
// isReasonableSystemPath checks if a path is in a reasonable system location
|
||||||
func isReasonableSystemPath(path, pathType string) bool {
|
func isReasonableSystemPath(path, pathType string) bool {
|
||||||
// Allow common system directories based on path type
|
// Allow common system directories based on path type
|
||||||
@@ -195,28 +206,20 @@ func NewConfigFromEnv() Config {
|
|||||||
if logDir == "" {
|
if logDir == "" {
|
||||||
logDir = shared.DefaultLogDir
|
logDir = shared.DefaultLogDir
|
||||||
}
|
}
|
||||||
|
cfg.LogDir = validateConfigPathWithFallback(
|
||||||
validatedLogDir, err := validateConfigPath(logDir, shared.PathTypeLog)
|
logDir, shared.PathTypeLog, shared.DefaultLogDir,
|
||||||
if err != nil {
|
"Invalid log directory from environment",
|
||||||
Logger.WithError(err).WithField(shared.LogFieldPath, logDir).Error("Invalid log directory from environment")
|
)
|
||||||
validatedLogDir = shared.DefaultLogDir // Fallback to safe default
|
|
||||||
}
|
|
||||||
cfg.LogDir = validatedLogDir
|
|
||||||
|
|
||||||
// Get and validate filter directory
|
// Get and validate filter directory
|
||||||
filterDir := os.Getenv("F2B_FILTER_DIR")
|
filterDir := os.Getenv("F2B_FILTER_DIR")
|
||||||
if filterDir == "" {
|
if filterDir == "" {
|
||||||
filterDir = shared.DefaultFilterDir
|
filterDir = shared.DefaultFilterDir
|
||||||
}
|
}
|
||||||
|
cfg.FilterDir = validateConfigPathWithFallback(
|
||||||
validatedFilterDir, err := validateConfigPath(filterDir, shared.PathTypeFilter)
|
filterDir, shared.PathTypeFilter, shared.DefaultFilterDir,
|
||||||
if err != nil {
|
"Invalid filter directory from environment",
|
||||||
Logger.WithError(err).
|
)
|
||||||
WithField(shared.LogFieldPath, filterDir).
|
|
||||||
Error("Invalid filter directory from environment")
|
|
||||||
validatedFilterDir = shared.DefaultFilterDir // Fallback to safe default
|
|
||||||
}
|
|
||||||
cfg.FilterDir = validatedFilterDir
|
|
||||||
|
|
||||||
// Configure timeouts from environment variables
|
// Configure timeouts from environment variables
|
||||||
cfg.CommandTimeout = parseTimeoutFromEnv("F2B_COMMAND_TIMEOUT", shared.DefaultCommandTimeout)
|
cfg.CommandTimeout = parseTimeoutFromEnv("F2B_COMMAND_TIMEOUT", shared.DefaultCommandTimeout)
|
||||||
|
|||||||
231
cmd/helpers.go
231
cmd/helpers.go
@@ -17,6 +17,20 @@ import (
|
|||||||
"github.com/ivuorinen/f2b/fail2ban"
|
"github.com/ivuorinen/f2b/fail2ban"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// createTimeoutContext creates a context with the configured command timeout.
|
||||||
|
// This helper consolidates the duplicate timeout handling pattern.
|
||||||
|
// If base is nil, context.Background() is used.
|
||||||
|
func createTimeoutContext(base context.Context, config *Config) (context.Context, context.CancelFunc) {
|
||||||
|
if base == nil {
|
||||||
|
base = context.Background()
|
||||||
|
}
|
||||||
|
timeout := shared.DefaultCommandTimeout
|
||||||
|
if config != nil && config.CommandTimeout > 0 {
|
||||||
|
timeout = config.CommandTimeout
|
||||||
|
}
|
||||||
|
return context.WithTimeout(base, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
// IsCI detects if we're running in a CI environment
|
// IsCI detects if we're running in a CI environment
|
||||||
func IsCI() bool {
|
func IsCI() bool {
|
||||||
return fail2ban.IsCI()
|
return fail2ban.IsCI()
|
||||||
@@ -50,17 +64,8 @@ func NewContextualCommand(
|
|||||||
// Get the contextual logger
|
// Get the contextual logger
|
||||||
logger := GetContextualLogger()
|
logger := GetContextualLogger()
|
||||||
|
|
||||||
// Base on Cobra's context so signals/cancellations propagate
|
// Create timeout context based on Cobra's context so signals/cancellations propagate
|
||||||
base := cmd.Context()
|
ctx, cancel := createTimeoutContext(cmd.Context(), config)
|
||||||
if base == nil {
|
|
||||||
base = context.Background()
|
|
||||||
}
|
|
||||||
// Create timeout context for the entire operation
|
|
||||||
timeout := shared.DefaultCommandTimeout
|
|
||||||
if config != nil && config.CommandTimeout > 0 {
|
|
||||||
timeout = config.CommandTimeout
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(base, timeout)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Extract command name from use string (first word)
|
// Extract command name from use string (first word)
|
||||||
@@ -388,22 +393,63 @@ type OperationResult struct {
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessBanOperation processes ban operations across multiple jails
|
// OperationType defines a ban or unban operation with its associated metadata
|
||||||
func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
type OperationType struct {
|
||||||
|
// MetricsType is the metrics key for this operation (e.g., shared.MetricsBan)
|
||||||
|
MetricsType string
|
||||||
|
// Message is the log message for this operation (e.g., shared.MsgBanResult)
|
||||||
|
Message string
|
||||||
|
// Operation is the function to execute without context
|
||||||
|
Operation func(client fail2ban.Client, ip, jail string) (int, error)
|
||||||
|
// OperationCtx is the function to execute with context
|
||||||
|
OperationCtx func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BanOperationType defines the ban operation
|
||||||
|
var BanOperationType = OperationType{
|
||||||
|
MetricsType: shared.MetricsBan,
|
||||||
|
Message: shared.MsgBanResult,
|
||||||
|
Operation: func(c fail2ban.Client, ip, jail string) (int, error) {
|
||||||
|
return c.BanIP(ip, jail)
|
||||||
|
},
|
||||||
|
OperationCtx: func(ctx context.Context, c fail2ban.Client, ip, jail string) (int, error) {
|
||||||
|
return c.BanIPWithContext(ctx, ip, jail)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnbanOperationType defines the unban operation
|
||||||
|
var UnbanOperationType = OperationType{
|
||||||
|
MetricsType: shared.MetricsUnban,
|
||||||
|
Message: shared.MsgUnbanResult,
|
||||||
|
Operation: func(c fail2ban.Client, ip, jail string) (int, error) {
|
||||||
|
return c.UnbanIP(ip, jail)
|
||||||
|
},
|
||||||
|
OperationCtx: func(ctx context.Context, c fail2ban.Client, ip, jail string) (int, error) {
|
||||||
|
return c.UnbanIPWithContext(ctx, ip, jail)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessOperation processes operations across multiple jails using the specified operation type
|
||||||
|
func ProcessOperation(
|
||||||
|
client fail2ban.Client,
|
||||||
|
ip string,
|
||||||
|
jails []string,
|
||||||
|
opType OperationType,
|
||||||
|
) ([]OperationResult, error) {
|
||||||
results := make([]OperationResult, 0, len(jails))
|
results := make([]OperationResult, 0, len(jails))
|
||||||
|
|
||||||
for _, jail := range jails {
|
for _, jail := range jails {
|
||||||
code, err := client.BanIP(ip, jail)
|
code, err := opType.Operation(client, ip, jail)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
status := InterpretBanStatus(code, shared.MetricsBan)
|
status := InterpretBanStatus(code, opType.MetricsType)
|
||||||
Logger.WithFields(map[string]interface{}{
|
Logger.WithFields(map[string]interface{}{
|
||||||
"ip": ip,
|
"ip": ip,
|
||||||
"jail": jail,
|
"jail": jail,
|
||||||
"status": status,
|
"status": status,
|
||||||
}).Info(shared.MsgBanResult)
|
}).Info(opType.Message)
|
||||||
|
|
||||||
results = append(results, OperationResult{
|
results = append(results, OperationResult{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
@@ -415,6 +461,59 @@ func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]O
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProcessOperationWithContext processes operations across multiple jails with timeout context
|
||||||
|
func ProcessOperationWithContext(
|
||||||
|
ctx context.Context,
|
||||||
|
client fail2ban.Client,
|
||||||
|
ip string,
|
||||||
|
jails []string,
|
||||||
|
opType OperationType,
|
||||||
|
) ([]OperationResult, error) {
|
||||||
|
logger := GetContextualLogger()
|
||||||
|
results := make([]OperationResult, 0, len(jails))
|
||||||
|
|
||||||
|
for _, jail := range jails {
|
||||||
|
// Add jail to context for this operation
|
||||||
|
jailCtx := WithJail(ctx, jail)
|
||||||
|
|
||||||
|
// Time the operation
|
||||||
|
start := time.Now()
|
||||||
|
code, err := opType.OperationCtx(jailCtx, client, ip, jail)
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Log the failed operation with timing
|
||||||
|
logger.LogBanOperation(jailCtx, opType.MetricsType, ip, jail, false, duration)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
status := InterpretBanStatus(code, opType.MetricsType)
|
||||||
|
|
||||||
|
// Log the successful operation with timing
|
||||||
|
logger.LogBanOperation(jailCtx, opType.MetricsType, ip, jail, true, duration)
|
||||||
|
|
||||||
|
// Log the operation-specific message (ban vs unban)
|
||||||
|
Logger.WithFields(map[string]interface{}{
|
||||||
|
"ip": ip,
|
||||||
|
"jail": jail,
|
||||||
|
"status": status,
|
||||||
|
}).Info(opType.Message)
|
||||||
|
|
||||||
|
results = append(results, OperationResult{
|
||||||
|
IP: ip,
|
||||||
|
Jail: jail,
|
||||||
|
Status: status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessBanOperation processes ban operations across multiple jails
|
||||||
|
func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
||||||
|
return ProcessOperation(client, ip, jails, BanOperationType)
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessBanOperationWithContext processes ban operations across multiple jails with timeout context
|
// ProcessBanOperationWithContext processes ban operations across multiple jails with timeout context
|
||||||
func ProcessBanOperationWithContext(
|
func ProcessBanOperationWithContext(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -422,70 +521,12 @@ func ProcessBanOperationWithContext(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
logger := GetContextualLogger()
|
return ProcessOperationWithContext(ctx, client, ip, jails, BanOperationType)
|
||||||
results := make([]OperationResult, 0, len(jails))
|
|
||||||
|
|
||||||
for _, jail := range jails {
|
|
||||||
// Add jail to context for this operation
|
|
||||||
jailCtx := WithJail(ctx, jail)
|
|
||||||
|
|
||||||
// Time the ban operation
|
|
||||||
start := time.Now()
|
|
||||||
code, err := client.BanIPWithContext(jailCtx, ip, jail)
|
|
||||||
duration := time.Since(start)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
// Log the failed operation with timing
|
|
||||||
logger.LogBanOperation(jailCtx, shared.MetricsBan, ip, jail, false, duration)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
status := InterpretBanStatus(code, shared.MetricsBan)
|
|
||||||
|
|
||||||
// Log the successful operation with timing
|
|
||||||
logger.LogBanOperation(jailCtx, shared.MetricsBan, ip, jail, true, duration)
|
|
||||||
|
|
||||||
Logger.WithFields(map[string]interface{}{
|
|
||||||
"ip": ip,
|
|
||||||
"jail": jail,
|
|
||||||
"status": status,
|
|
||||||
}).Info(shared.MsgBanResult)
|
|
||||||
|
|
||||||
results = append(results, OperationResult{
|
|
||||||
IP: ip,
|
|
||||||
Jail: jail,
|
|
||||||
Status: status,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessUnbanOperation processes unban operations across multiple jails
|
// ProcessUnbanOperation processes unban operations across multiple jails
|
||||||
func ProcessUnbanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
func ProcessUnbanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
||||||
results := make([]OperationResult, 0, len(jails))
|
return ProcessOperation(client, ip, jails, UnbanOperationType)
|
||||||
|
|
||||||
for _, jail := range jails {
|
|
||||||
code, err := client.UnbanIP(ip, jail)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
status := InterpretBanStatus(code, shared.MetricsUnban)
|
|
||||||
Logger.WithFields(map[string]interface{}{
|
|
||||||
"ip": ip,
|
|
||||||
"jail": jail,
|
|
||||||
"status": status,
|
|
||||||
}).Info(shared.MsgUnbanResult)
|
|
||||||
|
|
||||||
results = append(results, OperationResult{
|
|
||||||
IP: ip,
|
|
||||||
Jail: jail,
|
|
||||||
Status: status,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessUnbanOperationWithContext processes unban operations across multiple jails with timeout context
|
// ProcessUnbanOperationWithContext processes unban operations across multiple jails with timeout context
|
||||||
@@ -495,43 +536,7 @@ func ProcessUnbanOperationWithContext(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
logger := GetContextualLogger()
|
return ProcessOperationWithContext(ctx, client, ip, jails, UnbanOperationType)
|
||||||
results := make([]OperationResult, 0, len(jails))
|
|
||||||
|
|
||||||
for _, jail := range jails {
|
|
||||||
// Add jail to context for this operation
|
|
||||||
jailCtx := WithJail(ctx, jail)
|
|
||||||
|
|
||||||
// Time the unban operation
|
|
||||||
start := time.Now()
|
|
||||||
code, err := client.UnbanIPWithContext(jailCtx, ip, jail)
|
|
||||||
duration := time.Since(start)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
// Log the failed operation with timing
|
|
||||||
logger.LogBanOperation(jailCtx, shared.MetricsUnban, ip, jail, false, duration)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
status := InterpretBanStatus(code, shared.MetricsUnban)
|
|
||||||
|
|
||||||
// Log the successful operation with timing
|
|
||||||
logger.LogBanOperation(jailCtx, shared.MetricsUnban, ip, jail, true, duration)
|
|
||||||
|
|
||||||
Logger.WithFields(map[string]interface{}{
|
|
||||||
"ip": ip,
|
|
||||||
"jail": jail,
|
|
||||||
"status": status,
|
|
||||||
}).Info(shared.MsgUnbanResult)
|
|
||||||
|
|
||||||
results = append(results, OperationResult{
|
|
||||||
IP: ip,
|
|
||||||
Jail: jail,
|
|
||||||
Status: status,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Argument validation helpers
|
// Argument validation helpers
|
||||||
|
|||||||
@@ -12,9 +12,7 @@ import (
|
|||||||
|
|
||||||
// TestProcessBanOperation tests the ProcessBanOperation function
|
// TestProcessBanOperation tests the ProcessBanOperation function
|
||||||
func TestProcessBanOperation(t *testing.T) {
|
func TestProcessBanOperation(t *testing.T) {
|
||||||
// Save and restore original runner
|
defer fail2ban.WithTestRunner(t, fail2ban.GetRunner())()
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -27,7 +25,7 @@ func TestProcessBanOperation(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful ban single jail",
|
name: "successful ban single jail",
|
||||||
setupMock: func(m *fail2ban.MockRunner) {
|
setupMock: func(m *fail2ban.MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
fail2ban.StandardMockSetup(m)
|
||||||
m.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
m.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
||||||
m.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
m.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
||||||
},
|
},
|
||||||
@@ -39,7 +37,7 @@ func TestProcessBanOperation(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful ban multiple jails",
|
name: "successful ban multiple jails",
|
||||||
setupMock: func(m *fail2ban.MockRunner) {
|
setupMock: func(m *fail2ban.MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
fail2ban.StandardMockSetup(m)
|
||||||
m.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
m.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
||||||
m.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
m.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
||||||
m.SetResponse("fail2ban-client set apache banip 192.168.1.1", []byte("1"))
|
m.SetResponse("fail2ban-client set apache banip 192.168.1.1", []byte("1"))
|
||||||
@@ -53,7 +51,7 @@ func TestProcessBanOperation(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid IP address",
|
name: "invalid IP address",
|
||||||
setupMock: func(m *fail2ban.MockRunner) {
|
setupMock: func(m *fail2ban.MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
fail2ban.StandardMockSetup(m)
|
||||||
},
|
},
|
||||||
ip: "invalid.ip",
|
ip: "invalid.ip",
|
||||||
jails: []string{"sshd"},
|
jails: []string{"sshd"},
|
||||||
@@ -147,13 +145,3 @@ func TestParseTimeoutFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupBasicMockResponses is a helper for setting up version check and ping responses
|
|
||||||
func setupBasicMockResponses(m *fail2ban.MockRunner) {
|
|
||||||
m.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
|
||||||
m.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
|
||||||
m.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
|
||||||
m.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
|
||||||
m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache"))
|
|
||||||
m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache"))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -87,14 +87,9 @@ func ExecuteIPCommand(
|
|||||||
// Get the contextual logger
|
// Get the contextual logger
|
||||||
logger := GetContextualLogger()
|
logger := GetContextualLogger()
|
||||||
|
|
||||||
// Safe timeout handling with nil check
|
|
||||||
timeout := shared.DefaultCommandTimeout
|
|
||||||
if config != nil && config.CommandTimeout > 0 {
|
|
||||||
timeout = config.CommandTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create timeout context for the entire operation
|
// Create timeout context for the entire operation
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
// Use cmd.Context() to inherit Cobra's signal cancellation
|
||||||
|
ctx, cancel := createTimeoutContext(cmd.Context(), config)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Add command context
|
// Add command context
|
||||||
|
|||||||
@@ -9,6 +9,43 @@ import (
|
|||||||
"github.com/ivuorinen/f2b/fail2ban"
|
"github.com/ivuorinen/f2b/fail2ban"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// multiJailOperationFunc defines the signature for IP operation functions that process multiple jails
|
||||||
|
type multiJailOperationFunc func(
|
||||||
|
ctx context.Context,
|
||||||
|
client fail2ban.Client,
|
||||||
|
ip string,
|
||||||
|
jails []string,
|
||||||
|
) ([]OperationResult, error)
|
||||||
|
|
||||||
|
// validateIPAndJails validates an IP address and a list of jail names.
|
||||||
|
// Returns an error if any validation fails.
|
||||||
|
func validateIPAndJails(ip string, jails []string) error {
|
||||||
|
if err := fail2ban.ValidateIP(ip); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, jail := range jails {
|
||||||
|
if err := fail2ban.ValidateJail(jail); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processWithValidation validates inputs and executes the provided operation function.
|
||||||
|
// This consolidates the common validation-then-execute pattern.
|
||||||
|
func processWithValidation(
|
||||||
|
ctx context.Context,
|
||||||
|
client fail2ban.Client,
|
||||||
|
ip string,
|
||||||
|
jails []string,
|
||||||
|
opFunc multiJailOperationFunc,
|
||||||
|
) ([]OperationResult, error) {
|
||||||
|
if err := validateIPAndJails(ip, jails); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return opFunc(ctx, client, ip, jails)
|
||||||
|
}
|
||||||
|
|
||||||
// BanProcessor handles ban operations
|
// BanProcessor handles ban operations
|
||||||
type BanProcessor struct{}
|
type BanProcessor struct{}
|
||||||
|
|
||||||
@@ -19,19 +56,7 @@ func (p *BanProcessor) ProcessSingle(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
// Validate IP address before privilege escalation
|
return processWithValidation(ctx, client, ip, jails, ProcessBanOperationWithContext)
|
||||||
if err := fail2ban.ValidateIP(ip); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate each jail name before privilege escalation
|
|
||||||
for _, jail := range jails {
|
|
||||||
if err := fail2ban.ValidateJail(jail); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ProcessBanOperationWithContext(ctx, client, ip, jails)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessParallel processes ban operations for multiple jails in parallel
|
// ProcessParallel processes ban operations for multiple jails in parallel
|
||||||
@@ -41,19 +66,7 @@ func (p *BanProcessor) ProcessParallel(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
// Validate IP address before privilege escalation
|
return processWithValidation(ctx, client, ip, jails, ProcessBanOperationParallelWithContext)
|
||||||
if err := fail2ban.ValidateIP(ip); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate each jail name before privilege escalation
|
|
||||||
for _, jail := range jails {
|
|
||||||
if err := fail2ban.ValidateJail(jail); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ProcessBanOperationParallelWithContext(ctx, client, ip, jails)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnbanProcessor handles unban operations
|
// UnbanProcessor handles unban operations
|
||||||
@@ -66,19 +79,7 @@ func (p *UnbanProcessor) ProcessSingle(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
// Validate IP address before privilege escalation
|
return processWithValidation(ctx, client, ip, jails, ProcessUnbanOperationWithContext)
|
||||||
if err := fail2ban.ValidateIP(ip); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate each jail name before privilege escalation
|
|
||||||
for _, jail := range jails {
|
|
||||||
if err := fail2ban.ValidateJail(jail); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ProcessUnbanOperationWithContext(ctx, client, ip, jails)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessParallel processes unban operations for multiple jails in parallel
|
// ProcessParallel processes unban operations for multiple jails in parallel
|
||||||
@@ -88,17 +89,5 @@ func (p *UnbanProcessor) ProcessParallel(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
// Validate IP address before privilege escalation
|
return processWithValidation(ctx, client, ip, jails, ProcessUnbanOperationParallelWithContext)
|
||||||
if err := fail2ban.ValidateIP(ip); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate each jail name before privilege escalation
|
|
||||||
for _, jail := range jails {
|
|
||||||
if err := fail2ban.ValidateJail(jail); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ProcessUnbanOperationParallelWithContext(ctx, client, ip, jails)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/ivuorinen/f2b/fail2ban"
|
||||||
"github.com/ivuorinen/f2b/shared"
|
"github.com/ivuorinen/f2b/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,57 +58,68 @@ func getVersion() string {
|
|||||||
return version
|
return version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// contextKeyEntry defines a context key and its log field name
|
||||||
|
type contextKeyEntry struct {
|
||||||
|
key any // The context key to look up
|
||||||
|
fieldName string // The log field name to use
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextKeys lists all context keys to extract for logging
|
||||||
|
var contextKeys = []contextKeyEntry{
|
||||||
|
{shared.ContextKeyRequestID, string(shared.ContextKeyRequestID)},
|
||||||
|
{shared.ContextKeyOperation, string(shared.ContextKeyOperation)},
|
||||||
|
{shared.ContextKeyIP, string(shared.ContextKeyIP)},
|
||||||
|
{shared.ContextKeyJail, string(shared.ContextKeyJail)},
|
||||||
|
{shared.ContextKeyCommand, string(shared.ContextKeyCommand)},
|
||||||
|
}
|
||||||
|
|
||||||
// WithContext creates a logger entry with context values
|
// WithContext creates a logger entry with context values
|
||||||
func (cl *ContextualLogger) WithContext(ctx context.Context) *logrus.Entry {
|
func (cl *ContextualLogger) WithContext(ctx context.Context) *logrus.Entry {
|
||||||
entry := cl.WithFields(cl.defaultFields)
|
entry := cl.WithFields(cl.defaultFields)
|
||||||
|
|
||||||
// Extract context values and add as fields (using consistent constants)
|
// Extract context values and add as fields using table-driven approach
|
||||||
if requestID := ctx.Value(shared.ContextKeyRequestID); requestID != nil {
|
for _, ck := range contextKeys {
|
||||||
entry = entry.WithField(string(shared.ContextKeyRequestID), requestID)
|
if val := ctx.Value(ck.key); val != nil {
|
||||||
|
entry = entry.WithField(ck.fieldName, val)
|
||||||
}
|
}
|
||||||
|
|
||||||
if operation := ctx.Value(shared.ContextKeyOperation); operation != nil {
|
|
||||||
entry = entry.WithField(string(shared.ContextKeyOperation), operation)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip := ctx.Value(shared.ContextKeyIP); ip != nil {
|
|
||||||
entry = entry.WithField(string(shared.ContextKeyIP), ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
if jail := ctx.Value(shared.ContextKeyJail); jail != nil {
|
|
||||||
entry = entry.WithField(string(shared.ContextKeyJail), jail)
|
|
||||||
}
|
|
||||||
|
|
||||||
if command := ctx.Value(shared.ContextKeyCommand); command != nil {
|
|
||||||
entry = entry.WithField(string(shared.ContextKeyCommand), command)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return entry
|
return entry
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithOperation adds operation context and returns a new context
|
// WithOperation adds operation context and returns a new context.
|
||||||
|
// Delegates to fail2ban.WithOperation for consistent validation.
|
||||||
func WithOperation(ctx context.Context, operation string) context.Context {
|
func WithOperation(ctx context.Context, operation string) context.Context {
|
||||||
return context.WithValue(ctx, shared.ContextKeyOperation, operation)
|
return fail2ban.WithOperation(ctx, operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithIP adds IP context and returns a new context
|
// WithIP adds IP context and returns a new context.
|
||||||
|
// Delegates to fail2ban.WithIP for consistent IP validation.
|
||||||
func WithIP(ctx context.Context, ip string) context.Context {
|
func WithIP(ctx context.Context, ip string) context.Context {
|
||||||
return context.WithValue(ctx, shared.ContextKeyIP, ip)
|
return fail2ban.WithIP(ctx, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithJail adds jail context and returns a new context
|
// WithJail adds jail context and returns a new context.
|
||||||
|
// Delegates to fail2ban.WithJail for consistent jail name validation.
|
||||||
func WithJail(ctx context.Context, jail string) context.Context {
|
func WithJail(ctx context.Context, jail string) context.Context {
|
||||||
return context.WithValue(ctx, shared.ContextKeyJail, jail)
|
return fail2ban.WithJail(ctx, jail)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithCommand adds command context and returns a new context
|
// WithCommand adds command context and returns a new context.
|
||||||
|
// This is cmd-specific as fail2ban doesn't need command tracking.
|
||||||
|
// Empty commands are not stored in context.
|
||||||
func WithCommand(ctx context.Context, command string) context.Context {
|
func WithCommand(ctx context.Context, command string) context.Context {
|
||||||
|
command = strings.TrimSpace(command)
|
||||||
|
if command == "" {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
return context.WithValue(ctx, shared.ContextKeyCommand, command)
|
return context.WithValue(ctx, shared.ContextKeyCommand, command)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithRequestID adds request ID context and returns a new context
|
// WithRequestID adds request ID context and returns a new context.
|
||||||
|
// Delegates to fail2ban.WithRequestID for consistent validation.
|
||||||
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
||||||
return context.WithValue(ctx, shared.ContextKeyRequestID, requestID)
|
return fail2ban.WithRequestID(ctx, requestID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogOperation logs the start and end of an operation with timing and metrics
|
// LogOperation logs the start and end of an operation with timing and metrics
|
||||||
|
|||||||
105
cmd/metrics.go
105
cmd/metrics.go
@@ -68,17 +68,34 @@ func NewMetrics() *Metrics {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// recordOperationMetrics records metrics for any operation type.
|
||||||
|
// This helper consolidates the duplicate metrics recording pattern.
|
||||||
|
func (m *Metrics) recordOperationMetrics(
|
||||||
|
execCounter, durationCounter, failureCounter *int64,
|
||||||
|
buckets map[string]*LatencyBucket,
|
||||||
|
operation string,
|
||||||
|
duration time.Duration,
|
||||||
|
success bool,
|
||||||
|
) {
|
||||||
|
atomic.AddInt64(execCounter, 1)
|
||||||
|
atomic.AddInt64(durationCounter, duration.Milliseconds())
|
||||||
|
if !success {
|
||||||
|
atomic.AddInt64(failureCounter, 1)
|
||||||
|
}
|
||||||
|
m.recordLatencyBucket(buckets, operation, duration)
|
||||||
|
}
|
||||||
|
|
||||||
// RecordCommandExecution records metrics for command execution
|
// RecordCommandExecution records metrics for command execution
|
||||||
func (m *Metrics) RecordCommandExecution(command string, duration time.Duration, success bool) {
|
func (m *Metrics) RecordCommandExecution(command string, duration time.Duration, success bool) {
|
||||||
atomic.AddInt64(&m.CommandExecutions, 1)
|
m.recordOperationMetrics(
|
||||||
atomic.AddInt64(&m.CommandTotalDuration, duration.Milliseconds())
|
&m.CommandExecutions,
|
||||||
|
&m.CommandTotalDuration,
|
||||||
if !success {
|
&m.CommandFailures,
|
||||||
atomic.AddInt64(&m.CommandFailures, 1)
|
m.commandLatencyBuckets,
|
||||||
}
|
command,
|
||||||
|
duration,
|
||||||
// Record latency bucket
|
success,
|
||||||
m.recordLatencyBucket(m.commandLatencyBuckets, command, duration)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordBanOperation records metrics for ban operations
|
// RecordBanOperation records metrics for ban operations
|
||||||
@@ -99,15 +116,15 @@ func (m *Metrics) RecordBanOperation(operation string, _ time.Duration, success
|
|||||||
|
|
||||||
// RecordClientOperation records metrics for client operations
|
// RecordClientOperation records metrics for client operations
|
||||||
func (m *Metrics) RecordClientOperation(operation string, duration time.Duration, success bool) {
|
func (m *Metrics) RecordClientOperation(operation string, duration time.Duration, success bool) {
|
||||||
atomic.AddInt64(&m.ClientOperations, 1)
|
m.recordOperationMetrics(
|
||||||
atomic.AddInt64(&m.ClientTotalDuration, duration.Milliseconds())
|
&m.ClientOperations,
|
||||||
|
&m.ClientTotalDuration,
|
||||||
if !success {
|
&m.ClientFailures,
|
||||||
atomic.AddInt64(&m.ClientFailures, 1)
|
m.clientLatencyBuckets,
|
||||||
}
|
operation,
|
||||||
|
duration,
|
||||||
// Record latency bucket
|
success,
|
||||||
m.recordLatencyBucket(m.clientLatencyBuckets, operation, duration)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordValidationCacheHit records validation cache hits
|
// RecordValidationCacheHit records validation cache hits
|
||||||
@@ -143,6 +160,25 @@ func (m *Metrics) UpdateGoroutineCount(count int64) {
|
|||||||
atomic.StoreInt64(&m.GoroutineCount, count)
|
atomic.StoreInt64(&m.GoroutineCount, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copyBuckets creates a snapshot copy of latency buckets
|
||||||
|
// This helper consolidates the duplicate bucket copying logic
|
||||||
|
func copyBuckets(buckets map[string]*LatencyBucket) map[string]LatencyBucketSnapshot {
|
||||||
|
result := make(map[string]LatencyBucketSnapshot, len(buckets))
|
||||||
|
for op, bucket := range buckets {
|
||||||
|
result[op] = LatencyBucketSnapshot{
|
||||||
|
Under1ms: atomic.LoadInt64(&bucket.Under1ms),
|
||||||
|
Under10ms: atomic.LoadInt64(&bucket.Under10ms),
|
||||||
|
Under100ms: atomic.LoadInt64(&bucket.Under100ms),
|
||||||
|
Under1s: atomic.LoadInt64(&bucket.Under1s),
|
||||||
|
Under10s: atomic.LoadInt64(&bucket.Under10s),
|
||||||
|
Over10s: atomic.LoadInt64(&bucket.Over10s),
|
||||||
|
Total: atomic.LoadInt64(&bucket.Total),
|
||||||
|
TotalTime: atomic.LoadInt64(&bucket.TotalTime),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// recordLatencyBucket records latency in appropriate bucket
|
// recordLatencyBucket records latency in appropriate bucket
|
||||||
func (m *Metrics) recordLatencyBucket(buckets map[string]*LatencyBucket, operation string, duration time.Duration) {
|
func (m *Metrics) recordLatencyBucket(buckets map[string]*LatencyBucket, operation string, duration time.Duration) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -177,37 +213,8 @@ func (m *Metrics) recordLatencyBucket(buckets map[string]*LatencyBucket, operati
|
|||||||
// GetSnapshot returns a snapshot of current metrics
|
// GetSnapshot returns a snapshot of current metrics
|
||||||
func (m *Metrics) GetSnapshot() MetricsSnapshot {
|
func (m *Metrics) GetSnapshot() MetricsSnapshot {
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
|
commandBuckets := copyBuckets(m.commandLatencyBuckets)
|
||||||
// Copy command latency buckets
|
clientBuckets := copyBuckets(m.clientLatencyBuckets)
|
||||||
commandBuckets := make(map[string]LatencyBucketSnapshot)
|
|
||||||
for op, bucket := range m.commandLatencyBuckets {
|
|
||||||
commandBuckets[op] = LatencyBucketSnapshot{
|
|
||||||
Under1ms: atomic.LoadInt64(&bucket.Under1ms),
|
|
||||||
Under10ms: atomic.LoadInt64(&bucket.Under10ms),
|
|
||||||
Under100ms: atomic.LoadInt64(&bucket.Under100ms),
|
|
||||||
Under1s: atomic.LoadInt64(&bucket.Under1s),
|
|
||||||
Under10s: atomic.LoadInt64(&bucket.Under10s),
|
|
||||||
Over10s: atomic.LoadInt64(&bucket.Over10s),
|
|
||||||
Total: atomic.LoadInt64(&bucket.Total),
|
|
||||||
TotalTime: atomic.LoadInt64(&bucket.TotalTime),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy client latency buckets
|
|
||||||
clientBuckets := make(map[string]LatencyBucketSnapshot)
|
|
||||||
for op, bucket := range m.clientLatencyBuckets {
|
|
||||||
clientBuckets[op] = LatencyBucketSnapshot{
|
|
||||||
Under1ms: atomic.LoadInt64(&bucket.Under1ms),
|
|
||||||
Under10ms: atomic.LoadInt64(&bucket.Under10ms),
|
|
||||||
Under100ms: atomic.LoadInt64(&bucket.Under100ms),
|
|
||||||
Under1s: atomic.LoadInt64(&bucket.Under1s),
|
|
||||||
Under10s: atomic.LoadInt64(&bucket.Under10s),
|
|
||||||
Over10s: atomic.LoadInt64(&bucket.Over10s),
|
|
||||||
Total: atomic.LoadInt64(&bucket.Total),
|
|
||||||
TotalTime: atomic.LoadInt64(&bucket.TotalTime),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mu.RUnlock()
|
m.mu.RUnlock()
|
||||||
|
|
||||||
return MetricsSnapshot{
|
return MetricsSnapshot{
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -97,35 +98,40 @@ func printMetricsPlain(output io.Writer, snapshot MetricsSnapshot) error {
|
|||||||
// Command latency distribution
|
// Command latency distribution
|
||||||
if len(snapshot.CommandLatencyBuckets) > 0 {
|
if len(snapshot.CommandLatencyBuckets) > 0 {
|
||||||
sb.WriteString("Command Latency Distribution:\n")
|
sb.WriteString("Command Latency Distribution:\n")
|
||||||
for cmd, bucket := range snapshot.CommandLatencyBuckets {
|
formatLatencyBuckets(&sb, snapshot.CommandLatencyBuckets)
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtOperationHeader, cmd))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1s, bucket.Under1s))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10s, bucket.Under10s))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyOver10s, bucket.Over10s))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtAverageLatency, bucket.GetAverageLatency()))
|
|
||||||
}
|
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client latency distribution
|
// Client latency distribution
|
||||||
if len(snapshot.ClientLatencyBuckets) > 0 {
|
if len(snapshot.ClientLatencyBuckets) > 0 {
|
||||||
sb.WriteString("Client Operation Latency Distribution:\n")
|
sb.WriteString("Client Operation Latency Distribution:\n")
|
||||||
for op, bucket := range snapshot.ClientLatencyBuckets {
|
formatLatencyBuckets(&sb, snapshot.ClientLatencyBuckets)
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtOperationHeader, op))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder1s, bucket.Under1s))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyUnder10s, bucket.Under10s))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtLatencyOver10s, bucket.Over10s))
|
|
||||||
sb.WriteString(fmt.Sprintf(shared.MetricsFmtAverageLatency, bucket.GetAverageLatency()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the entire string at once
|
// Write the entire string at once
|
||||||
_, err := output.Write([]byte(sb.String()))
|
_, err := output.Write([]byte(sb.String()))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// formatLatencyBuckets writes latency bucket distribution to the builder.
|
||||||
|
// Keys are sorted for deterministic output.
|
||||||
|
func formatLatencyBuckets(sb *strings.Builder, buckets map[string]LatencyBucketSnapshot) {
|
||||||
|
// Sort keys for deterministic output
|
||||||
|
keys := make([]string, 0, len(buckets))
|
||||||
|
for name := range buckets {
|
||||||
|
keys = append(keys, name)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, name := range keys {
|
||||||
|
bucket := buckets[name]
|
||||||
|
fmt.Fprintf(sb, shared.MetricsFmtOperationHeader, name)
|
||||||
|
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder1ms, bucket.Under1ms)
|
||||||
|
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder10ms, bucket.Under10ms)
|
||||||
|
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder100ms, bucket.Under100ms)
|
||||||
|
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder1s, bucket.Under1s)
|
||||||
|
fmt.Fprintf(sb, shared.MetricsFmtLatencyUnder10s, bucket.Under10s)
|
||||||
|
fmt.Fprintf(sb, shared.MetricsFmtLatencyOver10s, bucket.Over10s)
|
||||||
|
fmt.Fprintf(sb, shared.MetricsFmtAverageLatency, bucket.GetAverageLatency())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ivuorinen/f2b/fail2ban"
|
"github.com/ivuorinen/f2b/fail2ban"
|
||||||
"github.com/ivuorinen/f2b/shared"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParallelOperationProcessor handles parallel ban/unban operations across multiple jails
|
// ParallelOperationProcessor handles parallel ban/unban operations across multiple jails
|
||||||
@@ -24,15 +24,16 @@ func NewParallelOperationProcessor(workerCount int) *ParallelOperationProcessor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessBanOperationParallel processes ban operations across multiple jails in parallel
|
// ProcessOperationParallel processes operations across multiple jails in parallel
|
||||||
func (pop *ParallelOperationProcessor) ProcessBanOperationParallel(
|
func (pop *ParallelOperationProcessor) ProcessOperationParallel(
|
||||||
client fail2ban.Client,
|
client fail2ban.Client,
|
||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
|
opType OperationType,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
if len(jails) <= 1 {
|
if len(jails) <= 1 {
|
||||||
// For single jail, use sequential processing to avoid overhead
|
// For single jail, use sequential processing to avoid overhead
|
||||||
return ProcessBanOperation(client, ip, jails)
|
return ProcessOperation(client, ip, jails, opType)
|
||||||
}
|
}
|
||||||
|
|
||||||
return pop.processOperations(
|
return pop.processOperations(
|
||||||
@@ -40,13 +41,43 @@ func (pop *ParallelOperationProcessor) ProcessBanOperationParallel(
|
|||||||
client,
|
client,
|
||||||
ip,
|
ip,
|
||||||
jails,
|
jails,
|
||||||
func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
|
opType.OperationCtx,
|
||||||
return client.BanIPWithContext(ctx, ip, jail)
|
opType.MetricsType,
|
||||||
},
|
|
||||||
shared.MetricsBan,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProcessOperationParallelWithContext processes operations across multiple jails in parallel with context
|
||||||
|
func (pop *ParallelOperationProcessor) ProcessOperationParallelWithContext(
|
||||||
|
ctx context.Context,
|
||||||
|
client fail2ban.Client,
|
||||||
|
ip string,
|
||||||
|
jails []string,
|
||||||
|
opType OperationType,
|
||||||
|
) ([]OperationResult, error) {
|
||||||
|
if len(jails) <= 1 {
|
||||||
|
// For single jail, use sequential processing to avoid overhead
|
||||||
|
return ProcessOperationWithContext(ctx, client, ip, jails, opType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return pop.processOperations(
|
||||||
|
ctx,
|
||||||
|
client,
|
||||||
|
ip,
|
||||||
|
jails,
|
||||||
|
opType.OperationCtx,
|
||||||
|
opType.MetricsType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessBanOperationParallel processes ban operations across multiple jails in parallel
|
||||||
|
func (pop *ParallelOperationProcessor) ProcessBanOperationParallel(
|
||||||
|
client fail2ban.Client,
|
||||||
|
ip string,
|
||||||
|
jails []string,
|
||||||
|
) ([]OperationResult, error) {
|
||||||
|
return pop.ProcessOperationParallel(client, ip, jails, BanOperationType)
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessBanOperationParallelWithContext processes ban operations across
|
// ProcessBanOperationParallelWithContext processes ban operations across
|
||||||
// multiple jails in parallel with timeout context
|
// multiple jails in parallel with timeout context
|
||||||
func (pop *ParallelOperationProcessor) ProcessBanOperationParallelWithContext(
|
func (pop *ParallelOperationProcessor) ProcessBanOperationParallelWithContext(
|
||||||
@@ -55,21 +86,7 @@ func (pop *ParallelOperationProcessor) ProcessBanOperationParallelWithContext(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
if len(jails) <= 1 {
|
return pop.ProcessOperationParallelWithContext(ctx, client, ip, jails, BanOperationType)
|
||||||
// For single jail, use sequential processing to avoid overhead
|
|
||||||
return ProcessBanOperationWithContext(ctx, client, ip, jails)
|
|
||||||
}
|
|
||||||
|
|
||||||
return pop.processOperations(
|
|
||||||
ctx,
|
|
||||||
client,
|
|
||||||
ip,
|
|
||||||
jails,
|
|
||||||
func(opCtx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
|
|
||||||
return client.BanIPWithContext(opCtx, ip, jail)
|
|
||||||
},
|
|
||||||
shared.MetricsBan,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessUnbanOperationParallel processes unban operations across multiple jails in parallel
|
// ProcessUnbanOperationParallel processes unban operations across multiple jails in parallel
|
||||||
@@ -78,21 +95,7 @@ func (pop *ParallelOperationProcessor) ProcessUnbanOperationParallel(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
if len(jails) <= 1 {
|
return pop.ProcessOperationParallel(client, ip, jails, UnbanOperationType)
|
||||||
// For single jail, use sequential processing to avoid overhead
|
|
||||||
return ProcessUnbanOperation(client, ip, jails)
|
|
||||||
}
|
|
||||||
|
|
||||||
return pop.processOperations(
|
|
||||||
context.Background(),
|
|
||||||
client,
|
|
||||||
ip,
|
|
||||||
jails,
|
|
||||||
func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
|
|
||||||
return client.UnbanIPWithContext(ctx, ip, jail)
|
|
||||||
},
|
|
||||||
shared.MetricsUnban,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessUnbanOperationParallelWithContext processes unban operations across
|
// ProcessUnbanOperationParallelWithContext processes unban operations across
|
||||||
@@ -103,26 +106,35 @@ func (pop *ParallelOperationProcessor) ProcessUnbanOperationParallelWithContext(
|
|||||||
ip string,
|
ip string,
|
||||||
jails []string,
|
jails []string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
if len(jails) <= 1 {
|
return pop.ProcessOperationParallelWithContext(ctx, client, ip, jails, UnbanOperationType)
|
||||||
// For single jail, use sequential processing to avoid overhead
|
|
||||||
return ProcessUnbanOperationWithContext(ctx, client, ip, jails)
|
|
||||||
}
|
|
||||||
|
|
||||||
return pop.processOperations(
|
|
||||||
ctx,
|
|
||||||
client,
|
|
||||||
ip,
|
|
||||||
jails,
|
|
||||||
func(opCtx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
|
|
||||||
return client.UnbanIPWithContext(opCtx, ip, jail)
|
|
||||||
},
|
|
||||||
shared.MetricsUnban,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// operationFunc represents a ban or unban operation with context
|
// operationFunc represents a ban or unban operation with context
|
||||||
type operationFunc func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error)
|
type operationFunc func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error)
|
||||||
|
|
||||||
|
// validateOperationInputs validates IP and jail inputs before parallel processing.
|
||||||
|
// Returns an aggregated error if any inputs are invalid.
|
||||||
|
func validateOperationInputs(ctx context.Context, ip string, jails []string) error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
// Validate IP address
|
||||||
|
if err := fail2ban.CachedValidateIP(ctx, ip); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate each jail name
|
||||||
|
for _, jail := range jails {
|
||||||
|
if err := fail2ban.CachedValidateJail(ctx, jail); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// processOperations handles the parallel processing of operations
|
// processOperations handles the parallel processing of operations
|
||||||
func (pop *ParallelOperationProcessor) processOperations(
|
func (pop *ParallelOperationProcessor) processOperations(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -132,6 +144,11 @@ func (pop *ParallelOperationProcessor) processOperations(
|
|||||||
operation operationFunc,
|
operation operationFunc,
|
||||||
operationType string,
|
operationType string,
|
||||||
) ([]OperationResult, error) {
|
) ([]OperationResult, error) {
|
||||||
|
// Validate inputs before processing
|
||||||
|
if err := validateOperationInputs(ctx, ip, jails); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
results := make([]OperationResult, len(jails))
|
results := make([]OperationResult, len(jails))
|
||||||
resultCh := make(chan operationResult, len(jails))
|
resultCh := make(chan operationResult, len(jails))
|
||||||
|
|
||||||
@@ -167,13 +184,20 @@ func (pop *ParallelOperationProcessor) processOperations(
|
|||||||
close(resultCh)
|
close(resultCh)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Collect results
|
// Collect results and errors
|
||||||
|
var errs []error
|
||||||
for result := range resultCh {
|
for result := range resultCh {
|
||||||
if result.index >= 0 && result.index < len(results) {
|
if result.index >= 0 && result.index < len(results) {
|
||||||
results[result.index] = result.result
|
results[result.index] = result.result
|
||||||
}
|
}
|
||||||
|
if result.err != nil {
|
||||||
|
errs = append(errs, result.err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return results, errors.Join(errs...)
|
||||||
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,6 +211,7 @@ type jailWork struct {
|
|||||||
type operationResult struct {
|
type operationResult struct {
|
||||||
result OperationResult
|
result OperationResult
|
||||||
index int
|
index int
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// worker processes jail operations
|
// worker processes jail operations
|
||||||
@@ -215,16 +240,15 @@ func (pop *ParallelOperationProcessor) worker(
|
|||||||
"status": status,
|
"status": status,
|
||||||
}).Info("Operation result")
|
}).Info("Operation result")
|
||||||
|
|
||||||
result := operationResult{
|
resultCh <- operationResult{
|
||||||
result: OperationResult{
|
result: OperationResult{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Jail: work.jail,
|
Jail: work.jail,
|
||||||
Status: status,
|
Status: status,
|
||||||
},
|
},
|
||||||
index: work.index,
|
index: work.index,
|
||||||
|
err: err,
|
||||||
}
|
}
|
||||||
|
|
||||||
resultCh <- result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,17 +12,13 @@ import (
|
|||||||
|
|
||||||
// TestUnbanProcessorProcessParallel tests the ProcessParallel method
|
// TestUnbanProcessorProcessParallel tests the ProcessParallel method
|
||||||
func TestUnbanProcessorProcessParallel(t *testing.T) {
|
func TestUnbanProcessorProcessParallel(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
setupBasicMockResponses(mockRunner)
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
|
fail2ban.StandardMockSetup(mockRunner)
|
||||||
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("fail2ban-client set apache unbanip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("fail2ban-client set apache unbanip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client set apache unbanip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("sudo fail2ban-client set apache unbanip 192.168.1.1", []byte("1"))
|
||||||
fail2ban.SetRunner(mockRunner)
|
|
||||||
|
|
||||||
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -12,17 +12,13 @@ import (
|
|||||||
|
|
||||||
// TestProcessBanOperationParallel tests the ProcessBanOperationParallel wrapper function
|
// TestProcessBanOperationParallel tests the ProcessBanOperationParallel wrapper function
|
||||||
func TestProcessBanOperationParallel(t *testing.T) {
|
func TestProcessBanOperationParallel(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
setupBasicMockResponses(mockRunner)
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
|
fail2ban.StandardMockSetup(mockRunner)
|
||||||
mockRunner.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("fail2ban-client set apache banip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("fail2ban-client set apache banip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client set apache banip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("sudo fail2ban-client set apache banip 192.168.1.1", []byte("1"))
|
||||||
fail2ban.SetRunner(mockRunner)
|
|
||||||
|
|
||||||
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -34,15 +30,11 @@ func TestProcessBanOperationParallel(t *testing.T) {
|
|||||||
|
|
||||||
// TestProcessUnbanOperationParallel tests the ProcessUnbanOperationParallel wrapper function
|
// TestProcessUnbanOperationParallel tests the ProcessUnbanOperationParallel wrapper function
|
||||||
func TestProcessUnbanOperationParallel(t *testing.T) {
|
func TestProcessUnbanOperationParallel(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
setupBasicMockResponses(mockRunner)
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
|
fail2ban.StandardMockSetup(mockRunner)
|
||||||
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
||||||
fail2ban.SetRunner(mockRunner)
|
|
||||||
|
|
||||||
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -54,15 +46,11 @@ func TestProcessUnbanOperationParallel(t *testing.T) {
|
|||||||
|
|
||||||
// TestProcessBanOperationParallelWithContext tests the wrapper with context
|
// TestProcessBanOperationParallelWithContext tests the wrapper with context
|
||||||
func TestProcessBanOperationParallelWithContext(t *testing.T) {
|
func TestProcessBanOperationParallelWithContext(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
setupBasicMockResponses(mockRunner)
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
|
fail2ban.StandardMockSetup(mockRunner)
|
||||||
mockRunner.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("sudo fail2ban-client set sshd banip 192.168.1.1", []byte("1"))
|
||||||
fail2ban.SetRunner(mockRunner)
|
|
||||||
|
|
||||||
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -75,15 +63,11 @@ func TestProcessBanOperationParallelWithContext(t *testing.T) {
|
|||||||
|
|
||||||
// TestProcessUnbanOperationParallelWithContext tests the wrapper with context
|
// TestProcessUnbanOperationParallelWithContext tests the wrapper with context
|
||||||
func TestProcessUnbanOperationParallelWithContext(t *testing.T) {
|
func TestProcessUnbanOperationParallelWithContext(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
setupBasicMockResponses(mockRunner)
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
|
fail2ban.StandardMockSetup(mockRunner)
|
||||||
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
mockRunner.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.1", []byte("1"))
|
||||||
fail2ban.SetRunner(mockRunner)
|
|
||||||
|
|
||||||
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -15,7 +15,11 @@ func ServiceCmd(config *Config) *cobra.Command {
|
|||||||
nil,
|
nil,
|
||||||
func(_ *cobra.Command, args []string) error {
|
func(_ *cobra.Command, args []string) error {
|
||||||
// Validate service action argument
|
// Validate service action argument
|
||||||
if err := RequireArguments(args, 1, "action required: start|stop|restart|status|reload|enable|disable"); err != nil {
|
if err := RequireArguments(
|
||||||
|
args,
|
||||||
|
1,
|
||||||
|
"action required: start|stop|restart|status|reload|enable|disable",
|
||||||
|
); err != nil {
|
||||||
return HandleValidationError(err)
|
return HandleValidationError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -110,18 +110,14 @@ func TestValidateConfigPath(t *testing.T) {
|
|||||||
|
|
||||||
// TestLogsWatchCmdCreation tests LogsWatchCmd creation
|
// TestLogsWatchCmdCreation tests LogsWatchCmd creation
|
||||||
func TestLogsWatchCmdCreation(t *testing.T) {
|
func TestLogsWatchCmdCreation(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
||||||
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
||||||
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
fail2ban.SetRunner(mockRunner)
|
|
||||||
|
|
||||||
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
client, err := fail2ban.NewClient("/var/log/fail2ban", "/etc/fail2ban/filter.d")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -142,18 +138,14 @@ func TestLogsWatchCmdCreation(t *testing.T) {
|
|||||||
|
|
||||||
// TestGetLogLinesWithLimitAndContext_Function tests the function
|
// TestGetLogLinesWithLimitAndContext_Function tests the function
|
||||||
func TestGetLogLinesWithLimitAndContext_Function(t *testing.T) {
|
func TestGetLogLinesWithLimitAndContext_Function(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := fail2ban.GetRunner()
|
|
||||||
defer fail2ban.SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := fail2ban.NewMockRunner()
|
mockRunner := fail2ban.NewMockRunner()
|
||||||
|
defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
||||||
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
||||||
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
fail2ban.SetRunner(mockRunner)
|
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
oldLogDir := fail2ban.GetLogDir()
|
oldLogDir := fail2ban.GetLogDir()
|
||||||
|
|||||||
@@ -95,6 +95,23 @@ func (btc *BoundedTimeCache) Size() int {
|
|||||||
return len(btc.cache)
|
return len(btc.cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseWithLayout parses a time string using the specified layout with caching.
|
||||||
|
// This method consolidates the cache-lookup-parse-store pattern used across
|
||||||
|
// different time parsing caches in the codebase.
|
||||||
|
func (btc *BoundedTimeCache) ParseWithLayout(timeStr, layout string) (time.Time, error) {
|
||||||
|
// Fast path: check cache
|
||||||
|
if cached, ok := btc.Load(timeStr); ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and cache - only cache successful parses
|
||||||
|
t, err := time.Parse(layout, timeStr)
|
||||||
|
if err == nil {
|
||||||
|
btc.Store(timeStr, t)
|
||||||
|
}
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
|
||||||
// BanRecordParser provides high-performance parsing of ban records
|
// BanRecordParser provides high-performance parsing of ban records
|
||||||
type BanRecordParser struct {
|
type BanRecordParser struct {
|
||||||
// Pools for zero-allocation parsing (goroutine-safe)
|
// Pools for zero-allocation parsing (goroutine-safe)
|
||||||
@@ -167,17 +184,7 @@ func NewFastTimeCache(layout string) (*FastTimeCache, error) {
|
|||||||
|
|
||||||
// ParseTimeOptimized parses time with minimal allocations
|
// ParseTimeOptimized parses time with minimal allocations
|
||||||
func (ftc *FastTimeCache) ParseTimeOptimized(timeStr string) (time.Time, error) {
|
func (ftc *FastTimeCache) ParseTimeOptimized(timeStr string) (time.Time, error) {
|
||||||
// Fast path: check cache
|
return ftc.parseCache.ParseWithLayout(timeStr, ftc.layout)
|
||||||
if cached, ok := ftc.parseCache.Load(timeStr); ok {
|
|
||||||
return cached, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse and cache - only cache successful parses
|
|
||||||
t, err := time.Parse(ftc.layout, timeStr)
|
|
||||||
if err == nil {
|
|
||||||
ftc.parseCache.Store(timeStr, t)
|
|
||||||
}
|
|
||||||
return t, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildTimeStringOptimized builds time string with zero allocations using byte buffer
|
// BuildTimeStringOptimized builds time string with zero allocations using byte buffer
|
||||||
|
|||||||
@@ -12,17 +12,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// setupBasicMockResponses sets up the basic responses needed for client initialization
|
|
||||||
func setupBasicMockResponses(m *MockRunner) {
|
|
||||||
m.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
|
||||||
m.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
|
||||||
m.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
|
||||||
m.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
|
||||||
// NewClient calls fetchJailsWithContext which runs status
|
|
||||||
m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache"))
|
|
||||||
m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestListJailsWithContext tests jail listing with context
|
// TestListJailsWithContext tests jail listing with context
|
||||||
func TestListJailsWithContext(t *testing.T) {
|
func TestListJailsWithContext(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -35,11 +24,11 @@ func TestListJailsWithContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful jail listing",
|
name: "successful jail listing",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
},
|
},
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
expectError: false,
|
expectError: false,
|
||||||
expectJails: []string{"sshd", "apache"}, // From setupBasicMockResponses
|
expectJails: []string{"sshd", "apache"}, // From StandardMockSetup
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,7 +72,7 @@ func TestStatusAllWithContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful status all",
|
name: "successful status all",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
},
|
},
|
||||||
@@ -94,7 +83,7 @@ func TestStatusAllWithContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "context timeout",
|
name: "context timeout",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
},
|
},
|
||||||
@@ -145,7 +134,7 @@ func TestStatusJailWithContext(t *testing.T) {
|
|||||||
name: "successful status jail",
|
name: "successful status jail",
|
||||||
jail: "sshd",
|
jail: "sshd",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse(
|
m.SetResponse(
|
||||||
"fail2ban-client status sshd",
|
"fail2ban-client status sshd",
|
||||||
[]byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"),
|
[]byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"),
|
||||||
@@ -163,7 +152,7 @@ func TestStatusJailWithContext(t *testing.T) {
|
|||||||
name: "invalid jail name",
|
name: "invalid jail name",
|
||||||
jail: "invalid@jail",
|
jail: "invalid@jail",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
// Validation will fail before command execution
|
// Validation will fail before command execution
|
||||||
},
|
},
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
@@ -173,7 +162,7 @@ func TestStatusJailWithContext(t *testing.T) {
|
|||||||
name: "context timeout",
|
name: "context timeout",
|
||||||
jail: "sshd",
|
jail: "sshd",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse(
|
m.SetResponse(
|
||||||
"fail2ban-client status sshd",
|
"fail2ban-client status sshd",
|
||||||
[]byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"),
|
[]byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"),
|
||||||
@@ -234,7 +223,7 @@ func TestUnbanIPWithContext(t *testing.T) {
|
|||||||
ip: "192.168.1.100",
|
ip: "192.168.1.100",
|
||||||
jail: "sshd",
|
jail: "sshd",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("0"))
|
m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("0"))
|
||||||
m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("0"))
|
m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("0"))
|
||||||
},
|
},
|
||||||
@@ -247,7 +236,7 @@ func TestUnbanIPWithContext(t *testing.T) {
|
|||||||
ip: "192.168.1.100",
|
ip: "192.168.1.100",
|
||||||
jail: "sshd",
|
jail: "sshd",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("1"))
|
m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("1"))
|
||||||
m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("1"))
|
m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("1"))
|
||||||
},
|
},
|
||||||
@@ -260,7 +249,7 @@ func TestUnbanIPWithContext(t *testing.T) {
|
|||||||
ip: "invalid-ip",
|
ip: "invalid-ip",
|
||||||
jail: "sshd",
|
jail: "sshd",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
// Validation will fail before command execution
|
// Validation will fail before command execution
|
||||||
},
|
},
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
@@ -271,7 +260,7 @@ func TestUnbanIPWithContext(t *testing.T) {
|
|||||||
ip: "192.168.1.100",
|
ip: "192.168.1.100",
|
||||||
jail: "invalid@jail",
|
jail: "invalid@jail",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
// Validation will fail before command execution
|
// Validation will fail before command execution
|
||||||
},
|
},
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
@@ -282,7 +271,7 @@ func TestUnbanIPWithContext(t *testing.T) {
|
|||||||
ip: "192.168.1.100",
|
ip: "192.168.1.100",
|
||||||
jail: "sshd",
|
jail: "sshd",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("0"))
|
m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("0"))
|
||||||
m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("0"))
|
m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("0"))
|
||||||
},
|
},
|
||||||
@@ -332,7 +321,7 @@ func TestListFiltersWithContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful filter listing",
|
name: "successful filter listing",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
// Mock responses not needed - uses file system
|
// Mock responses not needed - uses file system
|
||||||
},
|
},
|
||||||
setupEnv: func() {
|
setupEnv: func() {
|
||||||
@@ -345,7 +334,7 @@ func TestListFiltersWithContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "context timeout",
|
name: "context timeout",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
// Not applicable for file system operation
|
// Not applicable for file system operation
|
||||||
},
|
},
|
||||||
setupEnv: func() {
|
setupEnv: func() {
|
||||||
@@ -412,7 +401,7 @@ logpath = /var/log/auth.log
|
|||||||
name: "successful filter test",
|
name: "successful filter test",
|
||||||
filter: "sshd",
|
filter: "sshd",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse(
|
m.SetResponse(
|
||||||
"fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"),
|
"fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"),
|
||||||
[]byte("Success: 0 matches"),
|
[]byte("Success: 0 matches"),
|
||||||
@@ -429,7 +418,7 @@ logpath = /var/log/auth.log
|
|||||||
name: "invalid filter name",
|
name: "invalid filter name",
|
||||||
filter: "invalid@filter",
|
filter: "invalid@filter",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
// Validation will fail before command execution
|
// Validation will fail before command execution
|
||||||
},
|
},
|
||||||
timeout: 5 * time.Second,
|
timeout: 5 * time.Second,
|
||||||
@@ -439,7 +428,7 @@ logpath = /var/log/auth.log
|
|||||||
name: "context timeout",
|
name: "context timeout",
|
||||||
filter: "sshd",
|
filter: "sshd",
|
||||||
setupMock: func(m *MockRunner) {
|
setupMock: func(m *MockRunner) {
|
||||||
setupBasicMockResponses(m)
|
StandardMockSetup(m)
|
||||||
m.SetResponse(
|
m.SetResponse(
|
||||||
"fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"),
|
"fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"),
|
||||||
[]byte("Success: 0 matches"),
|
[]byte("Success: 0 matches"),
|
||||||
@@ -485,7 +474,7 @@ logpath = /var/log/auth.log
|
|||||||
// TestWithContextCancellation tests that all WithContext functions respect cancellation
|
// TestWithContextCancellation tests that all WithContext functions respect cancellation
|
||||||
func TestWithContextCancellation(t *testing.T) {
|
func TestWithContextCancellation(t *testing.T) {
|
||||||
mock := NewMockRunner()
|
mock := NewMockRunner()
|
||||||
setupBasicMockResponses(mock)
|
StandardMockSetup(mock)
|
||||||
SetRunner(mock)
|
SetRunner(mock)
|
||||||
|
|
||||||
client, err := NewClient("/var/log", "/etc/fail2ban/filter.d")
|
client, err := NewClient("/var/log", "/etc/fail2ban/filter.d")
|
||||||
@@ -520,7 +509,7 @@ func TestWithContextCancellation(t *testing.T) {
|
|||||||
// TestWithContextDeadline tests that all WithContext functions respect deadlines
|
// TestWithContextDeadline tests that all WithContext functions respect deadlines
|
||||||
func TestWithContextDeadline(t *testing.T) {
|
func TestWithContextDeadline(t *testing.T) {
|
||||||
mock := NewMockRunner()
|
mock := NewMockRunner()
|
||||||
setupBasicMockResponses(mock)
|
StandardMockSetup(mock)
|
||||||
SetRunner(mock)
|
SetRunner(mock)
|
||||||
|
|
||||||
client, err := NewClient("/var/log", "/etc/fail2ban/filter.d")
|
client, err := NewClient("/var/log", "/etc/fail2ban/filter.d")
|
||||||
@@ -576,7 +565,7 @@ func TestWithContextDeadline(t *testing.T) {
|
|||||||
// TestWithContextValidation tests that validation happens before context usage
|
// TestWithContextValidation tests that validation happens before context usage
|
||||||
func TestWithContextValidation(t *testing.T) {
|
func TestWithContextValidation(t *testing.T) {
|
||||||
mock := NewMockRunner()
|
mock := NewMockRunner()
|
||||||
setupBasicMockResponses(mock)
|
StandardMockSetup(mock)
|
||||||
SetRunner(mock)
|
SetRunner(mock)
|
||||||
|
|
||||||
client, err := NewClient("/var/log", "/etc/fail2ban/filter.d")
|
client, err := NewClient("/var/log", "/etc/fail2ban/filter.d")
|
||||||
|
|||||||
@@ -2,30 +2,60 @@ package fail2ban
|
|||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|
||||||
// ContextWrappers provides a helper to automatically generate WithContext method wrappers.
|
// Context Wrapper Pattern
|
||||||
// This eliminates the need for duplicate WithContext implementations across different Client types.
|
//
|
||||||
// Usage: embed this in your Client struct and call DefineContextWrappers to get automatic context support.
|
// This package provides generic helper functions for creating WithContext method wrappers.
|
||||||
type ContextWrappers struct{}
|
// These helpers eliminate boilerplate for methods that simply need context propagation
|
||||||
|
// without additional logic.
|
||||||
|
//
|
||||||
|
// For simple methods (no validation, direct delegation to non-context version):
|
||||||
|
//
|
||||||
|
// func (c *Client) ListJailsWithContext(ctx context.Context) ([]string, error) {
|
||||||
|
// return wrapWithContext0(c.ListJails)(ctx)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// For complex methods (validation, custom logic, different execution paths):
|
||||||
|
// Implement the WithContext version directly with the required logic.
|
||||||
|
//
|
||||||
|
// Note: Code generation was evaluated but determined unnecessary because:
|
||||||
|
// - Only 2 methods use the simple wrapper pattern
|
||||||
|
// - Most WithContext methods require custom validation/logic
|
||||||
|
// - The generic helpers below already solve the simple cases cleanly
|
||||||
|
|
||||||
// Helper functions to reduce boilerplate in WithContext implementations
|
// Helper functions to reduce boilerplate in WithContext implementations
|
||||||
|
|
||||||
// wrapWithContext0 wraps a function with no parameters to accept a context parameter.
|
// wrapWithContext0 wraps a function with no parameters to accept a context parameter.
|
||||||
|
// It checks for context cancellation before invoking the underlying function.
|
||||||
func wrapWithContext0[T any](fn func() (T, error)) func(context.Context) (T, error) {
|
func wrapWithContext0[T any](fn func() (T, error)) func(context.Context) (T, error) {
|
||||||
return func(_ context.Context) (T, error) {
|
return func(ctx context.Context) (T, error) {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
return fn()
|
return fn()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapWithContext1 wraps a function with one parameter to accept a context parameter.
|
// wrapWithContext1 wraps a function with one parameter to accept a context parameter.
|
||||||
|
// It checks for context cancellation before invoking the underlying function.
|
||||||
func wrapWithContext1[T any, A any](fn func(A) (T, error)) func(context.Context, A) (T, error) {
|
func wrapWithContext1[T any, A any](fn func(A) (T, error)) func(context.Context, A) (T, error) {
|
||||||
return func(_ context.Context, a A) (T, error) {
|
return func(ctx context.Context, a A) (T, error) {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
return fn(a)
|
return fn(a)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapWithContext2 wraps a function with two parameters to accept a context parameter.
|
// wrapWithContext2 wraps a function with two parameters to accept a context parameter.
|
||||||
|
// It checks for context cancellation before invoking the underlying function.
|
||||||
func wrapWithContext2[T any, A any, B any](fn func(A, B) (T, error)) func(context.Context, A, B) (T, error) {
|
func wrapWithContext2[T any, A any, B any](fn func(A, B) (T, error)) func(context.Context, A, B) (T, error) {
|
||||||
return func(_ context.Context, a A, b B) (T, error) {
|
return func(ctx context.Context, a A, b B) (T, error) {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
return fn(a, b)
|
return fn(a, b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,8 +54,7 @@ func TestRunnerFunctions(t *testing.T) {
|
|||||||
// Set up mock runner for testing
|
// Set up mock runner for testing
|
||||||
mockRunner := NewMockRunner()
|
mockRunner := NewMockRunner()
|
||||||
mockRunner.SetResponse("test-cmd arg1", []byte("test output"))
|
mockRunner.SetResponse("test-cmd arg1", []byte("test output"))
|
||||||
SetRunner(mockRunner)
|
defer WithTestRunner(t, mockRunner)()
|
||||||
defer SetRunner(&OSRunner{}) // Restore real runner
|
|
||||||
|
|
||||||
// Test RunnerCombinedOutput
|
// Test RunnerCombinedOutput
|
||||||
output, err := RunnerCombinedOutput("test-cmd", "arg1")
|
output, err := RunnerCombinedOutput("test-cmd", "arg1")
|
||||||
|
|||||||
@@ -52,6 +52,18 @@ func SetFilterDir(dir string) {
|
|||||||
// OSRunner runs commands locally.
|
// OSRunner runs commands locally.
|
||||||
type OSRunner struct{}
|
type OSRunner struct{}
|
||||||
|
|
||||||
|
// validateCommandExecution validates command name and arguments before execution.
|
||||||
|
// This helper consolidates the duplicate validation pattern used in command execution methods.
|
||||||
|
func validateCommandExecution(ctx context.Context, name string, args []string) error {
|
||||||
|
if err := CachedValidateCommand(ctx, name); err != nil {
|
||||||
|
return fmt.Errorf(shared.ErrCommandValidationFailed, err)
|
||||||
|
}
|
||||||
|
if err := ValidateArgumentsWithContext(ctx, args); err != nil {
|
||||||
|
return fmt.Errorf(shared.ErrArgumentValidationFailed, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CombinedOutput executes a command without sudo.
|
// CombinedOutput executes a command without sudo.
|
||||||
func (r *OSRunner) CombinedOutput(name string, args ...string) ([]byte, error) {
|
func (r *OSRunner) CombinedOutput(name string, args ...string) ([]byte, error) {
|
||||||
return r.CombinedOutputWithContext(context.Background(), name, args...)
|
return r.CombinedOutputWithContext(context.Background(), name, args...)
|
||||||
@@ -59,13 +71,8 @@ func (r *OSRunner) CombinedOutput(name string, args ...string) ([]byte, error) {
|
|||||||
|
|
||||||
// CombinedOutputWithContext executes a command without sudo with context support.
|
// CombinedOutputWithContext executes a command without sudo with context support.
|
||||||
func (r *OSRunner) CombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
func (r *OSRunner) CombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||||
// Validate command for security
|
if err := validateCommandExecution(ctx, name, args); err != nil {
|
||||||
if err := CachedValidateCommand(ctx, name); err != nil {
|
return nil, err
|
||||||
return nil, fmt.Errorf(shared.ErrCommandValidationFailed, err)
|
|
||||||
}
|
|
||||||
// Validate arguments for security
|
|
||||||
if err := ValidateArgumentsWithContext(ctx, args); err != nil {
|
|
||||||
return nil, fmt.Errorf(shared.ErrArgumentValidationFailed, err)
|
|
||||||
}
|
}
|
||||||
return exec.CommandContext(ctx, name, args...).CombinedOutput()
|
return exec.CommandContext(ctx, name, args...).CombinedOutput()
|
||||||
}
|
}
|
||||||
@@ -77,13 +84,8 @@ func (r *OSRunner) CombinedOutputWithSudo(name string, args ...string) ([]byte,
|
|||||||
|
|
||||||
// CombinedOutputWithSudoContext executes a command with sudo if needed, with context support.
|
// CombinedOutputWithSudoContext executes a command with sudo if needed, with context support.
|
||||||
func (r *OSRunner) CombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
func (r *OSRunner) CombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||||
// Validate command for security
|
if err := validateCommandExecution(ctx, name, args); err != nil {
|
||||||
if err := CachedValidateCommand(ctx, name); err != nil {
|
return nil, err
|
||||||
return nil, fmt.Errorf(shared.ErrCommandValidationFailed, err)
|
|
||||||
}
|
|
||||||
// Validate arguments for security
|
|
||||||
if err := ValidateArgumentsWithContext(ctx, args); err != nil {
|
|
||||||
return nil, fmt.Errorf(shared.ErrArgumentValidationFailed, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
checker := GetSudoChecker()
|
checker := GetSudoChecker()
|
||||||
@@ -158,30 +160,38 @@ func RunnerCombinedOutputWithSudo(name string, args ...string) ([]byte, error) {
|
|||||||
return output, err
|
return output, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runWithTimerContext is a helper that consolidates the common pattern of
|
||||||
|
// creating a timer, getting the runner, executing a command, and finishing the timer.
|
||||||
|
// This reduces code duplication between RunnerCombinedOutputWithContext and RunnerCombinedOutputWithSudoContext.
|
||||||
|
func runWithTimerContext(
|
||||||
|
ctx context.Context,
|
||||||
|
opName, name string,
|
||||||
|
args []string,
|
||||||
|
runFn func(Runner, context.Context, string, ...string) ([]byte, error),
|
||||||
|
) ([]byte, error) {
|
||||||
|
timer := NewTimedOperation(opName, name, args...)
|
||||||
|
runner := GetRunner()
|
||||||
|
output, err := runFn(runner, ctx, name, args...)
|
||||||
|
timer.FinishWithContext(ctx, err)
|
||||||
|
return output, err
|
||||||
|
}
|
||||||
|
|
||||||
// RunnerCombinedOutputWithContext invokes the runner for a command with context support.
|
// RunnerCombinedOutputWithContext invokes the runner for a command with context support.
|
||||||
// RunnerCombinedOutputWithContext executes a command with context using the global runner.
|
// RunnerCombinedOutputWithContext executes a command with context using the global runner.
|
||||||
func RunnerCombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
func RunnerCombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||||
timer := NewTimedOperation("RunnerCombinedOutputWithContext", name, args...)
|
return runWithTimerContext(ctx, "RunnerCombinedOutputWithContext", name, args,
|
||||||
|
func(r Runner, c context.Context, n string, a ...string) ([]byte, error) {
|
||||||
runner := GetRunner()
|
return r.CombinedOutputWithContext(c, n, a...)
|
||||||
|
})
|
||||||
output, err := runner.CombinedOutputWithContext(ctx, name, args...)
|
|
||||||
timer.FinishWithContext(ctx, err)
|
|
||||||
|
|
||||||
return output, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunnerCombinedOutputWithSudoContext invokes the runner for a command with sudo and context support.
|
// RunnerCombinedOutputWithSudoContext invokes the runner for a command with sudo and context support.
|
||||||
// RunnerCombinedOutputWithSudoContext executes a command with sudo privileges and context using the global runner.
|
// RunnerCombinedOutputWithSudoContext executes a command with sudo privileges and context using the global runner.
|
||||||
func RunnerCombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
func RunnerCombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||||
timer := NewTimedOperation("RunnerCombinedOutputWithSudoContext", name, args...)
|
return runWithTimerContext(ctx, "RunnerCombinedOutputWithSudoContext", name, args,
|
||||||
|
func(r Runner, c context.Context, n string, a ...string) ([]byte, error) {
|
||||||
runner := GetRunner()
|
return r.CombinedOutputWithSudoContext(c, n, a...)
|
||||||
|
})
|
||||||
output, err := runner.CombinedOutputWithSudoContext(ctx, name, args...)
|
|
||||||
timer.FinishWithContext(ctx, err)
|
|
||||||
|
|
||||||
return output, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockRunner is a simple mock for Runner, used in unit tests.
|
// MockRunner is a simple mock for Runner, used in unit tests.
|
||||||
@@ -274,6 +284,17 @@ func (m *MockRunner) CombinedOutputWithSudo(name string, args ...string) ([]byte
|
|||||||
return m.CombinedOutput(name, args...)
|
return m.CombinedOutput(name, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// withContextCheck wraps an operation with context cancellation check.
|
||||||
|
// This helper consolidates the duplicate context cancellation pattern.
|
||||||
|
func withContextCheck(ctx context.Context, fn func() ([]byte, error)) ([]byte, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return fn()
|
||||||
|
}
|
||||||
|
|
||||||
// SetResponse sets a response for a command.
|
// SetResponse sets a response for a command.
|
||||||
func (m *MockRunner) SetResponse(cmd string, response []byte) {
|
func (m *MockRunner) SetResponse(cmd string, response []byte) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -298,30 +319,50 @@ func (m *MockRunner) GetCalls() []string {
|
|||||||
return calls
|
return calls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetupStandardResponses configures comprehensive standard responses for testing.
|
||||||
|
// This eliminates the need for repetitive SetResponse calls in individual tests.
|
||||||
|
func (m *MockRunner) SetupStandardResponses() {
|
||||||
|
StandardMockSetup(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupJailResponses configures responses for a specific jail.
|
||||||
|
// This is useful for tests that focus on a single jail's behavior.
|
||||||
|
func (m *MockRunner) SetupJailResponses(jail string) {
|
||||||
|
statusResponse := fmt.Sprintf("Status for the jail: %s\n|- Filter\n| |- Currently failed:\t0\n| "+
|
||||||
|
"|- Total failed:\t5\n| `- File list:\t/var/log/auth.log\n`- Actions\n "+
|
||||||
|
"|- Currently banned:\t1\n |- Total banned:\t2\n `- Banned IP list:\t192.168.1.100", jail)
|
||||||
|
|
||||||
|
m.SetResponse(fmt.Sprintf("fail2ban-client status %s", jail), []byte(statusResponse))
|
||||||
|
m.SetResponse(fmt.Sprintf("sudo fail2ban-client status %s", jail), []byte(statusResponse))
|
||||||
|
|
||||||
|
// Common ban/unban operations for the jail (use success status, not already-processed)
|
||||||
|
m.SetResponse(fmt.Sprintf("fail2ban-client set %s banip 192.168.1.100", jail), []byte(shared.Fail2BanStatusSuccess))
|
||||||
|
m.SetResponse(
|
||||||
|
fmt.Sprintf("sudo fail2ban-client set %s banip 192.168.1.100", jail),
|
||||||
|
[]byte(shared.Fail2BanStatusSuccess),
|
||||||
|
)
|
||||||
|
m.SetResponse(
|
||||||
|
fmt.Sprintf("fail2ban-client set %s unbanip 192.168.1.100", jail),
|
||||||
|
[]byte(shared.Fail2BanStatusSuccess),
|
||||||
|
)
|
||||||
|
m.SetResponse(
|
||||||
|
fmt.Sprintf("sudo fail2ban-client set %s unbanip 192.168.1.100", jail),
|
||||||
|
[]byte(shared.Fail2BanStatusSuccess),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// CombinedOutputWithContext returns a mocked response or error for a command with context support.
|
// CombinedOutputWithContext returns a mocked response or error for a command with context support.
|
||||||
func (m *MockRunner) CombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
func (m *MockRunner) CombinedOutputWithContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||||
// Check if context is canceled
|
return withContextCheck(ctx, func() ([]byte, error) {
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delegate to the non-context version for simplicity in tests
|
|
||||||
return m.CombinedOutput(name, args...)
|
return m.CombinedOutput(name, args...)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CombinedOutputWithSudoContext returns a mocked response for sudo commands with context support.
|
// CombinedOutputWithSudoContext returns a mocked response for sudo commands with context support.
|
||||||
func (m *MockRunner) CombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
func (m *MockRunner) CombinedOutputWithSudoContext(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||||
// Check if context is canceled
|
return withContextCheck(ctx, func() ([]byte, error) {
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delegate to the non-context version for simplicity in tests
|
|
||||||
return m.CombinedOutputWithSudo(name, args...)
|
return m.CombinedOutputWithSudo(name, args...)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RealClient) fetchJailsWithContext(ctx context.Context) ([]string, error) {
|
func (c *RealClient) fetchJailsWithContext(ctx context.Context) ([]string, error) {
|
||||||
@@ -487,8 +528,12 @@ func (c *RealClient) StatusJailWithContext(ctx context.Context, jail string) (st
|
|||||||
return string(out), err
|
return string(out), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// BanIPWithContext bans an IP address in the specified jail with context support.
|
// executeIPActionWithContext executes a ban/unban IP action with validation and response parsing.
|
||||||
func (c *RealClient) BanIPWithContext(ctx context.Context, ip, jail string) (int, error) {
|
// It returns (0, nil) for success, (1, nil) if already processed, or an error.
|
||||||
|
func (c *RealClient) executeIPActionWithContext(
|
||||||
|
ctx context.Context,
|
||||||
|
ip, jail, action, errorTemplate string,
|
||||||
|
) (int, error) {
|
||||||
if err := CachedValidateIP(ctx, ip); err != nil {
|
if err := CachedValidateIP(ctx, ip); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -497,10 +542,9 @@ func (c *RealClient) BanIPWithContext(ctx context.Context, ip, jail string) (int
|
|||||||
}
|
}
|
||||||
|
|
||||||
currentRunner := GetRunner()
|
currentRunner := GetRunner()
|
||||||
|
out, err := currentRunner.CombinedOutputWithSudoContext(ctx, c.Path, shared.ActionSet, jail, action, ip)
|
||||||
out, err := currentRunner.CombinedOutputWithSudoContext(ctx, c.Path, shared.ActionSet, jail, shared.ActionBanIP, ip)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf(shared.ErrFailedToBanIP, ip, jail, err)
|
return 0, fmt.Errorf(errorTemplate, ip, jail, err)
|
||||||
}
|
}
|
||||||
code := strings.TrimSpace(string(out))
|
code := strings.TrimSpace(string(out))
|
||||||
if code == shared.Fail2BanStatusSuccess {
|
if code == shared.Fail2BanStatusSuccess {
|
||||||
@@ -512,36 +556,14 @@ func (c *RealClient) BanIPWithContext(ctx context.Context, ip, jail string) (int
|
|||||||
return 0, fmt.Errorf(shared.ErrUnexpectedOutput, code)
|
return 0, fmt.Errorf(shared.ErrUnexpectedOutput, code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BanIPWithContext bans an IP address in the specified jail with context support.
|
||||||
|
func (c *RealClient) BanIPWithContext(ctx context.Context, ip, jail string) (int, error) {
|
||||||
|
return c.executeIPActionWithContext(ctx, ip, jail, shared.ActionBanIP, shared.ErrFailedToBanIP)
|
||||||
|
}
|
||||||
|
|
||||||
// UnbanIPWithContext unbans an IP address from the specified jail with context support.
|
// UnbanIPWithContext unbans an IP address from the specified jail with context support.
|
||||||
func (c *RealClient) UnbanIPWithContext(ctx context.Context, ip, jail string) (int, error) {
|
func (c *RealClient) UnbanIPWithContext(ctx context.Context, ip, jail string) (int, error) {
|
||||||
if err := CachedValidateIP(ctx, ip); err != nil {
|
return c.executeIPActionWithContext(ctx, ip, jail, shared.ActionUnbanIP, shared.ErrFailedToUnbanIP)
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if err := CachedValidateJail(ctx, jail); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
currentRunner := GetRunner()
|
|
||||||
|
|
||||||
out, err := currentRunner.CombinedOutputWithSudoContext(
|
|
||||||
ctx,
|
|
||||||
c.Path,
|
|
||||||
shared.ActionSet,
|
|
||||||
jail,
|
|
||||||
shared.ActionUnbanIP,
|
|
||||||
ip,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf(shared.ErrFailedToUnbanIP, ip, jail, err)
|
|
||||||
}
|
|
||||||
code := strings.TrimSpace(string(out))
|
|
||||||
if code == shared.Fail2BanStatusSuccess {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if code == shared.Fail2BanStatusAlreadyProcessed {
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf(shared.ErrUnexpectedOutput, code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BannedInWithContext returns a list of jails where the specified IP address is currently banned with context support.
|
// BannedInWithContext returns a list of jails where the specified IP address is currently banned with context support.
|
||||||
|
|||||||
@@ -10,8 +10,7 @@ import (
|
|||||||
// TestRunnerConcurrentAccess tests that concurrent access to the runner
|
// TestRunnerConcurrentAccess tests that concurrent access to the runner
|
||||||
// is safe and doesn't cause race conditions.
|
// is safe and doesn't cause race conditions.
|
||||||
func TestRunnerConcurrentAccess(t *testing.T) {
|
func TestRunnerConcurrentAccess(t *testing.T) {
|
||||||
original := GetRunner()
|
defer WithTestRunner(t, GetRunner())()
|
||||||
defer SetRunner(original)
|
|
||||||
|
|
||||||
const numGoroutines = 100
|
const numGoroutines = 100
|
||||||
const numOperations = 50
|
const numOperations = 50
|
||||||
@@ -53,12 +52,9 @@ func TestRunnerConcurrentAccess(t *testing.T) {
|
|||||||
// TestRunnerCombinedOutputConcurrency tests that concurrent calls to
|
// TestRunnerCombinedOutputConcurrency tests that concurrent calls to
|
||||||
// RunnerCombinedOutput are safe.
|
// RunnerCombinedOutput are safe.
|
||||||
func TestRunnerCombinedOutputConcurrency(t *testing.T) {
|
func TestRunnerCombinedOutputConcurrency(t *testing.T) {
|
||||||
original := GetRunner()
|
|
||||||
defer SetRunner(original)
|
|
||||||
|
|
||||||
mockRunner := NewMockRunner()
|
mockRunner := NewMockRunner()
|
||||||
|
defer WithTestRunner(t, mockRunner)()
|
||||||
mockRunner.SetResponse("echo test", []byte("test output"))
|
mockRunner.SetResponse("echo test", []byte("test output"))
|
||||||
SetRunner(mockRunner)
|
|
||||||
|
|
||||||
const numGoroutines = 50
|
const numGoroutines = 50
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@@ -120,12 +116,10 @@ func TestRunnerCombinedOutputWithSudoConcurrency(t *testing.T) {
|
|||||||
// TestMixedConcurrentOperations tests mixed concurrent operations including
|
// TestMixedConcurrentOperations tests mixed concurrent operations including
|
||||||
// setting runners and executing commands.
|
// setting runners and executing commands.
|
||||||
func TestMixedConcurrentOperations(t *testing.T) {
|
func TestMixedConcurrentOperations(t *testing.T) {
|
||||||
original := GetRunner()
|
|
||||||
defer SetRunner(original)
|
|
||||||
|
|
||||||
// Set up a single shared MockRunner with all required responses
|
// Set up a single shared MockRunner with all required responses
|
||||||
// This avoids race conditions from multiple goroutines setting different runners
|
// This avoids race conditions from multiple goroutines setting different runners
|
||||||
sharedMockRunner := NewMockRunner()
|
sharedMockRunner := NewMockRunner()
|
||||||
|
defer WithTestRunner(t, sharedMockRunner)()
|
||||||
|
|
||||||
// Set up responses for valid fail2ban commands to avoid validation errors
|
// Set up responses for valid fail2ban commands to avoid validation errors
|
||||||
sharedMockRunner.SetResponse("fail2ban-client status", []byte("Status: OK"))
|
sharedMockRunner.SetResponse("fail2ban-client status", []byte("Status: OK"))
|
||||||
@@ -135,8 +129,6 @@ func TestMixedConcurrentOperations(t *testing.T) {
|
|||||||
sharedMockRunner.SetResponse("sudo fail2ban-client status", []byte("Status: OK"))
|
sharedMockRunner.SetResponse("sudo fail2ban-client status", []byte("Status: OK"))
|
||||||
sharedMockRunner.SetResponse("sudo fail2ban-client -V", []byte("Version: 1.0.0"))
|
sharedMockRunner.SetResponse("sudo fail2ban-client -V", []byte("Version: 1.0.0"))
|
||||||
|
|
||||||
SetRunner(sharedMockRunner)
|
|
||||||
|
|
||||||
const numGoroutines = 30
|
const numGoroutines = 30
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
@@ -203,8 +195,7 @@ func TestMixedConcurrentOperations(t *testing.T) {
|
|||||||
// TestRunnerManagerLockOrdering verifies there are no deadlocks in the
|
// TestRunnerManagerLockOrdering verifies there are no deadlocks in the
|
||||||
// runner manager's lock ordering.
|
// runner manager's lock ordering.
|
||||||
func TestRunnerManagerLockOrdering(t *testing.T) {
|
func TestRunnerManagerLockOrdering(t *testing.T) {
|
||||||
original := GetRunner()
|
defer WithTestRunner(t, GetRunner())()
|
||||||
defer SetRunner(original)
|
|
||||||
|
|
||||||
// This test specifically looks for deadlocks by creating scenarios
|
// This test specifically looks for deadlocks by creating scenarios
|
||||||
// where multiple goroutines could potentially deadlock if locks
|
// where multiple goroutines could potentially deadlock if locks
|
||||||
@@ -245,13 +236,10 @@ func TestRunnerManagerLockOrdering(t *testing.T) {
|
|||||||
// TestRunnerStateConsistency verifies that the runner state remains
|
// TestRunnerStateConsistency verifies that the runner state remains
|
||||||
// consistent across concurrent operations.
|
// consistent across concurrent operations.
|
||||||
func TestRunnerStateConsistency(t *testing.T) {
|
func TestRunnerStateConsistency(t *testing.T) {
|
||||||
original := GetRunner()
|
|
||||||
defer SetRunner(original)
|
|
||||||
|
|
||||||
// Set initial state
|
// Set initial state
|
||||||
initialRunner := NewMockRunner()
|
initialRunner := NewMockRunner()
|
||||||
initialRunner.SetResponse("initial", []byte("initial response"))
|
initialRunner.SetResponse("initial", []byte("initial response"))
|
||||||
SetRunner(initialRunner)
|
defer WithTestRunner(t, initialRunner)()
|
||||||
|
|
||||||
const numReaders = 50
|
const numReaders = 50
|
||||||
const numWriters = 10
|
const numWriters = 10
|
||||||
|
|||||||
@@ -182,7 +182,10 @@ func TestBanIP(t *testing.T) {
|
|||||||
fmt.Errorf("command failed"),
|
fmt.Errorf("command failed"),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
mock.SetResponse(fmt.Sprintf("sudo fail2ban-client set %s banip %s", tt.jail, tt.ip), []byte(tt.mockResponse))
|
mock.SetResponse(
|
||||||
|
fmt.Sprintf("sudo fail2ban-client set %s banip %s", tt.jail, tt.ip),
|
||||||
|
[]byte(tt.mockResponse),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := NewClient(shared.DefaultLogDir, shared.DefaultFilterDir)
|
client, err := NewClient(shared.DefaultLogDir, shared.DefaultFilterDir)
|
||||||
|
|||||||
@@ -11,6 +11,32 @@ import (
|
|||||||
"github.com/ivuorinen/f2b/shared"
|
"github.com/ivuorinen/f2b/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// safeCloseFile closes a file and logs any error.
|
||||||
|
// This helper consolidates the duplicate close-and-log pattern.
|
||||||
|
func safeCloseFile(f *os.File, path string) {
|
||||||
|
if f == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if closeErr := f.Close(); closeErr != nil {
|
||||||
|
getLogger().WithError(closeErr).
|
||||||
|
WithField(shared.LogFieldFile, path).
|
||||||
|
Warn("Failed to close file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeCloseReader closes a reader and logs any error.
|
||||||
|
// This helper consolidates the duplicate close-and-log pattern for io.ReadCloser.
|
||||||
|
func safeCloseReader(r io.ReadCloser, path string) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if closeErr := r.Close(); closeErr != nil {
|
||||||
|
getLogger().WithError(closeErr).
|
||||||
|
WithField(shared.LogFieldFile, path).
|
||||||
|
Warn("Failed to close reader")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GzipDetector provides utilities for detecting and handling gzip-compressed files
|
// GzipDetector provides utilities for detecting and handling gzip-compressed files
|
||||||
type GzipDetector struct{}
|
type GzipDetector struct{}
|
||||||
|
|
||||||
@@ -38,13 +64,7 @@ func (gd *GzipDetector) hasGzipMagicBytes(path string) (bool, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer safeCloseFile(f, path)
|
||||||
if closeErr := f.Close(); closeErr != nil {
|
|
||||||
getLogger().WithError(closeErr).
|
|
||||||
WithField(shared.LogFieldFile, path).
|
|
||||||
Warn("Failed to close file in gzip magic byte check")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var magic [2]byte
|
var magic [2]byte
|
||||||
n, err := f.Read(magic[:])
|
n, err := f.Read(magic[:])
|
||||||
@@ -70,11 +90,7 @@ func (gd *GzipDetector) OpenGzipAwareReader(path string) (io.ReadCloser, error)
|
|||||||
|
|
||||||
isGzip, err := gd.IsGzipFile(path)
|
isGzip, err := gd.IsGzipFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if closeErr := f.Close(); closeErr != nil {
|
safeCloseFile(f, path)
|
||||||
getLogger().WithError(closeErr).
|
|
||||||
WithField(shared.LogFieldFile, path).
|
|
||||||
Warn("Failed to close file during error handling")
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,21 +98,13 @@ func (gd *GzipDetector) OpenGzipAwareReader(path string) (io.ReadCloser, error)
|
|||||||
// For gzip files, we need to position at the beginning and create gzip reader
|
// For gzip files, we need to position at the beginning and create gzip reader
|
||||||
_, err = f.Seek(0, io.SeekStart)
|
_, err = f.Seek(0, io.SeekStart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if closeErr := f.Close(); closeErr != nil {
|
safeCloseFile(f, path)
|
||||||
getLogger().WithError(closeErr).
|
|
||||||
WithField(shared.LogFieldFile, path).
|
|
||||||
Warn("Failed to close file during seek error handling")
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
gz, err := gzip.NewReader(f)
|
gz, err := gzip.NewReader(f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if closeErr := f.Close(); closeErr != nil {
|
safeCloseFile(f, path)
|
||||||
getLogger().WithError(closeErr).
|
|
||||||
WithField(shared.LogFieldFile, path).
|
|
||||||
Warn("Failed to close file during gzip reader error handling")
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,11 +136,7 @@ func (gd *GzipDetector) CreateGzipAwareScannerWithBuffer(path string, maxLineSiz
|
|||||||
}
|
}
|
||||||
|
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
if err := reader.Close(); err != nil {
|
safeCloseReader(reader, path)
|
||||||
getLogger().WithError(err).
|
|
||||||
WithField(shared.LogFieldFile, path).
|
|
||||||
Warn("Failed to close reader during cleanup")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return scanner, cleanup, nil
|
return scanner, cleanup, nil
|
||||||
|
|||||||
@@ -792,20 +792,23 @@ func ValidateLogPath(ctx context.Context, path string, logDir string) (string, e
|
|||||||
return ValidatePathWithSecurity(path, config)
|
return ValidatePathWithSecurity(path, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateClientPath is a generic helper for client path validation.
|
||||||
|
// It reduces duplication between ValidateClientLogPath and ValidateClientFilterPath.
|
||||||
|
func validateClientPath(ctx context.Context, path string, configFn func() PathSecurityConfig) (string, error) {
|
||||||
|
_ = ctx // Context not currently used by ValidatePathWithSecurity
|
||||||
|
return ValidatePathWithSecurity(path, configFn())
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateClientLogPath validates log directory path for client initialization
|
// ValidateClientLogPath validates log directory path for client initialization
|
||||||
// Context parameter accepted for API consistency but not currently used
|
// Context parameter accepted for API consistency but not currently used
|
||||||
func ValidateClientLogPath(ctx context.Context, logDir string) (string, error) {
|
func ValidateClientLogPath(ctx context.Context, logDir string) (string, error) {
|
||||||
_ = ctx // Context not currently used by ValidatePathWithSecurity
|
return validateClientPath(ctx, logDir, CreateLogPathConfig)
|
||||||
config := CreateLogPathConfig()
|
|
||||||
return ValidatePathWithSecurity(logDir, config)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateClientFilterPath validates filter directory path for client initialization
|
// ValidateClientFilterPath validates filter directory path for client initialization
|
||||||
// Context parameter accepted for API consistency but not currently used
|
// Context parameter accepted for API consistency but not currently used
|
||||||
func ValidateClientFilterPath(ctx context.Context, filterDir string) (string, error) {
|
func ValidateClientFilterPath(ctx context.Context, filterDir string) (string, error) {
|
||||||
_ = ctx // Context not currently used by ValidatePathWithSecurity
|
return validateClientPath(ctx, filterDir, CreateFilterPathConfig)
|
||||||
config := CreateFilterPathConfig()
|
|
||||||
return ValidatePathWithSecurity(filterDir, config)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateFilterName validates a filter name for path traversal prevention.
|
// ValidateFilterName validates a filter name for path traversal prevention.
|
||||||
|
|||||||
@@ -75,18 +75,14 @@ func TestValidateFilterName(t *testing.T) {
|
|||||||
|
|
||||||
// TestGetLogLinesWrapper tests the GetLogLines wrapper function
|
// TestGetLogLinesWrapper tests the GetLogLines wrapper function
|
||||||
func TestGetLogLinesWrapper(t *testing.T) {
|
func TestGetLogLinesWrapper(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := GetRunner()
|
|
||||||
defer SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := NewMockRunner()
|
mockRunner := NewMockRunner()
|
||||||
|
defer WithTestRunner(t, mockRunner)()
|
||||||
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
||||||
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
||||||
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
SetRunner(mockRunner)
|
|
||||||
|
|
||||||
// Create temporary log directory
|
// Create temporary log directory
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
@@ -107,9 +103,7 @@ func TestGetLogLinesWrapper(t *testing.T) {
|
|||||||
|
|
||||||
// TestBanIPWithContext tests the BanIPWithContext function
|
// TestBanIPWithContext tests the BanIPWithContext function
|
||||||
func TestBanIPWithContext(t *testing.T) {
|
func TestBanIPWithContext(t *testing.T) {
|
||||||
// Save and restore original runner
|
defer WithTestRunner(t, GetRunner())()
|
||||||
originalRunner := GetRunner()
|
|
||||||
defer SetRunner(originalRunner)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -160,18 +154,14 @@ func TestBanIPWithContext(t *testing.T) {
|
|||||||
|
|
||||||
// TestGetLogLinesWithLimitAndContext tests the GetLogLinesWithLimitAndContext function
|
// TestGetLogLinesWithLimitAndContext tests the GetLogLinesWithLimitAndContext function
|
||||||
func TestGetLogLinesWithLimitAndContext(t *testing.T) {
|
func TestGetLogLinesWithLimitAndContext(t *testing.T) {
|
||||||
// Save and restore original runner
|
|
||||||
originalRunner := GetRunner()
|
|
||||||
defer SetRunner(originalRunner)
|
|
||||||
|
|
||||||
mockRunner := NewMockRunner()
|
mockRunner := NewMockRunner()
|
||||||
|
defer WithTestRunner(t, mockRunner)()
|
||||||
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
mockRunner.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
mockRunner.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0"))
|
||||||
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
mockRunner.SetResponse("fail2ban-client ping", []byte("Server replied: pong"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
mockRunner.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong"))
|
||||||
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
mockRunner.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
mockRunner.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd"))
|
||||||
SetRunner(mockRunner)
|
|
||||||
|
|
||||||
// Create temporary log directory
|
// Create temporary log directory
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
|||||||
@@ -2,14 +2,47 @@ package fail2ban
|
|||||||
|
|
||||||
import "github.com/sirupsen/logrus"
|
import "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
// logrusAdapter wraps logrus to implement our decoupled LoggerInterface
|
// loggerCore provides common logging methods that delegate to a logrus.Entry.
|
||||||
type logrusAdapter struct {
|
// This type is embedded in both logrusAdapter and logrusEntryAdapter to
|
||||||
|
// eliminate duplicate method implementations.
|
||||||
|
type loggerCore struct {
|
||||||
entry *logrus.Entry
|
entry *logrus.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
// logrusEntryAdapter wraps logrus.Entry to implement LoggerEntry
|
// Debug logs a debug-level message.
|
||||||
|
func (c *loggerCore) Debug(args ...interface{}) { c.entry.Debug(args...) }
|
||||||
|
|
||||||
|
// Info logs an info-level message.
|
||||||
|
func (c *loggerCore) Info(args ...interface{}) { c.entry.Info(args...) }
|
||||||
|
|
||||||
|
// Warn logs a warning-level message.
|
||||||
|
func (c *loggerCore) Warn(args ...interface{}) { c.entry.Warn(args...) }
|
||||||
|
|
||||||
|
// Error logs an error-level message.
|
||||||
|
func (c *loggerCore) Error(args ...interface{}) { c.entry.Error(args...) }
|
||||||
|
|
||||||
|
// Debugf logs a formatted debug-level message.
|
||||||
|
func (c *loggerCore) Debugf(format string, args ...interface{}) { c.entry.Debugf(format, args...) }
|
||||||
|
|
||||||
|
// Infof logs a formatted info-level message.
|
||||||
|
func (c *loggerCore) Infof(format string, args ...interface{}) { c.entry.Infof(format, args...) }
|
||||||
|
|
||||||
|
// Warnf logs a formatted warning-level message.
|
||||||
|
func (c *loggerCore) Warnf(format string, args ...interface{}) { c.entry.Warnf(format, args...) }
|
||||||
|
|
||||||
|
// Errorf logs a formatted error-level message.
|
||||||
|
func (c *loggerCore) Errorf(format string, args ...interface{}) { c.entry.Errorf(format, args...) }
|
||||||
|
|
||||||
|
// logrusAdapter wraps logrus to implement our decoupled LoggerInterface.
|
||||||
|
// It embeds loggerCore to provide the 8 standard logging methods.
|
||||||
|
type logrusAdapter struct {
|
||||||
|
loggerCore // embeds Debug, Info, Warn, Error, Debugf, Infof, Warnf, Errorf
|
||||||
|
}
|
||||||
|
|
||||||
|
// logrusEntryAdapter wraps logrus.Entry to implement LoggerEntry.
|
||||||
|
// It embeds loggerCore to provide the 8 standard logging methods.
|
||||||
type logrusEntryAdapter struct {
|
type logrusEntryAdapter struct {
|
||||||
entry *logrus.Entry
|
loggerCore // embeds Debug, Info, Warn, Error, Debugf, Infof, Warnf, Errorf
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure logrusAdapter implements LoggerInterface
|
// Ensure logrusAdapter implements LoggerInterface
|
||||||
@@ -23,117 +56,37 @@ func NewLogrusAdapter(logger *logrus.Logger) LoggerInterface {
|
|||||||
if logger == nil {
|
if logger == nil {
|
||||||
logger = logrus.StandardLogger()
|
logger = logrus.StandardLogger()
|
||||||
}
|
}
|
||||||
return &logrusAdapter{entry: logrus.NewEntry(logger)}
|
return &logrusAdapter{loggerCore: loggerCore{entry: logrus.NewEntry(logger)}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithField implements LoggerInterface
|
// WithField implements LoggerInterface
|
||||||
func (l *logrusAdapter) WithField(key string, value interface{}) LoggerEntry {
|
func (l *logrusAdapter) WithField(key string, value interface{}) LoggerEntry {
|
||||||
return &logrusEntryAdapter{entry: l.entry.WithField(key, value)}
|
return &logrusEntryAdapter{loggerCore: loggerCore{entry: l.entry.WithField(key, value)}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithFields implements LoggerInterface
|
// WithFields implements LoggerInterface
|
||||||
func (l *logrusAdapter) WithFields(fields Fields) LoggerEntry {
|
func (l *logrusAdapter) WithFields(fields Fields) LoggerEntry {
|
||||||
return &logrusEntryAdapter{entry: l.entry.WithFields(logrus.Fields(fields))}
|
return &logrusEntryAdapter{loggerCore: loggerCore{entry: l.entry.WithFields(logrus.Fields(fields))}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithError implements LoggerInterface
|
// WithError implements LoggerInterface
|
||||||
func (l *logrusAdapter) WithError(err error) LoggerEntry {
|
func (l *logrusAdapter) WithError(err error) LoggerEntry {
|
||||||
return &logrusEntryAdapter{entry: l.entry.WithError(err)}
|
return &logrusEntryAdapter{loggerCore: loggerCore{entry: l.entry.WithError(err)}}
|
||||||
}
|
|
||||||
|
|
||||||
// Debug implements LoggerInterface
|
|
||||||
func (l *logrusAdapter) Debug(args ...interface{}) {
|
|
||||||
l.entry.Debug(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Info implements LoggerInterface
|
|
||||||
func (l *logrusAdapter) Info(args ...interface{}) {
|
|
||||||
l.entry.Info(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warn implements LoggerInterface
|
|
||||||
func (l *logrusAdapter) Warn(args ...interface{}) {
|
|
||||||
l.entry.Warn(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error implements LoggerInterface
|
|
||||||
func (l *logrusAdapter) Error(args ...interface{}) {
|
|
||||||
l.entry.Error(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debugf implements LoggerInterface
|
|
||||||
func (l *logrusAdapter) Debugf(format string, args ...interface{}) {
|
|
||||||
l.entry.Debugf(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Infof implements LoggerInterface
|
|
||||||
func (l *logrusAdapter) Infof(format string, args ...interface{}) {
|
|
||||||
l.entry.Infof(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warnf implements LoggerInterface
|
|
||||||
func (l *logrusAdapter) Warnf(format string, args ...interface{}) {
|
|
||||||
l.entry.Warnf(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Errorf implements LoggerInterface
|
|
||||||
func (l *logrusAdapter) Errorf(format string, args ...interface{}) {
|
|
||||||
l.entry.Errorf(format, args...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoggerEntry implementation for logrusEntryAdapter
|
// LoggerEntry implementation for logrusEntryAdapter
|
||||||
|
|
||||||
// WithField implements LoggerEntry
|
// WithField implements LoggerEntry
|
||||||
func (e *logrusEntryAdapter) WithField(key string, value interface{}) LoggerEntry {
|
func (e *logrusEntryAdapter) WithField(key string, value interface{}) LoggerEntry {
|
||||||
return &logrusEntryAdapter{entry: e.entry.WithField(key, value)}
|
return &logrusEntryAdapter{loggerCore: loggerCore{entry: e.entry.WithField(key, value)}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithFields implements LoggerEntry
|
// WithFields implements LoggerEntry
|
||||||
func (e *logrusEntryAdapter) WithFields(fields Fields) LoggerEntry {
|
func (e *logrusEntryAdapter) WithFields(fields Fields) LoggerEntry {
|
||||||
return &logrusEntryAdapter{entry: e.entry.WithFields(logrus.Fields(fields))}
|
return &logrusEntryAdapter{loggerCore: loggerCore{entry: e.entry.WithFields(logrus.Fields(fields))}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithError implements LoggerEntry
|
// WithError implements LoggerEntry
|
||||||
func (e *logrusEntryAdapter) WithError(err error) LoggerEntry {
|
func (e *logrusEntryAdapter) WithError(err error) LoggerEntry {
|
||||||
return &logrusEntryAdapter{entry: e.entry.WithError(err)}
|
return &logrusEntryAdapter{loggerCore: loggerCore{entry: e.entry.WithError(err)}}
|
||||||
}
|
|
||||||
|
|
||||||
// Debug implements LoggerEntry
|
|
||||||
func (e *logrusEntryAdapter) Debug(args ...interface{}) {
|
|
||||||
e.entry.Debug(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Info implements LoggerEntry
|
|
||||||
func (e *logrusEntryAdapter) Info(args ...interface{}) {
|
|
||||||
e.entry.Info(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warn implements LoggerEntry
|
|
||||||
func (e *logrusEntryAdapter) Warn(args ...interface{}) {
|
|
||||||
e.entry.Warn(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error implements LoggerEntry
|
|
||||||
func (e *logrusEntryAdapter) Error(args ...interface{}) {
|
|
||||||
e.entry.Error(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debugf implements LoggerEntry
|
|
||||||
func (e *logrusEntryAdapter) Debugf(format string, args ...interface{}) {
|
|
||||||
e.entry.Debugf(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Infof implements LoggerEntry
|
|
||||||
func (e *logrusEntryAdapter) Infof(format string, args ...interface{}) {
|
|
||||||
e.entry.Infof(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warnf implements LoggerEntry
|
|
||||||
func (e *logrusEntryAdapter) Warnf(format string, args ...interface{}) {
|
|
||||||
e.entry.Warnf(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Errorf implements LoggerEntry
|
|
||||||
func (e *logrusEntryAdapter) Errorf(format string, args ...interface{}) {
|
|
||||||
e.entry.Errorf(format, args...)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -377,11 +377,7 @@ func readLogFile(path string) ([]byte, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer safeCloseReader(reader, cleanPath)
|
||||||
if cerr := reader.Close(); cerr != nil {
|
|
||||||
getLogger().WithError(cerr).Error("failed to close log file")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return io.ReadAll(reader)
|
return io.ReadAll(reader)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,17 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// setNestedMapValue sets a value in a nested map[string]map[string]T structure with mutex protection.
|
||||||
|
// It initializes the inner map if nil.
|
||||||
|
func setNestedMapValue[T any](mu *sync.Mutex, mp map[string]map[string]T, jail, ip string, value T) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if mp[jail] == nil {
|
||||||
|
mp[jail] = make(map[string]T)
|
||||||
|
}
|
||||||
|
mp[jail][ip] = value
|
||||||
|
}
|
||||||
|
|
||||||
// MockClient is a stateful, thread-safe mock implementation of the Client interface for testing.
|
// MockClient is a stateful, thread-safe mock implementation of the Client interface for testing.
|
||||||
type MockClient struct {
|
type MockClient struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -286,42 +297,22 @@ func (m *MockClient) Reset() {
|
|||||||
|
|
||||||
// SetBanError configures an error to return for BanIP(ip, jail).
|
// SetBanError configures an error to return for BanIP(ip, jail).
|
||||||
func (m *MockClient) SetBanError(jail, ip string, err error) {
|
func (m *MockClient) SetBanError(jail, ip string, err error) {
|
||||||
m.mu.Lock()
|
setNestedMapValue(&m.mu, m.BanErrors, jail, ip, err)
|
||||||
defer m.mu.Unlock()
|
|
||||||
if m.BanErrors[jail] == nil {
|
|
||||||
m.BanErrors[jail] = make(map[string]error)
|
|
||||||
}
|
|
||||||
m.BanErrors[jail][ip] = err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetBanResult configures a result code to return for BanIP(ip, jail).
|
// SetBanResult configures a result code to return for BanIP(ip, jail).
|
||||||
func (m *MockClient) SetBanResult(jail, ip string, result int) {
|
func (m *MockClient) SetBanResult(jail, ip string, result int) {
|
||||||
m.mu.Lock()
|
setNestedMapValue(&m.mu, m.BanResults, jail, ip, result)
|
||||||
defer m.mu.Unlock()
|
|
||||||
if m.BanResults[jail] == nil {
|
|
||||||
m.BanResults[jail] = make(map[string]int)
|
|
||||||
}
|
|
||||||
m.BanResults[jail][ip] = result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUnbanError configures an error to return for UnbanIP(ip, jail).
|
// SetUnbanError configures an error to return for UnbanIP(ip, jail).
|
||||||
func (m *MockClient) SetUnbanError(jail, ip string, err error) {
|
func (m *MockClient) SetUnbanError(jail, ip string, err error) {
|
||||||
m.mu.Lock()
|
setNestedMapValue(&m.mu, m.UnbanErrors, jail, ip, err)
|
||||||
defer m.mu.Unlock()
|
|
||||||
if m.UnbanErrors[jail] == nil {
|
|
||||||
m.UnbanErrors[jail] = make(map[string]error)
|
|
||||||
}
|
|
||||||
m.UnbanErrors[jail][ip] = err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUnbanResult configures a result code to return for UnbanIP(ip, jail).
|
// SetUnbanResult configures a result code to return for UnbanIP(ip, jail).
|
||||||
func (m *MockClient) SetUnbanResult(jail, ip string, result int) {
|
func (m *MockClient) SetUnbanResult(jail, ip string, result int) {
|
||||||
m.mu.Lock()
|
setNestedMapValue(&m.mu, m.UnbanResults, jail, ip, result)
|
||||||
defer m.mu.Unlock()
|
|
||||||
if m.UnbanResults[jail] == nil {
|
|
||||||
m.UnbanResults[jail] = make(map[string]int)
|
|
||||||
}
|
|
||||||
m.UnbanResults[jail][ip] = result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetStatusJailData configures the status data for a specific jail.
|
// SetStatusJailData configures the status data for a specific jail.
|
||||||
|
|||||||
@@ -262,6 +262,24 @@ func assertContainsText(t *testing.T, lines []string, text string) {
|
|||||||
t.Errorf("Expected to find '%s' in results", text)
|
t.Errorf("Expected to find '%s' in results", text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithTestRunner sets a test runner and returns a cleanup function.
|
||||||
|
// Usage: defer fail2ban.WithTestRunner(t, mockRunner)()
|
||||||
|
func WithTestRunner(t TestingInterface, runner Runner) func() {
|
||||||
|
t.Helper()
|
||||||
|
original := GetRunner()
|
||||||
|
SetRunner(runner)
|
||||||
|
return func() { SetRunner(original) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTestSudoChecker sets a test sudo checker and returns a cleanup function.
|
||||||
|
// Usage: defer fail2ban.WithTestSudoChecker(t, mockChecker)()
|
||||||
|
func WithTestSudoChecker(t TestingInterface, checker SudoChecker) func() {
|
||||||
|
t.Helper()
|
||||||
|
original := GetSudoChecker()
|
||||||
|
SetSudoChecker(checker)
|
||||||
|
return func() { SetSudoChecker(original) }
|
||||||
|
}
|
||||||
|
|
||||||
// StandardMockSetup configures comprehensive standard responses for MockRunner
|
// StandardMockSetup configures comprehensive standard responses for MockRunner
|
||||||
// This eliminates the need for repetitive SetResponse calls in individual tests
|
// This eliminates the need for repetitive SetResponse calls in individual tests
|
||||||
func StandardMockSetup(mockRunner *MockRunner) {
|
func StandardMockSetup(mockRunner *MockRunner) {
|
||||||
|
|||||||
@@ -36,17 +36,7 @@ func NewTimeParsingCache(layout string) (*TimeParsingCache, error) {
|
|||||||
|
|
||||||
// ParseTime parses a time string with bounded caching for performance
|
// ParseTime parses a time string with bounded caching for performance
|
||||||
func (tpc *TimeParsingCache) ParseTime(timeStr string) (time.Time, error) {
|
func (tpc *TimeParsingCache) ParseTime(timeStr string) (time.Time, error) {
|
||||||
// Check cache first
|
return tpc.parseCache.ParseWithLayout(timeStr, tpc.layout)
|
||||||
if cached, ok := tpc.parseCache.Load(timeStr); ok {
|
|
||||||
return cached, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse and cache
|
|
||||||
t, err := time.Parse(tpc.layout, timeStr)
|
|
||||||
if err == nil {
|
|
||||||
tpc.parseCache.Store(timeStr, t)
|
|
||||||
}
|
|
||||||
return t, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildTimeString efficiently builds a time string from date and time components
|
// BuildTimeString efficiently builds a time string from date and time components
|
||||||
|
|||||||
@@ -33,19 +33,10 @@ type LoggerEntry interface {
|
|||||||
Errorf(format string, args ...interface{})
|
Errorf(format string, args ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoggerInterface defines the top-level logging interface (decoupled from logrus)
|
// LoggerInterface defines the top-level logging interface (decoupled from logrus).
|
||||||
|
// It embeds LoggerEntry since both interfaces share the same method signatures.
|
||||||
type LoggerInterface interface {
|
type LoggerInterface interface {
|
||||||
WithField(key string, value interface{}) LoggerEntry
|
LoggerEntry
|
||||||
WithFields(fields Fields) LoggerEntry
|
|
||||||
WithError(err error) LoggerEntry
|
|
||||||
Debug(args ...interface{})
|
|
||||||
Info(args ...interface{})
|
|
||||||
Warn(args ...interface{})
|
|
||||||
Error(args ...interface{})
|
|
||||||
Debugf(format string, args ...interface{})
|
|
||||||
Infof(format string, args ...interface{})
|
|
||||||
Warnf(format string, args ...interface{})
|
|
||||||
Errorf(format string, args ...interface{})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogCollectionConfig configures log line collection behavior
|
// LogCollectionConfig configures log line collection behavior
|
||||||
|
|||||||
Reference in New Issue
Block a user