mirror of
https://github.com/ivuorinen/f2b.git
synced 2026-03-16 04:01:12 +00:00
622 lines
17 KiB
Go
622 lines
17 KiB
Go
// Package cmd provides common helper functions and utilities for CLI commands.
|
|
// This package contains shared functionality used across multiple f2b commands,
|
|
// including argument validation, error handling, and output formatting helpers.
|
|
package cmd
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ivuorinen/f2b/shared"
|
|
|
|
"github.com/spf13/cobra"
|
|
|
|
"github.com/ivuorinen/f2b/fail2ban"
|
|
)
|
|
|
|
// createTimeoutContext creates a context with the configured command timeout.
|
|
// This helper consolidates the duplicate timeout handling pattern.
|
|
// If base is nil, context.Background() is used.
|
|
func createTimeoutContext(base context.Context, config *Config) (context.Context, context.CancelFunc) {
|
|
if base == nil {
|
|
base = context.Background()
|
|
}
|
|
timeout := shared.DefaultCommandTimeout
|
|
if config != nil && config.CommandTimeout > 0 {
|
|
timeout = config.CommandTimeout
|
|
}
|
|
// #nosec G118 -- cancel is returned to callers who are responsible for calling it
|
|
return context.WithTimeout(base, timeout)
|
|
}
|
|
|
|
// IsCI detects if we're running in a CI environment
|
|
func IsCI() bool {
|
|
return fail2ban.IsCI()
|
|
}
|
|
|
|
// IsTestEnvironment detects if we're running in a test environment
|
|
func IsTestEnvironment() bool {
|
|
return fail2ban.IsTestEnvironment()
|
|
}
|
|
|
|
// Command creation helpers
|
|
|
|
// NewCommand creates a new cobra command with standard setup
|
|
func NewCommand(use, short string, aliases []string, runE func(*cobra.Command, []string) error) *cobra.Command {
|
|
return &cobra.Command{
|
|
Use: use,
|
|
Short: short,
|
|
Aliases: aliases,
|
|
RunE: runE,
|
|
}
|
|
}
|
|
|
|
// NewContextualCommand creates a command with standardized context and logging setup
|
|
func NewContextualCommand(
|
|
use, short string,
|
|
aliases []string,
|
|
config *Config,
|
|
handler func(context.Context, *cobra.Command, []string) error,
|
|
) *cobra.Command {
|
|
return NewCommand(use, short, aliases, func(cmd *cobra.Command, args []string) error {
|
|
// Get the contextual logger
|
|
logger := GetContextualLogger()
|
|
|
|
// Create timeout context based on Cobra's context so signals/cancellations propagate
|
|
ctx, cancel := createTimeoutContext(cmd.Context(), config)
|
|
defer cancel()
|
|
|
|
// Extract command name from use string (first word)
|
|
cmdName := use
|
|
if spaceIndex := strings.Index(use, " "); spaceIndex != -1 {
|
|
cmdName = use[:spaceIndex]
|
|
}
|
|
|
|
// Add command context
|
|
ctx = WithCommand(ctx, cmdName)
|
|
|
|
// Log operation with timing
|
|
return logger.LogOperation(ctx, cmdName+"_command", func() error {
|
|
return handler(ctx, cmd, args)
|
|
})
|
|
})
|
|
}
|
|
|
|
// AddLogFlags adds common log-related flags to a command
|
|
func AddLogFlags(cmd *cobra.Command) {
|
|
cmd.Flags().IntP(shared.FlagLimit, "n", 0, "Show only the last N log lines")
|
|
}
|
|
|
|
// IsSkipCommand returns true if the command doesn't require a fail2ban client
|
|
func IsSkipCommand(command string) bool {
|
|
skipCommands := []string{
|
|
"service",
|
|
"version",
|
|
"test-filter",
|
|
"completion",
|
|
"help",
|
|
}
|
|
|
|
for _, skip := range skipCommands {
|
|
if command == skip {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// AddWatchFlags adds common watch-related flags to a command
|
|
func AddWatchFlags(cmd *cobra.Command, interval *time.Duration) {
|
|
cmd.Flags().DurationVarP(interval, shared.FlagInterval, "i", shared.DefaultPollingInterval, "Polling interval")
|
|
}
|
|
|
|
// Validation helpers
|
|
|
|
// ValidateIPArgument validates that an IP address is provided in args
|
|
func ValidateIPArgument(args []string) (string, error) {
|
|
return ValidateIPArgumentWithContext(context.Background(), args)
|
|
}
|
|
|
|
// ValidateIPArgumentWithContext validates that an IP address is provided in args with context support
|
|
func ValidateIPArgumentWithContext(ctx context.Context, args []string) (string, error) {
|
|
if len(args) < 1 {
|
|
return "", fmt.Errorf("IP address required")
|
|
}
|
|
ip := args[0]
|
|
// Validate the IP address
|
|
if err := fail2ban.CachedValidateIP(ctx, ip); err != nil {
|
|
return "", err
|
|
}
|
|
return ip, nil
|
|
}
|
|
|
|
// ValidateServiceAction validates that a service action is valid
|
|
func ValidateServiceAction(action string) error {
|
|
validActions := map[string]bool{
|
|
"start": true,
|
|
"stop": true,
|
|
"restart": true,
|
|
"status": true,
|
|
"reload": true,
|
|
"enable": true,
|
|
"disable": true,
|
|
}
|
|
|
|
if !validActions[action] {
|
|
return fmt.Errorf(
|
|
"invalid service action: %s. Valid actions: start, stop, restart, status, reload, enable, disable",
|
|
action,
|
|
)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetJailsFromArgs gets jail list from arguments or client
|
|
func GetJailsFromArgs(client fail2ban.Client, args []string, startIndex int) ([]string, error) {
|
|
if len(args) > startIndex {
|
|
return []string{strings.ToLower(args[startIndex])}, nil
|
|
}
|
|
|
|
jails, err := client.ListJails()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return jails, nil
|
|
}
|
|
|
|
// GetJailsFromArgsWithContext gets jail list from arguments or client with timeout context
|
|
func GetJailsFromArgsWithContext(
|
|
ctx context.Context,
|
|
client fail2ban.Client,
|
|
args []string,
|
|
startIndex int,
|
|
) ([]string, error) {
|
|
if len(args) > startIndex {
|
|
return []string{strings.ToLower(args[startIndex])}, nil
|
|
}
|
|
|
|
jails, err := client.ListJailsWithContext(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return jails, nil
|
|
}
|
|
|
|
// ParseOptionalArgs parses optional arguments up to a given count
|
|
func ParseOptionalArgs(args []string, count int) []string {
|
|
result := make([]string, count)
|
|
for i := 0; i < count && i < len(args); i++ {
|
|
result[i] = args[i]
|
|
}
|
|
return result
|
|
}
|
|
|
|
// Error handling helpers
|
|
|
|
// HandleClientError handles client errors with consistent formatting
|
|
func HandleClientError(err error) error {
|
|
if err != nil {
|
|
PrintError(err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// errorPatternMatch defines a pattern and its associated remediation message
|
|
type errorPatternMatch struct {
|
|
patterns []string
|
|
remediation string
|
|
}
|
|
|
|
// errorTypePattern maps error message patterns to their corresponding handler function
|
|
type errorTypePattern struct {
|
|
patterns []string
|
|
handler func(error) error
|
|
}
|
|
|
|
// errorTypePatterns defines patterns for inferring error types from non-contextual errors
|
|
var errorTypePatterns = []errorTypePattern{
|
|
{
|
|
patterns: []string{"invalid", "required", "malformed", "format"},
|
|
handler: HandleValidationError,
|
|
},
|
|
{
|
|
patterns: []string{"permission", "sudo", "unauthorized", "forbidden"},
|
|
handler: HandlePermissionError,
|
|
},
|
|
{
|
|
patterns: []string{"not found", "not running", "connection", "timeout"},
|
|
handler: HandleSystemError,
|
|
},
|
|
}
|
|
|
|
// handleCategorizedError is a shared helper for handling categorized errors with pattern matching
|
|
func handleCategorizedError(
|
|
err error,
|
|
category fail2ban.ErrorCategory,
|
|
patternMatches []errorPatternMatch,
|
|
createError func(error, string) error,
|
|
) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// Check if it's already a contextual error of this category
|
|
var contextErr *fail2ban.ContextualError
|
|
if errors.As(err, &contextErr) && contextErr.GetCategory() == category {
|
|
PrintError(err)
|
|
return err
|
|
}
|
|
|
|
// Check for pattern matches
|
|
errMsg := strings.ToLower(err.Error())
|
|
for _, pm := range patternMatches {
|
|
for _, pattern := range pm.patterns {
|
|
if strings.Contains(errMsg, pattern) {
|
|
newErr := createError(err, pm.remediation)
|
|
PrintError(newErr)
|
|
return newErr
|
|
}
|
|
}
|
|
}
|
|
|
|
return HandleClientError(err)
|
|
}
|
|
|
|
// HandleValidationError specifically handles validation errors with clearer messaging
|
|
func HandleValidationError(err error) error {
|
|
return handleCategorizedError(
|
|
err,
|
|
fail2ban.ErrorCategoryValidation,
|
|
[]errorPatternMatch{
|
|
{
|
|
patterns: []string{"invalid", "required"},
|
|
remediation: "Check your input parameters and try again. Use --help for usage information.",
|
|
},
|
|
},
|
|
func(err error, remediation string) error {
|
|
return fail2ban.NewValidationError(err.Error(), remediation)
|
|
},
|
|
)
|
|
}
|
|
|
|
// HandlePermissionError specifically handles permission/sudo errors with helpful hints
|
|
func HandlePermissionError(err error) error {
|
|
return handleCategorizedError(
|
|
err,
|
|
fail2ban.ErrorCategoryPermission,
|
|
[]errorPatternMatch{
|
|
{
|
|
patterns: []string{"permission denied", "sudo"},
|
|
remediation: "Try running with sudo privileges or check that fail2ban service is running.",
|
|
},
|
|
},
|
|
func(err error, remediation string) error {
|
|
return fail2ban.NewPermissionError(err.Error(), remediation)
|
|
},
|
|
)
|
|
}
|
|
|
|
// HandleSystemError specifically handles system-level errors with diagnostic hints
|
|
func HandleSystemError(err error) error {
|
|
return handleCategorizedError(
|
|
err,
|
|
fail2ban.ErrorCategorySystem,
|
|
[]errorPatternMatch{
|
|
{
|
|
patterns: []string{"not found", "command not found"},
|
|
remediation: "Ensure fail2ban is installed and fail2ban-client is in your PATH.",
|
|
},
|
|
{
|
|
patterns: []string{"not running", "connection refused"},
|
|
remediation: "Start the fail2ban service: sudo systemctl start fail2ban",
|
|
},
|
|
},
|
|
func(err error, remediation string) error {
|
|
return fail2ban.NewSystemError(err.Error(), remediation, err)
|
|
},
|
|
)
|
|
}
|
|
|
|
// HandleErrorWithContext automatically chooses the appropriate error handler based on error context
|
|
func HandleErrorWithContext(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// Check if it's already a contextual error and route accordingly
|
|
var contextErr *fail2ban.ContextualError
|
|
if errors.As(err, &contextErr) {
|
|
switch contextErr.GetCategory() {
|
|
case fail2ban.ErrorCategoryValidation:
|
|
return HandleValidationError(err)
|
|
case fail2ban.ErrorCategoryPermission:
|
|
return HandlePermissionError(err)
|
|
case fail2ban.ErrorCategorySystem:
|
|
return HandleSystemError(err)
|
|
default:
|
|
return HandleClientError(err)
|
|
}
|
|
}
|
|
|
|
// For non-contextual errors, try to infer the type from patterns
|
|
errMsg := strings.ToLower(err.Error())
|
|
for _, ep := range errorTypePatterns {
|
|
for _, pattern := range ep.patterns {
|
|
if strings.Contains(errMsg, pattern) {
|
|
return ep.handler(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Default to generic client error handling
|
|
return HandleClientError(err)
|
|
}
|
|
|
|
// Output helpers
|
|
|
|
// OutputResults outputs results in the specified format
|
|
func OutputResults(cmd *cobra.Command, results interface{}, config *Config) {
|
|
if config != nil && config.Format == JSONFormat {
|
|
PrintOutputTo(GetCmdOutput(cmd), results, JSONFormat)
|
|
} else {
|
|
PrintOutputTo(GetCmdOutput(cmd), results, PlainFormat)
|
|
}
|
|
}
|
|
|
|
// InterpretBanStatus interprets ban operation status codes
|
|
func InterpretBanStatus(code int, operation string) string {
|
|
switch operation {
|
|
case shared.MetricsBan:
|
|
if code == 1 {
|
|
return "Already banned"
|
|
}
|
|
return "Banned"
|
|
case shared.MetricsUnban:
|
|
if code == 1 {
|
|
return "Already unbanned"
|
|
}
|
|
return "Unbanned"
|
|
default:
|
|
return "Unknown"
|
|
}
|
|
}
|
|
|
|
// Operation result types
|
|
|
|
// OperationResult represents the result of a jail operation
|
|
type OperationResult struct {
|
|
IP string `json:"ip"`
|
|
Jail string `json:"jail"`
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
// OperationType defines a ban or unban operation with its associated metadata
|
|
type OperationType struct {
|
|
// MetricsType is the metrics key for this operation (e.g., shared.MetricsBan)
|
|
MetricsType string
|
|
// Message is the log message for this operation (e.g., shared.MsgBanResult)
|
|
Message string
|
|
// Operation is the function to execute without context
|
|
Operation func(client fail2ban.Client, ip, jail string) (int, error)
|
|
// OperationCtx is the function to execute with context
|
|
OperationCtx func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error)
|
|
}
|
|
|
|
// BanOperationType defines the ban operation
|
|
var BanOperationType = OperationType{
|
|
MetricsType: shared.MetricsBan,
|
|
Message: shared.MsgBanResult,
|
|
Operation: func(c fail2ban.Client, ip, jail string) (int, error) {
|
|
return c.BanIP(ip, jail)
|
|
},
|
|
OperationCtx: func(ctx context.Context, c fail2ban.Client, ip, jail string) (int, error) {
|
|
return c.BanIPWithContext(ctx, ip, jail)
|
|
},
|
|
}
|
|
|
|
// UnbanOperationType defines the unban operation
|
|
var UnbanOperationType = OperationType{
|
|
MetricsType: shared.MetricsUnban,
|
|
Message: shared.MsgUnbanResult,
|
|
Operation: func(c fail2ban.Client, ip, jail string) (int, error) {
|
|
return c.UnbanIP(ip, jail)
|
|
},
|
|
OperationCtx: func(ctx context.Context, c fail2ban.Client, ip, jail string) (int, error) {
|
|
return c.UnbanIPWithContext(ctx, ip, jail)
|
|
},
|
|
}
|
|
|
|
// ProcessOperation processes operations across multiple jails using the specified operation type
|
|
func ProcessOperation(
|
|
client fail2ban.Client,
|
|
ip string,
|
|
jails []string,
|
|
opType OperationType,
|
|
) ([]OperationResult, error) {
|
|
results := make([]OperationResult, 0, len(jails))
|
|
|
|
for _, jail := range jails {
|
|
code, err := opType.Operation(client, ip, jail)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
status := InterpretBanStatus(code, opType.MetricsType)
|
|
Logger.WithFields(map[string]interface{}{
|
|
"ip": ip,
|
|
"jail": jail,
|
|
"status": status,
|
|
}).Info(opType.Message)
|
|
|
|
results = append(results, OperationResult{
|
|
IP: ip,
|
|
Jail: jail,
|
|
Status: status,
|
|
})
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// ProcessOperationWithContext processes operations across multiple jails with timeout context
|
|
func ProcessOperationWithContext(
|
|
ctx context.Context,
|
|
client fail2ban.Client,
|
|
ip string,
|
|
jails []string,
|
|
opType OperationType,
|
|
) ([]OperationResult, error) {
|
|
logger := GetContextualLogger()
|
|
results := make([]OperationResult, 0, len(jails))
|
|
|
|
for _, jail := range jails {
|
|
// Add jail to context for this operation
|
|
jailCtx := WithJail(ctx, jail)
|
|
|
|
// Time the operation
|
|
start := time.Now()
|
|
code, err := opType.OperationCtx(jailCtx, client, ip, jail)
|
|
duration := time.Since(start)
|
|
|
|
if err != nil {
|
|
// Log the failed operation with timing
|
|
logger.LogBanOperation(jailCtx, opType.MetricsType, ip, jail, false, duration)
|
|
return nil, err
|
|
}
|
|
|
|
status := InterpretBanStatus(code, opType.MetricsType)
|
|
|
|
// Log the successful operation with timing
|
|
logger.LogBanOperation(jailCtx, opType.MetricsType, ip, jail, true, duration)
|
|
|
|
// Log the operation-specific message (ban vs unban)
|
|
Logger.WithFields(map[string]interface{}{
|
|
"ip": ip,
|
|
"jail": jail,
|
|
"status": status,
|
|
}).Info(opType.Message)
|
|
|
|
results = append(results, OperationResult{
|
|
IP: ip,
|
|
Jail: jail,
|
|
Status: status,
|
|
})
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// ProcessBanOperation processes ban operations across multiple jails
|
|
func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
|
return ProcessOperation(client, ip, jails, BanOperationType)
|
|
}
|
|
|
|
// ProcessBanOperationWithContext processes ban operations across multiple jails with timeout context
|
|
func ProcessBanOperationWithContext(
|
|
ctx context.Context,
|
|
client fail2ban.Client,
|
|
ip string,
|
|
jails []string,
|
|
) ([]OperationResult, error) {
|
|
return ProcessOperationWithContext(ctx, client, ip, jails, BanOperationType)
|
|
}
|
|
|
|
// ProcessUnbanOperation processes unban operations across multiple jails
|
|
func ProcessUnbanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
|
return ProcessOperation(client, ip, jails, UnbanOperationType)
|
|
}
|
|
|
|
// ProcessUnbanOperationWithContext processes unban operations across multiple jails with timeout context
|
|
func ProcessUnbanOperationWithContext(
|
|
ctx context.Context,
|
|
client fail2ban.Client,
|
|
ip string,
|
|
jails []string,
|
|
) ([]OperationResult, error) {
|
|
return ProcessOperationWithContext(ctx, client, ip, jails, UnbanOperationType)
|
|
}
|
|
|
|
// Argument validation helpers
|
|
|
|
// RequireArguments checks that at least n arguments are provided
|
|
func RequireArguments(args []string, n int, errorMsg string) error {
|
|
if len(args) < n {
|
|
return errors.New(errorMsg)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RequireNonEmptyArgument checks that an argument is not empty
|
|
func RequireNonEmptyArgument(arg, name string) error {
|
|
if IsEmptyString(arg) {
|
|
return fmt.Errorf("%s cannot be empty", name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Status output helpers
|
|
|
|
// FormatBannedResult formats banned IP results for output
|
|
func FormatBannedResult(ip string, jails []string) string {
|
|
if len(jails) == 0 {
|
|
return fmt.Sprintf("IP %s is not banned", ip)
|
|
}
|
|
return fmt.Sprintf("IP %s is banned in: %v", ip, jails)
|
|
}
|
|
|
|
// FormatStatusResult formats status results for output
|
|
func FormatStatusResult(jail, status string) string {
|
|
if jail == "" {
|
|
return status
|
|
}
|
|
return fmt.Sprintf("Status for %s:\n%s", jail, status)
|
|
}
|
|
|
|
// String processing helpers
|
|
|
|
// TrimmedString safely trims whitespace and returns empty string when input is empty
|
|
func TrimmedString(s string) string {
|
|
return strings.TrimSpace(s)
|
|
}
|
|
|
|
// IsEmptyString checks if a string is empty after trimming whitespace
|
|
func IsEmptyString(s string) bool {
|
|
return strings.TrimSpace(s) == ""
|
|
}
|
|
|
|
// NonEmptyString checks if a string has content after trimming whitespace
|
|
func NonEmptyString(s string) bool {
|
|
return strings.TrimSpace(s) != ""
|
|
}
|
|
|
|
// Error handling helpers
|
|
|
|
// WrapError provides consistent error wrapping with operation context
|
|
func WrapError(err error, operation string) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("%s failed: %w", operation, err)
|
|
}
|
|
|
|
// WrapErrorf provides formatted error wrapping with context
|
|
func WrapErrorf(err error, format string, args ...interface{}) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
// Append ": %w" to format and add err as final argument for single formatting
|
|
allArgs := append(args, err)
|
|
return fmt.Errorf(format+": %w", allArgs...)
|
|
}
|
|
|
|
// Command output helpers
|
|
|
|
// TrimmedOutput safely trims whitespace from command output bytes
|
|
func TrimmedOutput(output []byte) string {
|
|
return strings.TrimSpace(string(output))
|
|
}
|