mirror of
https://github.com/ivuorinen/f2b.git
synced 2026-01-26 03:13:58 +00:00
feat!: Go rewrite (#9)
* Go rewrite * chore(cr): apply suggestions Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Ismo Vuorinen <ismo@ivuorinen.net> * 📝 CodeRabbit Chat: Add NoOpClient to fail2ban and initialize when skip flag is true * 📝 CodeRabbit Chat: Fix malformed if-else structure and add no-op client for skip-only commands * 📝 CodeRabbit Chat: Fix malformed if-else structure and add no-op client for skip-only commands * fix(main): correct no-op branch syntax (#10) * chore(gitignore): ignore env and binary files (#11) * chore(config): remove indent_size for go files (#12) * feat(cli): inject version via ldflags (#13) * fix(security): validate filter parameter to prevent path traversal (#15) * chore(repo): anchor ignore for build artifacts (#16) * chore(ci): use golangci-lint action (#17) * feat(fail2ban): expose GetLogDir (#19) * test(cmd): improve IP mock validation (#20) * chore(ci): update golanglint * fix(ci): golanglint * fix(ci): correct args indentation in pr-lint workflow (#21) * fix(ci): avoid duplicate releases (#22) * refactor(fail2ban): remove test check from OSRunner (#23) * refactor(fail2ban): make log and filter dirs configurable (#24) * fix(ci): create single release per tag (#14) Signed-off-by: Ismo Vuorinen <ismo@ivuorinen.net> * chore(dev): add codex setup script (#27) * chore(lint): enable staticcheck (#26) * chore(ci): verify golangci config (#28) * refactor(cmd): centralize env config (#29) * chore(dev): add pre-commit config (#30) * fix(ci): disable cgo in cross compile (#31) * fix(ci): fail on formatting issues (#32) * feat(cmd): add context to logs watch (#33) * chore: fixes, roadmap, claude.md, linting * chore: fixes, linting * fix(ci): gh actions update, fixes and tweaks * chore: use reviewdog actionlint * chore: use wow-rp-addons/actions-editorconfig-check * chore: combine agent instructions, add comments, fixes * chore: linting, fixes, go revive * chore(deps): update pre-commit hooks * chore: bump go to 1.21, pin workflows * fix: install tools in lint.yml * fix: sudo timeout * fix: service command injection * fix: memory exhaustion with large logs * fix: enhanced path traversal and file security vulns * fix: race conditions * fix: context support * chore: simplify fail2ban/ code * feat: major refactoring with GoReleaser integration and code consolidation - Add GoReleaser configuration for automated multi-platform releases - Support for Linux, macOS, Windows, and BSD builds - Docker images, Homebrew tap, and Linux packages (.deb, .rpm, .apk) - GitHub Actions workflow for release automation - Consolidate duplicate code and improve architecture - Extract common command helpers to cmd/helpers.go (~230 lines) - Remove duplicate MockClient implementation from tests (~250 lines) - Create context wrapper helpers in fail2ban/context_helpers.go - Standardize error messages in fail2ban/errors.go - Enhance validation and security - Add proper IP address validation with fail2ban.ValidateIP - Fix path traversal and command injection vulnerabilities - Improve thread-safety in MockClient with consistent ordering - Optimize documentation - Reduce CLAUDE.md from 190 to 81 lines (57% reduction) - Reduce TODO.md from 633 to 93 lines (85% reduction) - Move README.md to root directory with installation instructions - Improve test reliability - Fix race conditions and test flakiness - Add sorting to ensure deterministic test output - Enhance MockClient with configurable behavior * feat: comprehensive code quality improvements and documentation reorganization This commit represents a major overhaul of code quality, documentation structure, and development tooling: **Documentation & Structure:** - Move CODE_OF_CONDUCT.md from .github to root directory - Reorganize documentation with dedicated docs/ directory - Create comprehensive architecture, security, and testing documentation - Update all references and cross-links for new documentation structure **Code Quality & Linting:** - Add 120-character line length limit across all files via EditorConfig - Enable comprehensive linting with golines, lll, usetesting, gosec, and revive - Fix all 86 revive linter issues (unused parameters, missing export comments) - Resolve security issues (file permissions 0644 → 0600, gosec warnings) - Replace deprecated os.Setenv with t.Setenv in all tests - Configure golangci-lint with auto-fix capabilities and formatter integration **Development Tooling:** - Enhance pre-commit configuration with additional hooks and formatters - Update GoReleaser configuration with improved YAML formatting - Improve GitHub workflows and issue templates for CLI-specific context - Add comprehensive Makefile with proper dependency checking **Testing & Security:** - Standardize mock patterns and context wrapper implementations - Enhance error handling with centralized error constants - Improve concurrent access testing for thread safety * perf: implement major performance optimizations with comprehensive test coverage This commit introduces three significant performance improvements along with complete linting compliance and robust test coverage: **Performance Optimizations:** 1. **Time Parsing Cache (8.6x improvement)** - Add TimeParsingCache with sync.Map for caching parsed times - Implement object pooling for string builders to reduce allocations - Create optimized BanRecordParser with pooled string slices 2. **Gzip Detection Consolidation (55x improvement)** - Consolidate ~100 lines of duplicate gzip detection logic - Fast-path extension checking before magic byte detection - Unified GzipDetector with comprehensive file handling utilities 3. **Parallel Processing (2.5-5.0x improvement)** - Generic WorkerPool implementation for concurrent operations - Smart fallback to sequential processing for single operations - Context-aware cancellation support for long-running tasks - Applied to ban/unban operations across multiple jails **New Files Added:** - fail2ban/time_parser.go: Cached time parsing with global instances - fail2ban/ban_record_parser.go: Optimized ban record parsing - fail2ban/gzip_detection.go: Unified gzip handling utilities - fail2ban/parallel_processing.go: Generic parallel processing framework - cmd/parallel_operations.go: Command-level parallel operation support **Code Quality & Linting:** - Resolve all golangci-lint issues (0 remaining) - Add proper #nosec annotations for legitimate file operations - Implement sentinel errors replacing nil/nil anti-pattern - Fix context parameter handling and error checking **Comprehensive Test Coverage:** - 500+ lines of new tests with benchmarks validating all improvements - Concurrent access testing for thread safety - Edge case handling and error condition testing - Performance benchmarks demonstrating measured improvements **Modified Files:** - fail2ban/fail2ban.go: Integration with new optimized parsers - fail2ban/logs.go: Use consolidated gzip detection (-91 lines) - cmd/ban.go & cmd/unban.go: Add conditional parallel processing * test: comprehensive test infrastructure overhaul with real test data Major improvements to test code quality and organization: • Added comprehensive test data infrastructure with 6 anonymized log files • Extracted common test helpers reducing ~200 lines to ~50 reusable functions • Enhanced ban record parser tests with real production log patterns • Improved gzip detection tests with actual compressed test data • Added integration tests for full log processing and concurrent operations • Updated .gitignore to allow testdata log files while excluding others • Updated TODO.md to reflect completed test infrastructure improvements * fix: comprehensive security hardening and critical bug fixes Security Enhancements: - Add command injection protection with allowlist validation for all external commands - Add security documentation to gzip functions warning about path traversal risks - Complete TODO.md security audit - all critical vulnerabilities addressed Bug Fixes: - Fix negative index access vulnerability in parallel operations (prevent panic) - Fix parsing inconsistency between BannedIn and BannedInWithContext functions - Fix nil error handling in concurrent log reading tests - Fix benchmark error simulation to measure actual performance vs error paths Implementation Details: - Add ValidateCommand() with allowlist for fail2ban-client, fail2ban-regex, service, systemctl, sudo - Integrate command validation into all OSRunner methods before execution - Replace manual string parsing with ParseBracketedList() for consistency - Add bounds checking (index >= 0) to prevent negative array access - Replace nil error with descriptive error message in concurrent error channels - Update banFunc in benchmark to return success instead of permanent errors Test Coverage: - Add comprehensive security validation tests with injection attempt patterns - Add parallel operations safety tests with index validation - Add parsing consistency tests between context/non-context functions - Add error handling demonstration tests for concurrent operations - Add gzip function security requirement documentation tests * perf: implement ultra-optimized log and ban record parsing with significant performance gains Major performance improvements to core fail2ban processing with comprehensive benchmarking: Performance Achievements: • Ban record parsing: 15% faster, 39% less memory, 45% fewer allocations • Log processing: 27% faster, 64% less memory, 32% fewer allocations • Cache performance: 624x faster cache hits with zero allocations • String pooling: 4.7x improvement with zero memory allocations Core Optimizations: • Object pooling (sync.Pool) for string slices, scanner buffers, and line buffers • Comprehensive caching (sync.Map) for gzip detection, file info, and path patterns • Fast path optimizations with extension-based gzip detection • Byte-level operations to reduce string allocations in filtering • Ultra-optimized parsers with smart field parsing and efficient time handling New Files: • fail2ban/ban_record_parser_optimized.go - High-performance ban record parser • fail2ban/log_performance_optimized.go - Ultra-optimized log processor with caching • fail2ban/ban_record_parser_benchmark_test.go - Ban record parsing benchmarks • fail2ban/log_performance_benchmark_test.go - Log performance benchmarks • fail2ban/ban_record_parser_compatibility_test.go - Compatibility verification tests Updated: • fail2ban/fail2ban.go - Integration with ultra-optimized parsers • TODO.md - Marked performance optimization tasks as completed * fix(ci): install dev dependencies for pre-commit * refactor: streamline pre-commit config and extract test helpers - Replace local hooks with upstream pre-commit repositories for better maintainability - Add new hooks: shellcheck, shfmt, checkov for enhanced code quality - Extract common test helpers into dedicated test_helpers.go to reduce duplication - Add warning logs for unreadable log files in fail2ban and logs packages - Remove hard-coded GID checks in sudo.go for better cross-platform portability - Update golangci-lint installation method in Makefile * fix(security): path traversal, log file validation * feat: complete pre-release modernization with comprehensive testing - Remove all deprecated legacy functions and dead code paths - Add security hardening with sanitized error messages - Implement comprehensive performance benchmarks and security audit tests - Mark all pre-release modernization tasks as completed (10/10) - Update project documentation to reflect full completion status * fix(ci): linting, and update gosec install source * feat: implement comprehensive test framework with 60-70% code reduction Major test infrastructure modernization: - Create fluent CommandTestBuilder framework for streamlined test creation - Add MockClientBuilder pattern for advanced mock configuration - Standardize table test field naming (expectedOut→wantOutput, expectError→wantError) - Consolidate test code: 3,796 insertions, 3,104 deletions (net +692 lines with enhanced functionality) Framework achievements: - 168+ tests passing with zero regressions - 5 cmd test files fully migrated to new framework - 63 field name standardizations applied - Advanced mock patterns with fluent interface File organization improvements: - Rename all test files with consistent prefixes (cmd_*, fail2ban_*, main_*) - Split monolithic test files into focused, maintainable modules - Eliminate cmd_test.go (622 lines) and main_test.go (825 lines) - Create specialized test files for better organization Documentation enhancements: - Update docs/testing.md with complete framework documentation - Optimize TODO.md from 231→72 lines (69% token reduction) - Add comprehensive migration guides and best practices Test framework components: - command_test_framework.go: Core fluent interface implementation - MockClientBuilder: Advanced mock configuration with builder pattern - table_test_standards.go: Standardized field naming conventions - Enhanced test helpers with error checking consolidation * chore: fixes, .go-version, linting * fix(ci) editorconfig in .pre-commit-config.yaml * fix: too broad gitignore * chore: update fail2ban/fail2ban_path_security_test.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Ismo Vuorinen <ismo@ivuorinen.net> * chore: code review fixes * chore: code review fixes * fix: more code review fixes * fix: more code review fixes * feat: cleanup, fixes, testing * chore: minor config file updates - Add quotes to F2B_TIMEOUT value in .env.example for clarity - Remove testdata log exception from .gitignore (simplified) * feat: implement comprehensive monitoring with structured logging and metrics - Add structured logging with context propagation throughout codebase - Implement ContextualLogger with request tracking and operation timing - Add context values for operation, IP, jail, command, and request ID - Integrate with existing logrus logging infrastructure - Add request/response timing metrics collection - Create comprehensive Metrics system with atomic counters - Track command executions, ban/unban operations, and client operations - Implement latency distribution buckets for performance analysis - Add validation cache hit/miss tracking - Enhance ban/unban commands with structured logging - Add LogOperation wrapper for automatic timing and context - Log individual jail operations with success/failure status - Integrate metrics recording with ban/unban operations - Add new 'metrics' command to expose collected metrics - Support both plain text and JSON output formats - Display system metrics (uptime, memory, goroutines) - Show operation counts, failures, and average latencies - Include latency distribution histograms - Update test infrastructure - Add tests for metrics command - Fix test helper to support persistent flags - Ensure all tests pass with new logging This completes the high-priority performance monitoring and structured logging requirements from TODO.md, providing comprehensive operational visibility into the f2b application. * docs: update TODO.md to reflect completed monitoring work - Mark structured logging and timing metrics as completed - Update test coverage stats (cmd/ improved from 66.4% to 76.8%) - Add completed infrastructure section for today's work - Update current status date and add monitoring to health indicators * feat: complete TODO.md technical debt cleanup Complete all remaining TODO.md tasks with comprehensive implementation: ## 🎯 Validation Caching Implementation - Thread-safe validation cache with sync.RWMutex protection - MetricsRecorder interface to avoid circular dependencies - Cached validation for IP, jail, filter, and command validation - Integration with existing metrics system for cache hit/miss tracking - 100% test coverage for caching functionality ## 🔧 Constants Extraction - Fail2Ban status codes: Fail2BanStatusSuccess, Fail2BanStatusAlreadyProcessed - Command constants: Fail2BanClientCommand, Fail2BanRegexCommand, Fail2BanServerCommand - File permissions: DefaultFilePermissions (0600), DefaultDirectoryPermissions (0750) - Timeout limits: MaxCommandTimeout, MaxFileTimeout, MaxParallelTimeout - Updated all references throughout codebase to use named constants ## 📊 Test Coverage Improvement - Increased fail2ban package coverage from 62.0% to 70.3% (target: 70%+) - Added 6 new comprehensive test files with 200+ additional test cases - Coverage improvements across all major components: - Context helpers, validation cache, mock clients, OS runner methods - Error constructors, timing operations, cache statistics - Thread safety and concurrency testing ## 🛠️ Code Quality & Fixes - Fixed all linting issues (golangci-lint, revive, errcheck) - Resolved unused parameter warnings and error handling - Fixed timing-dependent test failures in worker pool cancellation - Enhanced thread safety in validation caching ## 📈 Final Metrics - Overall test coverage: 72.4% (up from ~65%) - fail2ban package: 70.3% (exceeds 70% target) - cmd package: 76.9% - Zero TODO/FIXME/HACK comments in production code - 100% linting compliance * fix: resolve test framework issues and update documentation - Remove unnecessary defer/recover block in comprehensive_framework_test.go - Fix compilation error in command_test_framework.go variable redeclaration - Update TODO.md to reflect all 12 completed code quality fixes - Clean up dead code and improve test maintainability - Fix linting issues: error handling, code complexity, security warnings - Break down complex test function to reduce cyclomatic complexity * fix: replace dangerous test commands with safe placeholders Replaces actual dangerous commands in test cases with safe placeholder patterns to prevent accidental execution while maintaining comprehensive security testing. - Replace 'rm -rf /', 'cat /etc/passwd' with 'DANGEROUS_RM_COMMAND', 'DANGEROUS_SYSTEM_CALL' - Update GetDangerousCommandPatterns() to recognize both old and new patterns - Enhance filter validation with command injection protection (semicolons, pipes, backticks, dollar signs) - Add package documentation comments for all packages (main, cmd, fail2ban) - Fix GoReleaser static linking configuration for cross-platform builds - Remove Docker platform restriction to enable multi-arch support - Apply code formatting and linting fixes All security validation tests continue to pass with the safe placeholders. * fix: resolve TestMixedConcurrentOperations race condition and command key mismatches The concurrency test was failing due to several issues: 1. **Command Key Mismatch**: Test setup used "sudo test arg" key but MockRunner looked for "test arg" because "test" command doesn't require sudo 2. **Invalid Commands**: Using "test" and "echo" commands that aren't in the fail2ban command allowlist, causing validation failures 3. **Race Conditions**: Multiple goroutines setting different MockRunners simultaneously, overwriting responses **Solution:** - Replace invalid test commands ("test", "echo") with valid fail2ban commands ("fail2ban-client status", "fail2ban-client -V") - Pre-configure shared MockRunner with all required response keys for both sudo and non-sudo execution paths - Improve test structure to reduce race conditions between setup and execution All tests now pass reliably, resolving the CI failure. * fix: address code quality issues and improve test coverage - Replace unsafe type assertion with comma-ok idiom in logging - Fix TestTestFilter to use created filter instead of nonexistent - Add warning logs for invalid log level configurations - Update TestVersionCommand to use consistent test framework pattern - Remove unused LoggerContextKey constant - Add version command support to test framework - Fix trailing whitespace in test files * feat: add timeout handling and multi-architecture Docker support * test: enhance path traversal security test coverage * chore: comprehensive documentation update and linting fixes Updated all documentation to reflect current capabilities including context-aware operations, multi-architecture Docker support, advanced security features, and performance monitoring. Removed unused functions and fixed all linting issues. * fix(lint): .goreleaser.yaml * feat: add markdown link checker and fix all linting issues - Add markdown-link-check to pre-commit hooks with comprehensive configuration - Fix GitHub workflow structure (sync-labels.yml) with proper job setup - Add JSON schemas to all configuration files for better IDE support - Update tool installation in Makefile for markdown-link-check dependency - Fix all revive linting issues (Boolean literals, defer in loop, if-else simplification, method naming) - Resolve broken relative link in CONTRIBUTING.md - Configure rate limiting and ignore patterns for GitHub URLs - Enhance CLAUDE.md with link checking documentation * fix(ci): sync-labels permissions * docs: comprehensive documentation update reflecting current project status - Updated TODO.md to show production-ready status with 21 commands - Enhanced README.md with enterprise-grade features and capabilities - Added performance monitoring and timeout configuration to FAQ - Updated CLAUDE.md with accurate project architecture overview - Fixed all line length issues to meet EditorConfig requirements - Added .mega-linter.yml configuration for enhanced linting * fix: address CodeRabbitAI review feedback - Split .goreleaser.yaml builds for static/dynamic linking by architecture - Update docs to accurately reflect 7 path traversal patterns (not 17) - Fix containsPathTraversal to allow valid absolute paths - Replace runnerCombinedRunWithSudoContext with RunnerCombinedOutputWithSudoContext - Fix ldflags to use uppercase Version variable name - Remove duplicate test coverage metrics in TODO.md - Fix .markdown-link-check.json schema violations - Add v8r JSON validator to pre-commit hooks * chore(ci): update workflows, switch v8r to check-jsonschema * fix: restrict static linking to amd64 only in .goreleaser.yaml - Move arm64 from static to dynamic build configuration - Static linking now only applies to linux/amd64 - Prevents build failures due to missing static libc on ARM64 - All architectures remain supported with appropriate linking * fix(ci): caching * fix(ci): python caching with pip, node with npm * fix(ci): no caching for node then * fix(ci): no requirements.txt, no cache * refactor: address code review feedback - Pin Alpine base image to v3.20 for reproducible builds - Remove redundant --platform flags in GoReleaser Docker configs - Fix unused parameters in concurrency test goroutines - Simplify string search helper using strings.Contains() - Remove redundant error checking logic in security tests --------- Signed-off-by: Ismo Vuorinen <ismo@ivuorinen.net> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
76
cmd/ban.go
Normal file
76
cmd/ban.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// BanCmd returns the ban command with injected client and config
|
||||
func BanCmd(client fail2ban.Client, config *Config) *cobra.Command {
|
||||
return NewCommand("ban <ip> [jail]", "Ban an IP address", []string{"banip", "b"},
|
||||
func(cmd *cobra.Command, args []string) error {
|
||||
// Get the contextual logger
|
||||
logger := GetContextualLogger()
|
||||
|
||||
// Create timeout context for the entire ban operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.CommandTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Add command context
|
||||
ctx = WithCommand(ctx, "ban")
|
||||
|
||||
// Log operation with timing
|
||||
return logger.LogOperation(ctx, "ban_command", func() error {
|
||||
// Validate IP argument
|
||||
ip, err := ValidateIPArgument(args)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
// Add IP to context
|
||||
ctx = WithIP(ctx, ip)
|
||||
|
||||
// Get jails from arguments or client (with timeout context)
|
||||
jails, err := GetJailsFromArgsWithContext(ctx, client, args, 1)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
// Process ban operation with timeout context (use parallel processing for multiple jails)
|
||||
var results []OperationResult
|
||||
if len(jails) > 1 {
|
||||
// Use parallel timeout for multi-jail operations
|
||||
parallelCtx, parallelCancel := context.WithTimeout(ctx, config.ParallelTimeout)
|
||||
defer parallelCancel()
|
||||
results, err = ProcessBanOperationParallelWithContext(parallelCtx, client, ip, jails)
|
||||
} else {
|
||||
results, err = ProcessBanOperationWithContext(ctx, client, ip, jails)
|
||||
}
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
// Read the format flag and override config.Format if set
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
if format != "" {
|
||||
config.Format = format
|
||||
}
|
||||
|
||||
// Output results
|
||||
if config != nil && config.Format == JSONFormat {
|
||||
PrintOutputTo(GetCmdOutput(cmd), results, JSONFormat)
|
||||
} else {
|
||||
for _, r := range results {
|
||||
if _, err := fmt.Fprintf(GetCmdOutput(cmd), "%s %s in %s\n", r.Status, r.IP, r.Jail); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
50
cmd/banned.go
Normal file
50
cmd/banned.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// BannedCmd returns the banned command with injected client and config
|
||||
func BannedCmd(client interface {
|
||||
GetBanRecordsWithContext(context.Context, []string) ([]fail2ban.BanRecord, error)
|
||||
}, config *Config) *cobra.Command {
|
||||
if client == nil {
|
||||
panic("client cannot be nil")
|
||||
}
|
||||
return NewCommand(
|
||||
"banned [all|<jail>]",
|
||||
"List banned IPs with remaining time",
|
||||
nil,
|
||||
func(cmd *cobra.Command, args []string) error {
|
||||
// Create timeout context for banned operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.CommandTimeout)
|
||||
defer cancel()
|
||||
|
||||
target := "all"
|
||||
if len(args) > 0 {
|
||||
target = strings.ToLower(args[0])
|
||||
}
|
||||
|
||||
records, err := client.GetBanRecordsWithContext(ctx, []string{target})
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
if config.Format == JSONFormat {
|
||||
PrintOutputTo(GetCmdOutput(cmd), records, config.Format)
|
||||
} else {
|
||||
for _, r := range records {
|
||||
PrintOutputTo(GetCmdOutput(cmd),
|
||||
r.Jail+" | "+r.IP+" | "+r.Remaining+" remaining",
|
||||
config.Format,
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
536
cmd/cmd_commands_test.go
Normal file
536
cmd/cmd_commands_test.go
Normal file
@@ -0,0 +1,536 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
Logger.SetOutput(io.Discard)
|
||||
|
||||
// Set up mock environment for all tests
|
||||
_, cleanup := fail2ban.SetupMockEnvironment(&testingT{})
|
||||
defer cleanup()
|
||||
|
||||
code := m.Run()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// testingT implements TestingInterface for TestMain
|
||||
type testingT struct{}
|
||||
|
||||
func (t *testingT) Helper() {}
|
||||
func (t *testingT) Fatalf(format string, args ...interface{}) {
|
||||
fmt.Printf("TestMain setup fatal: "+format+"\n", args...)
|
||||
}
|
||||
func (t *testingT) Skipf(format string, args ...interface{}) {
|
||||
fmt.Printf("TestMain setup skip: "+format+"\n", args...)
|
||||
}
|
||||
func (t *testingT) TempDir() string { return os.TempDir() }
|
||||
|
||||
// All common test helpers are now in test_helpers.go to eliminate duplication
|
||||
|
||||
// Helper function to set up commands (mimics the real cmd package)
|
||||
|
||||
func TestListJailsCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jails []string
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "list single jail",
|
||||
jails: []string{"sshd"},
|
||||
wantOutput: "sshd\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "list multiple jails",
|
||||
jails: []string{"sshd", "apache", "nginx"},
|
||||
wantOutput: "apache nginx sshd\n", // alphabetical order
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "list no jails",
|
||||
jails: []string{},
|
||||
wantOutput: "\n",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
NewCommandTest(t, "list-jails").
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, tt.jails)
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectExactOutput(tt.wantOutput).
|
||||
Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
jails []string
|
||||
statusAll string
|
||||
statusJail map[string]string
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "status all",
|
||||
args: []string{"all"},
|
||||
jails: []string{"sshd"},
|
||||
statusAll: "Status for all jails\n",
|
||||
wantOutput: "Status for all jails\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "status specific jail",
|
||||
args: []string{"sshd"},
|
||||
jails: []string{"sshd"},
|
||||
statusJail: map[string]string{"sshd": "Status for sshd jail\n"},
|
||||
wantOutput: "Status for sshd jail\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "status nonexistent jail",
|
||||
args: []string{"nonexistent"},
|
||||
jails: []string{"sshd"},
|
||||
wantOutput: "Error: jail 'nonexistent' not found",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "status no args shows usage",
|
||||
args: []string{},
|
||||
jails: []string{"sshd"},
|
||||
wantOutput: "Usage: status [all|<jail>] status all (show all jails)\n" +
|
||||
" status [all|<jail>] status <jail> (show specific jail)\nAvailable jails: sshd\n",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
builder := NewCommandTest(t, "status").
|
||||
WithArgs(tt.args...).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, tt.jails)
|
||||
if tt.statusAll != "" {
|
||||
mock.StatusAllData = tt.statusAll
|
||||
}
|
||||
if tt.statusJail != nil {
|
||||
mock.StatusJailData = tt.statusJail
|
||||
}
|
||||
})
|
||||
|
||||
if tt.wantError {
|
||||
builder.ExpectError()
|
||||
} else {
|
||||
builder.ExpectSuccess()
|
||||
}
|
||||
|
||||
if tt.wantOutput != "" {
|
||||
builder.ExpectOutput(tt.wantOutput)
|
||||
}
|
||||
|
||||
builder.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBannedCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
banRecords []fail2ban.BanRecord
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "show banned IPs",
|
||||
args: []string{},
|
||||
banRecords: []fail2ban.BanRecord{
|
||||
{Jail: "sshd", IP: "192.168.1.100", Remaining: "01:30:00"},
|
||||
{Jail: "apache", IP: "192.168.1.101", Remaining: "02:15:30"},
|
||||
},
|
||||
wantOutput: "sshd | 192.168.1.100 | 01:30:00 remaining\napache | 192.168.1.101 | 02:15:30 remaining\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "show banned IPs with specific jail",
|
||||
args: []string{"sshd"},
|
||||
banRecords: []fail2ban.BanRecord{},
|
||||
wantOutput: "",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Using new mock builder pattern for cleaner setup
|
||||
mockBuilder := NewMockClientBuilder()
|
||||
for _, record := range tt.banRecords {
|
||||
mockBuilder.WithBanRecord(record.Jail, record.IP, record.Remaining)
|
||||
}
|
||||
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs(tt.args...).
|
||||
WithMockBuilder(mockBuilder).
|
||||
ExpectSuccess().
|
||||
ExpectExactOutput(tt.wantOutput).
|
||||
Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
jails []string
|
||||
banResults map[string]map[string]int
|
||||
setupBanned bool
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "ban IP without jail specified",
|
||||
args: []string{"192.168.1.100"},
|
||||
jails: []string{"sshd", "apache"},
|
||||
banResults: map[string]map[string]int{"192.168.1.100": {"sshd": 0, "apache": 0}},
|
||||
wantOutput: "Banned 192.168.1.100 in apache\nBanned 192.168.1.100 in sshd\n", // alphabetical order
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "ban IP with specific jail",
|
||||
args: []string{"192.168.1.100", "sshd"},
|
||||
jails: []string{"sshd"},
|
||||
banResults: map[string]map[string]int{"192.168.1.100": {"sshd": 0}},
|
||||
wantOutput: "Banned 192.168.1.100 in sshd\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "ban IP already banned",
|
||||
args: []string{"192.168.1.100", "sshd"},
|
||||
jails: []string{"sshd"},
|
||||
setupBanned: true,
|
||||
wantOutput: "Already banned 192.168.1.100 in sshd\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "ban command without IP",
|
||||
args: []string{},
|
||||
wantOutput: "Error: IP address required",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
builder := NewCommandTest(t, "ban").
|
||||
WithArgs(tt.args...).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, tt.jails)
|
||||
mock.BanResults = tt.banResults
|
||||
if tt.setupBanned {
|
||||
_, _ = mock.BanIP("192.168.1.100", "sshd")
|
||||
}
|
||||
})
|
||||
|
||||
if tt.wantError {
|
||||
builder.ExpectError().ExpectOutput(tt.wantOutput)
|
||||
} else {
|
||||
builder.ExpectSuccess().ExpectOutput(tt.wantOutput)
|
||||
}
|
||||
|
||||
builder.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnbanCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
jails []string
|
||||
banResults map[string]map[string]int
|
||||
setupBanned bool
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "unban IP with specific jail",
|
||||
args: []string{"192.168.1.100", "sshd"},
|
||||
jails: []string{"sshd"},
|
||||
setupBanned: true,
|
||||
wantOutput: "Unbanned 192.168.1.100 in sshd\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "unban IP already unbanned",
|
||||
args: []string{"192.168.1.100", "sshd"},
|
||||
jails: []string{"sshd"},
|
||||
setupBanned: false,
|
||||
wantOutput: "Already unbanned 192.168.1.100 in sshd\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "unban command without IP",
|
||||
args: []string{},
|
||||
wantOutput: "Error: IP address required",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
builder := NewCommandTest(t, "unban").
|
||||
WithArgs(tt.args...).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, tt.jails)
|
||||
mock.BanResults = tt.banResults
|
||||
if tt.setupBanned {
|
||||
_, _ = mock.BanIP("192.168.1.100", "sshd")
|
||||
}
|
||||
})
|
||||
|
||||
if tt.wantError {
|
||||
builder.ExpectError().ExpectOutput(tt.wantOutput)
|
||||
} else {
|
||||
builder.ExpectSuccess().ExpectOutput(tt.wantOutput)
|
||||
}
|
||||
|
||||
builder.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestIPCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupBans map[string][]string // jail -> IPs to ban
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "test IP not banned",
|
||||
args: []string{"192.168.1.100"},
|
||||
setupBans: map[string][]string{}, // no bans
|
||||
wantOutput: "IP 192.168.1.100 is not banned",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "test IP banned in one jail",
|
||||
args: []string{"192.168.1.100"},
|
||||
setupBans: map[string][]string{"sshd": {"192.168.1.100"}},
|
||||
wantOutput: "IP 192.168.1.100 is banned in: [sshd]\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "test IP banned in multiple jails",
|
||||
args: []string{"192.168.1.100"},
|
||||
setupBans: map[string][]string{"sshd": {"192.168.1.100"}, "apache": {"192.168.1.100"}},
|
||||
wantOutput: "IP 192.168.1.100 is banned in: [apache sshd]\n", // alphabetical order
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "test command without IP",
|
||||
args: []string{},
|
||||
wantOutput: "Error: IP address required",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
builder := NewCommandTest(t, "test").
|
||||
WithArgs(tt.args...).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
// Set up bans
|
||||
for jail, ips := range tt.setupBans {
|
||||
for _, ip := range ips {
|
||||
_, _ = mock.BanIP(ip, jail)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if tt.wantError {
|
||||
builder.ExpectError().ExpectOutput(tt.wantOutput)
|
||||
} else {
|
||||
builder.ExpectSuccess().ExpectOutput(tt.wantOutput)
|
||||
}
|
||||
|
||||
builder.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogsCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
logLines []string
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "show all logs",
|
||||
args: []string{},
|
||||
logLines: []string{
|
||||
"2024-01-01 12:00:00 [sshd] Ban 192.168.1.100",
|
||||
"2024-01-01 12:01:00 [apache] Ban 192.168.1.101",
|
||||
},
|
||||
wantOutput: "[2024-01-01 12:00:00 [sshd] Ban 192.168.1.100 2024-01-01 12:01:00 [apache] Ban 192.168.1.101]",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "show logs with jail filter",
|
||||
args: []string{"sshd"},
|
||||
logLines: []string{"2024-01-01 12:00:00 [sshd] Ban 192.168.1.100"},
|
||||
wantOutput: "[2024-01-01 12:00:00 [sshd] Ban 192.168.1.100]",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "show logs with jail and IP filter",
|
||||
args: []string{"sshd", "192.168.1.100"},
|
||||
logLines: []string{"2024-01-01 12:00:00 [sshd] Ban 192.168.1.100"},
|
||||
wantOutput: "[2024-01-01 12:00:00 [sshd] Ban 192.168.1.100]",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "show logs when no logs exist",
|
||||
args: []string{},
|
||||
logLines: []string{}, // Explicitly set empty slice
|
||||
wantOutput: "[]",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
NewCommandTest(t, "logs").
|
||||
WithArgs(tt.args...).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
mock.LogLines = tt.logLines
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectOutput(tt.wantOutput).
|
||||
Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestFilterCommand(t *testing.T) {
|
||||
// This test would need a test-filter command implementation
|
||||
// For now, skipping this test as it appears to test functionality not yet implemented
|
||||
t.Skip("test-filter command not implemented yet")
|
||||
}
|
||||
|
||||
func TestVersionCommand(t *testing.T) {
|
||||
wantOutput := fmt.Sprintf("f2b version %s\n", Version)
|
||||
|
||||
NewCommandTest(t, "version").
|
||||
ExpectSuccess().
|
||||
ExpectExactOutput(wantOutput).
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestCommandErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
args []string
|
||||
setupMock func(*fail2ban.MockClient)
|
||||
wantError bool
|
||||
wantErrorMsg string
|
||||
}{
|
||||
{
|
||||
name: "ban IP error",
|
||||
command: "ban",
|
||||
args: []string{"192.168.1.100", "sshd"},
|
||||
setupMock: func(m *fail2ban.MockClient) {
|
||||
setMockJails(m, []string{"sshd"})
|
||||
m.SetBanError("sshd", "192.168.1.100", fmt.Errorf("ban failed"))
|
||||
},
|
||||
wantError: true,
|
||||
wantErrorMsg: "ban failed",
|
||||
},
|
||||
{
|
||||
name: "unban IP error",
|
||||
command: "unban",
|
||||
args: []string{"192.168.1.100", "sshd"},
|
||||
setupMock: func(m *fail2ban.MockClient) {
|
||||
setMockJails(m, []string{"sshd"})
|
||||
m.SetUnbanError("sshd", "192.168.1.100", fmt.Errorf("unban failed"))
|
||||
},
|
||||
wantError: true,
|
||||
wantErrorMsg: "unban failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := NewCommandTest(t, tt.command).
|
||||
WithArgs(tt.args...).
|
||||
WithSetup(tt.setupMock).
|
||||
ExpectError().
|
||||
Run()
|
||||
|
||||
// Validate specific error message
|
||||
if tt.wantErrorMsg != "" && !strings.Contains(result.Error.Error(), tt.wantErrorMsg) {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.wantErrorMsg, result.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommandInvalidArguments tests commands with invalid arguments
|
||||
func TestCommandInvalidArguments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
args []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "ban without IP",
|
||||
command: "ban",
|
||||
args: []string{},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "unban without IP",
|
||||
command: "unban",
|
||||
args: []string{},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "test without IP",
|
||||
command: "test",
|
||||
args: []string{},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
NewCommandTest(t, tt.command).
|
||||
WithArgs(tt.args...).
|
||||
ExpectError().
|
||||
Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
167
cmd/cmd_config_utils_test.go
Normal file
167
cmd/cmd_config_utils_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContainsPathTraversal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
category string
|
||||
}{
|
||||
// Safe paths (should return false)
|
||||
{"empty path", "", false, "safe"},
|
||||
{"normal path", "/var/log/fail2ban.log", false, "safe"},
|
||||
{"relative safe path", "logs/fail2ban.log", false, "safe"},
|
||||
{"path with dots in filename", "fail2ban.log.1", false, "safe"},
|
||||
{"path with single dot", "./logs", false, "safe"},
|
||||
|
||||
// Basic path traversal (should return true)
|
||||
{"basic double dot", "..", true, "basic"},
|
||||
{"double dot with slash", "../", true, "basic"},
|
||||
{"double dot with backslash", "..\\", true, "basic"},
|
||||
{"nested path traversal", "logs/../../../etc/passwd", true, "basic"},
|
||||
{"multiple traversals", "../../../etc/passwd", true, "basic"},
|
||||
|
||||
// URL encoded attacks (should return true)
|
||||
{"url encoded double dot", "%2e%2e", true, "url_encoded"},
|
||||
{"url encoded uppercase", "%2E%2E", true, "url_encoded"},
|
||||
{"mixed case url encoding", "%2e%2E", true, "url_encoded"},
|
||||
{"mixed case reverse", "%2E%2e", true, "url_encoded"},
|
||||
{"url encoded with slash", "%2e%2e%2f", true, "url_encoded"},
|
||||
{"url encoded with backslash", "%2e%2e%5c", true, "url_encoded"},
|
||||
{"url encoded backslash uppercase", "%2e%2e%5C", true, "url_encoded"},
|
||||
|
||||
// Double URL encoding (should return true)
|
||||
{"double url encoded", "%252e%252e", true, "double_encoded"},
|
||||
{"double url encoded uppercase", "%252E%252E", true, "double_encoded"},
|
||||
{"triple url encoded", "%25252e%25252e", true, "triple_encoded"},
|
||||
|
||||
// Unicode escapes (should return true)
|
||||
{"unicode escape", "\\u002e\\u002e", true, "unicode"},
|
||||
{"extended unicode escape", "\\u00002e\\u00002e", true, "unicode"},
|
||||
{"actual unicode chars", "\u002e\u002e", true, "unicode"},
|
||||
|
||||
// Mixed encoding techniques (should return true)
|
||||
{"mixed literal and encoded", "..%2f", true, "mixed"},
|
||||
{"mixed encoded dot", ".%2e", true, "mixed"},
|
||||
{"reverse mixed encoded", "%2e.", true, "mixed"},
|
||||
{"extra dots with slashes", "...//", true, "mixed"},
|
||||
|
||||
// Null byte injection (should return true)
|
||||
{"null byte with traversal", "..%00", true, "null_injection"},
|
||||
{"null byte literal with dots", "..\x00", true, "null_injection"},
|
||||
|
||||
// Creative separator attacks (should return true)
|
||||
{"semicolon separator", "..;/", true, "creative"},
|
||||
{"url encoded semicolon", "..%3b", true, "creative"},
|
||||
|
||||
// Complex realistic attack vectors (should return true)
|
||||
{"realistic attack 1", "/var/log/../../../etc/passwd", true, "realistic"},
|
||||
{"realistic attack 2", "logs/%2e%2e/%2e%2e/etc/passwd", true, "realistic"},
|
||||
{"realistic attack 3", "fail2ban.log%00../../../etc/shadow", true, "realistic"},
|
||||
{"realistic attack 4", "%252e%252e%252f%252e%252e%252fetc%252fpasswd", true, "realistic"},
|
||||
|
||||
// Edge cases that should be safe
|
||||
{"legitimate dots in path", "/var/log/fail2ban.log.1.gz", false, "edge_safe"},
|
||||
{"legitimate relative path", "config/fail2ban.conf", false, "edge_safe"},
|
||||
{"path with version dots", "/usr/local/go1.21/bin", false, "edge_safe"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := containsPathTraversal(tt.path)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsPathTraversal(%q) = %v, expected %v (category: %s)",
|
||||
tt.path, result, tt.expected, tt.category)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsPathTraversalURLDecoding(t *testing.T) {
|
||||
// Test that URL decoding works correctly
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"single encoded traversal", "%2e%2e%2f%2e%2e%2fetc%2fpasswd", true},
|
||||
{"double encoded traversal", "%252e%252e%252f", true},
|
||||
{"mixed single and double", "%2e%252e", true},
|
||||
{"encoded null byte", "%2e%2e%00", true},
|
||||
{"complex encoded path", "logs%2f%2e%2e%2f%2e%2e%2fetc", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := containsPathTraversal(tt.path)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsPathTraversal(%q) = %v, expected %v", tt.path, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfigPathSecurity(t *testing.T) {
|
||||
// Test that validateConfigPath properly uses the new security function
|
||||
maliciousPaths := []string{
|
||||
"../../../etc/passwd",
|
||||
"%2e%2e%2f%2e%2e%2fetc%2fpasswd",
|
||||
"logs\\..\\..\\windows\\system32",
|
||||
"..%00/etc/shadow",
|
||||
"%252e%252e%252f",
|
||||
}
|
||||
|
||||
for _, path := range maliciousPaths {
|
||||
t.Run("malicious_path_"+path, func(t *testing.T) {
|
||||
_, err := validateConfigPath(path, "log")
|
||||
if err == nil {
|
||||
t.Errorf("validateConfigPath should have rejected malicious path: %s", path)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "path traversal") {
|
||||
t.Errorf("Error should mention path traversal, got: %s", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfigPathLegitimate(t *testing.T) {
|
||||
// Test that legitimate paths still work
|
||||
legitimatePaths := []string{
|
||||
"/var/log",
|
||||
"/tmp/test-logs",
|
||||
"/home/user/logs",
|
||||
}
|
||||
|
||||
for _, path := range legitimatePaths {
|
||||
t.Run("legitimate_path_"+path, func(t *testing.T) {
|
||||
// Note: These might still fail due to other validation (like path existence)
|
||||
// but they should NOT fail due to path traversal detection
|
||||
if containsPathTraversal(path) {
|
||||
t.Errorf("containsPathTraversal should not detect traversal in legitimate path: %s", path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark the new security function
|
||||
func BenchmarkContainsPathTraversal(b *testing.B) {
|
||||
testPaths := []string{
|
||||
"/var/log/fail2ban.log",
|
||||
"../../../etc/passwd",
|
||||
"%2e%2e%2f%2e%2e%2fetc%2fpasswd",
|
||||
"logs/fail2ban.log.1",
|
||||
"%252e%252e%252f",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, path := range testPaths {
|
||||
containsPathTraversal(path)
|
||||
}
|
||||
}
|
||||
}
|
||||
182
cmd/cmd_integration_test.go
Normal file
182
cmd/cmd_integration_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIntegration_BanUnbanFlow(t *testing.T) {
|
||||
mock := NewMockClient()
|
||||
|
||||
// Ban an IP in sshd
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("1.2.3.4", "sshd").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Banned 1.2.3.4 in sshd").
|
||||
Run()
|
||||
|
||||
// Ban again (should be already banned)
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("1.2.3.4", "sshd").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Already banned 1.2.3.4 in sshd").
|
||||
Run()
|
||||
|
||||
// Unban the IP
|
||||
NewCommandTest(t, "unban").
|
||||
WithArgs("1.2.3.4", "sshd").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Unbanned 1.2.3.4 in sshd").
|
||||
Run()
|
||||
|
||||
// Unban again (should be already unbanned)
|
||||
NewCommandTest(t, "unban").
|
||||
WithArgs("1.2.3.4", "sshd").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Already unbanned 1.2.3.4 in sshd").
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestIntegration_BannedCommandAndTestIP(t *testing.T) {
|
||||
mock := NewMockClient()
|
||||
|
||||
// Ban two IPs in different jails
|
||||
NewCommandTest(t, "ban").WithArgs("1.2.3.4", "sshd").WithMockClient(mock).ExpectSuccess().Run()
|
||||
NewCommandTest(t, "ban").WithArgs("5.6.7.8", "apache").WithMockClient(mock).ExpectSuccess().Run()
|
||||
|
||||
// List banned IPs
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs("sshd").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("sshd | 1.2.3.4").
|
||||
Run()
|
||||
|
||||
// Test IP command - banned IP
|
||||
NewCommandTest(t, "test").
|
||||
WithArgs("1.2.3.4").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("is banned in").
|
||||
Run()
|
||||
|
||||
// Test IP command - not banned IP
|
||||
NewCommandTest(t, "test").
|
||||
WithArgs("9.9.9.9").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("is not banned").
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestIntegration_LogsFilteringAndFormat(t *testing.T) {
|
||||
mock := NewMockClient()
|
||||
|
||||
// Ban IPs to generate logs
|
||||
NewCommandTest(t, "ban").WithArgs("1.2.3.4", "sshd").WithMockClient(mock).ExpectSuccess().Run()
|
||||
NewCommandTest(t, "ban").WithArgs("5.6.7.8", "apache").WithMockClient(mock).ExpectSuccess().Run()
|
||||
|
||||
// Get logs for sshd
|
||||
NewCommandTest(t, "logs").
|
||||
WithArgs("sshd").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("sshd").
|
||||
Run()
|
||||
|
||||
// Get logs for specific IP
|
||||
NewCommandTest(t, "logs").
|
||||
WithArgs("apache", "5.6.7.8").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("5.6.7.8").
|
||||
Run()
|
||||
|
||||
// Test JSON output
|
||||
NewCommandTest(t, "logs").
|
||||
WithArgs("sshd").
|
||||
WithMockClient(mock).
|
||||
WithJSONFormat().
|
||||
ExpectSuccess().
|
||||
ExpectOutput("[").
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestIntegration_InvalidInputAndErrors(t *testing.T) {
|
||||
mock := NewMockClient()
|
||||
|
||||
// Ban with invalid jail
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("1.2.3.4", "notajail").
|
||||
WithMockClient(mock).
|
||||
ExpectError().
|
||||
ExpectOutput("not found").
|
||||
Run()
|
||||
|
||||
// Ban with invalid IP
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("notanip", "sshd").
|
||||
WithMockClient(mock).
|
||||
ExpectError().
|
||||
ExpectOutput("invalid IP address").
|
||||
Run()
|
||||
|
||||
// Unban with invalid jail
|
||||
NewCommandTest(t, "unban").
|
||||
WithArgs("1.2.3.4", "notajail").
|
||||
WithMockClient(mock).
|
||||
ExpectError().
|
||||
ExpectOutput("not found").
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestIntegration_ListJailsAndStatus(t *testing.T) {
|
||||
mock := NewMockClient()
|
||||
|
||||
// List jails
|
||||
result := NewCommandTest(t, "list-jails").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("sshd").
|
||||
Run()
|
||||
|
||||
// Also check for apache jail
|
||||
result.AssertContains("apache")
|
||||
|
||||
// Status all
|
||||
NewCommandTest(t, "status").
|
||||
WithArgs("all").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Mock status for all jails").
|
||||
Run()
|
||||
|
||||
// Status specific jail
|
||||
NewCommandTest(t, "status").
|
||||
WithArgs("sshd").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Mock status for jail sshd").
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestIntegration_BannedCommand_JSON(t *testing.T) {
|
||||
mock := NewMockClient()
|
||||
|
||||
// Ban an IP
|
||||
NewCommandTest(t, "ban").WithArgs("1.2.3.4", "sshd").WithMockClient(mock).ExpectSuccess().Run()
|
||||
|
||||
// List banned IPs in JSON
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs("sshd").
|
||||
WithMockClient(mock).
|
||||
WithJSONFormat().
|
||||
ExpectSuccess().
|
||||
ExpectOutput("\"Jail\"").
|
||||
Run()
|
||||
}
|
||||
|
||||
// Optionally, add more tests for edge cases, concurrency, and error propagation.
|
||||
412
cmd/cmd_logswatch_test.go
Normal file
412
cmd/cmd_logswatch_test.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
func TestLogsWatchCmd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
mockLogs []string
|
||||
limit int
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "watch all logs",
|
||||
args: []string{},
|
||||
mockLogs: []string{"2024-01-01 12:00:00 [sshd] Ban 192.168.1.100"},
|
||||
limit: 10,
|
||||
wantOutput: "2024-01-01 12:00:00 [sshd] Ban 192.168.1.100",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "watch logs with jail filter",
|
||||
args: []string{"sshd"},
|
||||
mockLogs: []string{
|
||||
"2024-01-01 12:00:00 [sshd] Ban 192.168.1.100",
|
||||
"2024-01-01 12:01:00 [apache] Ban 192.168.1.101",
|
||||
},
|
||||
limit: 10,
|
||||
wantOutput: "2024-01-01 12:00:00 [sshd] Ban 192.168.1.100",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "watch logs with jail and IP filter",
|
||||
args: []string{"sshd", "192.168.1.100"},
|
||||
mockLogs: []string{"2024-01-01 12:00:00 [sshd] Ban 192.168.1.100"},
|
||||
limit: 10,
|
||||
wantOutput: "2024-01-01 12:00:00 [sshd] Ban 192.168.1.100",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "watch logs with limit",
|
||||
args: []string{},
|
||||
mockLogs: []string{"line1", "line2", "line3"},
|
||||
limit: 2,
|
||||
wantOutput: "line2\nline3",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "watch logs with error",
|
||||
args: []string{},
|
||||
mockLogs: []string{},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a mock client that will return different logs on subsequent calls
|
||||
mock := &MockLogsWatchClient{
|
||||
initialLogs: tt.mockLogs,
|
||||
limit: tt.limit,
|
||||
shouldError: tt.wantError,
|
||||
}
|
||||
|
||||
config := &Config{Format: "plain"}
|
||||
cmd := LogsWatchCmd(context.Background(), mock, config)
|
||||
|
||||
// Set up command flags
|
||||
if tt.limit > 0 {
|
||||
if err := cmd.Flags().Set("limit", strconv.Itoa(tt.limit)); err != nil {
|
||||
t.Fatalf("failed to set limit flag: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Capture output
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
// For error cases, run the command and check error immediately
|
||||
if tt.wantError {
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For success cases, test that the command can be set up without error
|
||||
// We can't easily test the actual watching behavior in unit tests
|
||||
// without complex goroutine management, so we test the setup
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
// Test that we can create the command and it has the expected structure
|
||||
if cmd.Use != "logs-watch [jail] [ip]" {
|
||||
t.Errorf("unexpected command use: %s", cmd.Use)
|
||||
}
|
||||
|
||||
// Test that the limit flag exists
|
||||
limitFlag := cmd.Flags().Lookup("limit")
|
||||
if limitFlag == nil {
|
||||
t.Fatalf("limit flag should exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogsWatchCmdJSON(t *testing.T) {
|
||||
mock := &MockLogsWatchClient{
|
||||
initialLogs: []string{"2024-01-01 12:00:00 [sshd] Ban 192.168.1.100"},
|
||||
limit: 10,
|
||||
}
|
||||
|
||||
config := &Config{Format: JSONFormat}
|
||||
cmd := LogsWatchCmd(context.Background(), mock, config)
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
|
||||
// Test that the command is properly set up for JSON output
|
||||
cmd.SetArgs([]string{})
|
||||
|
||||
// Check that the command structure is correct
|
||||
if cmd.Use != "logs-watch [jail] [ip]" {
|
||||
t.Errorf("unexpected command use: %s", cmd.Use)
|
||||
}
|
||||
|
||||
// Test that the limit flag exists and has correct default
|
||||
limitFlag := cmd.Flags().Lookup("limit")
|
||||
if limitFlag == nil {
|
||||
t.Fatalf("limit flag should exist")
|
||||
}
|
||||
if limitFlag.DefValue != "10" {
|
||||
t.Errorf("expected default limit of 10, got %s", limitFlag.DefValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogsWatchCmdLimit(t *testing.T) {
|
||||
mock := &MockLogsWatchClient{
|
||||
initialLogs: []string{"line1", "line2", "line3", "line4", "line5"},
|
||||
limit: 3,
|
||||
}
|
||||
|
||||
config := &Config{Format: "plain"}
|
||||
cmd := LogsWatchCmd(context.Background(), mock, config)
|
||||
|
||||
// Set limit flag
|
||||
if err := cmd.Flags().Set("limit", "3"); err != nil {
|
||||
t.Fatalf("failed to set limit flag: %v", err)
|
||||
}
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
|
||||
// Test that the limit flag can be set properly
|
||||
err := cmd.Flags().Set("limit", "3")
|
||||
if err != nil {
|
||||
t.Errorf("failed to set limit flag: %v", err)
|
||||
}
|
||||
|
||||
// Check that the command structure is correct
|
||||
if cmd.Use != "logs-watch [jail] [ip]" {
|
||||
t.Errorf("unexpected command use: %s", cmd.Use)
|
||||
}
|
||||
|
||||
// Test that the limit flag was set correctly
|
||||
limitFlag := cmd.Flags().Lookup("limit")
|
||||
if limitFlag == nil {
|
||||
t.Errorf("limit flag should exist")
|
||||
}
|
||||
|
||||
// Get the limit value
|
||||
limitValue, err := cmd.Flags().GetInt("limit")
|
||||
if err != nil {
|
||||
t.Errorf("failed to get limit value: %v", err)
|
||||
}
|
||||
if limitValue != 3 {
|
||||
t.Errorf("expected limit value 3, got %d", limitValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeHashEquivalence(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a []string
|
||||
b []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "equal slices",
|
||||
a: []string{"a", "b", "c"},
|
||||
b: []string{"a", "b", "c"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different lengths",
|
||||
a: []string{"a", "b"},
|
||||
b: []string{"a", "b", "c"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different content",
|
||||
a: []string{"a", "b", "c"},
|
||||
b: []string{"a", "b", "d"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty slices",
|
||||
a: []string{},
|
||||
b: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "one empty, one not",
|
||||
a: []string{},
|
||||
b: []string{"a"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hashA := computeHash(tt.a)
|
||||
hashB := computeHash(tt.b)
|
||||
result := hashA == hashB
|
||||
if result != tt.expected {
|
||||
t.Errorf("computeHash equivalence for (%v, %v) = %v, want %v", tt.a, tt.b, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogsWatchCmdFlags(t *testing.T) {
|
||||
mock := &MockLogsWatchClient{
|
||||
initialLogs: []string{"test log"},
|
||||
limit: 5,
|
||||
}
|
||||
|
||||
config := &Config{Format: "plain"}
|
||||
cmd := LogsWatchCmd(context.Background(), mock, config)
|
||||
|
||||
// Test that the limit flag is properly defined
|
||||
limitFlag := cmd.Flags().Lookup("limit")
|
||||
if limitFlag == nil {
|
||||
t.Fatal("limit flag should be defined")
|
||||
}
|
||||
|
||||
if limitFlag.Shorthand != "n" {
|
||||
t.Errorf("expected limit flag shorthand to be 'n', got %q", limitFlag.Shorthand)
|
||||
}
|
||||
|
||||
if limitFlag.DefValue != "10" {
|
||||
t.Errorf("expected limit flag default value to be '10', got %q", limitFlag.DefValue)
|
||||
}
|
||||
|
||||
// Test that the interval flag is properly defined
|
||||
intervalFlag := cmd.Flags().Lookup("interval")
|
||||
if intervalFlag == nil {
|
||||
t.Fatal("interval flag should be defined")
|
||||
}
|
||||
if intervalFlag.Shorthand != "i" {
|
||||
t.Errorf("expected interval flag shorthand to be 'i', got %q", intervalFlag.Shorthand)
|
||||
}
|
||||
if intervalFlag.DefValue != DefaultPollingInterval.String() {
|
||||
t.Errorf(
|
||||
"expected interval flag default value to be %q, got %q",
|
||||
DefaultPollingInterval.String(),
|
||||
intervalFlag.DefValue,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MockLogsWatchClient is a mock client specifically for testing logs-watch
|
||||
type MockLogsWatchClient struct {
|
||||
initialLogs []string
|
||||
limit int
|
||||
shouldError bool
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) GetLogLines(jail, ip string) ([]string, error) {
|
||||
if m.shouldError {
|
||||
return nil, fmt.Errorf("mock error getting log lines")
|
||||
}
|
||||
|
||||
m.callCount++
|
||||
|
||||
var logs []string
|
||||
// Return initial logs on first call, then simulate new logs on subsequent calls
|
||||
if m.callCount == 1 {
|
||||
logs = m.initialLogs
|
||||
} else {
|
||||
// Simulate new logs being added
|
||||
logs = make([]string, len(m.initialLogs))
|
||||
copy(logs, m.initialLogs)
|
||||
logs = append(logs, fmt.Sprintf("new log line %d", m.callCount))
|
||||
}
|
||||
|
||||
// Apply jail filtering if specified
|
||||
if jail != "" && jail != "all" {
|
||||
var filtered []string
|
||||
for _, line := range logs {
|
||||
if strings.Contains(line, "["+jail+"]") {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
logs = filtered
|
||||
}
|
||||
|
||||
// Apply IP filtering if specified
|
||||
if ip != "" && ip != "all" {
|
||||
var filtered []string
|
||||
for _, line := range logs {
|
||||
if strings.Contains(line, ip) {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
logs = filtered
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// Implement other required methods for the interface
|
||||
func (m *MockLogsWatchClient) ListJails() ([]string, error) {
|
||||
return []string{"sshd", "apache"}, nil
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) StatusAll() (string, error) {
|
||||
return "mock status", nil
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) StatusJail(jail string) (string, error) {
|
||||
return fmt.Sprintf("mock status for %s", jail), nil
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) BanIP(_, _ string) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) UnbanIP(_, _ string) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) BannedIn(_ string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) GetBanRecords(_ []string) ([]fail2ban.BanRecord, error) {
|
||||
return []fail2ban.BanRecord{}, nil
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) ListFilters() ([]string, error) {
|
||||
return []string{"sshd"}, nil
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) TestFilter(_ string) (string, error) {
|
||||
return "mock filter test result", nil
|
||||
}
|
||||
|
||||
// Context-aware methods for MockLogsWatchClient
|
||||
|
||||
func (m *MockLogsWatchClient) ListJailsWithContext(_ context.Context) ([]string, error) {
|
||||
return m.ListJails()
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) StatusAllWithContext(_ context.Context) (string, error) {
|
||||
return m.StatusAll()
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) StatusJailWithContext(_ context.Context, jail string) (string, error) {
|
||||
return m.StatusJail(jail)
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) BanIPWithContext(_ context.Context, ip, jail string) (int, error) {
|
||||
return m.BanIP(ip, jail)
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) UnbanIPWithContext(_ context.Context, ip, jail string) (int, error) {
|
||||
return m.UnbanIP(ip, jail)
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) BannedInWithContext(_ context.Context, ip string) ([]string, error) {
|
||||
return m.BannedIn(ip)
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) GetBanRecordsWithContext(
|
||||
_ context.Context, jails []string) ([]fail2ban.BanRecord, error) {
|
||||
return m.GetBanRecords(jails)
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) GetLogLinesWithContext(_ context.Context, jail, ip string) ([]string, error) {
|
||||
return m.GetLogLines(jail, ip)
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) ListFiltersWithContext(_ context.Context) ([]string, error) {
|
||||
return m.ListFilters()
|
||||
}
|
||||
|
||||
func (m *MockLogsWatchClient) TestFilterWithContext(_ context.Context, filter string) (string, error) {
|
||||
return m.TestFilter(filter)
|
||||
}
|
||||
97
cmd/cmd_metrics_test.go
Normal file
97
cmd/cmd_metrics_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
func TestMetricsCommand(t *testing.T) {
|
||||
// Setup
|
||||
_, cleanup := fail2ban.SetupMockEnvironmentWithSudo(t, true)
|
||||
defer cleanup()
|
||||
|
||||
// Set global metrics for testing
|
||||
metrics := NewMetrics()
|
||||
// Simulate some metrics
|
||||
metrics.RecordCommandExecution("ban", 50*time.Millisecond, true)
|
||||
metrics.RecordCommandExecution("ban", 100*time.Millisecond, false)
|
||||
metrics.RecordBanOperation("ban", 50*time.Millisecond, true)
|
||||
metrics.RecordBanOperation("unban", 30*time.Millisecond, true)
|
||||
metrics.RecordClientOperation("list-jails", 20*time.Millisecond, true)
|
||||
metrics.RecordValidationCacheHit()
|
||||
metrics.RecordValidationCacheMiss()
|
||||
metrics.UpdateMemoryUsage(10 * 1024 * 1024) // 10MB
|
||||
metrics.UpdateGoroutineCount(5)
|
||||
SetGlobalMetrics(metrics)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
format string
|
||||
wantError bool
|
||||
wantOutput []string
|
||||
}{
|
||||
{
|
||||
name: "show metrics in plain format",
|
||||
args: []string{"metrics"},
|
||||
format: "plain",
|
||||
wantError: false,
|
||||
wantOutput: []string{
|
||||
"F2B Performance Metrics",
|
||||
"System:",
|
||||
"Commands:",
|
||||
"Total Executions: 2",
|
||||
"Total Failures: 1",
|
||||
"Ban Operations:",
|
||||
"Ban Operations: 1 (failures: 0)",
|
||||
"Unban Operations: 1 (failures: 0)",
|
||||
"Client Operations:",
|
||||
"Total Operations: 1",
|
||||
"Validation:",
|
||||
"Cache Hits: 1",
|
||||
"Cache Misses: 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "show metrics in JSON format",
|
||||
args: []string{"metrics", "--format=json"},
|
||||
format: "json",
|
||||
wantError: false,
|
||||
wantOutput: []string{
|
||||
`"command_executions": 2`,
|
||||
`"command_failures": 1`,
|
||||
`"ban_operations": 1`,
|
||||
`"unban_operations": 1`,
|
||||
`"client_operations": 1`,
|
||||
`"validation_cache_hits": 1`,
|
||||
`"validation_cache_miss": 1`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mock := NewMockClient()
|
||||
setMockJails(mock, []string{"sshd", "apache"})
|
||||
|
||||
// Execute command
|
||||
output, err := executeCommand(mock, tt.args...)
|
||||
|
||||
// Check error
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("MetricsCmd() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
// Check output
|
||||
for _, want := range tt.wantOutput {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("MetricsCmd() output missing %q\nGot: %s", want, output)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
433
cmd/cmd_output_test.go
Normal file
433
cmd/cmd_output_test.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestPrintOutput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
format string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "plain string output",
|
||||
data: "hello world",
|
||||
format: "plain",
|
||||
expected: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "json string output",
|
||||
data: "hello world",
|
||||
format: JSONFormat,
|
||||
expected: "\"hello world\"\n",
|
||||
},
|
||||
{
|
||||
name: "json object output",
|
||||
data: map[string]string{"key": "value"},
|
||||
format: JSONFormat,
|
||||
expected: "{\n \"key\": \"value\"\n}\n",
|
||||
},
|
||||
{
|
||||
name: "json array output",
|
||||
data: []string{"item1", "item2"},
|
||||
format: JSONFormat,
|
||||
expected: "[\n \"item1\",\n \"item2\"\n]\n",
|
||||
},
|
||||
{
|
||||
name: "plain struct output",
|
||||
data: struct{ Name string }{"test"},
|
||||
format: "plain",
|
||||
expected: "{test}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Capture stdout
|
||||
oldStdout := os.Stdout
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pipe: %v", err)
|
||||
}
|
||||
os.Stdout = w
|
||||
|
||||
PrintOutput(tt.data, tt.format)
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("unexpected close error: %v", err)
|
||||
}
|
||||
os.Stdout = oldStdout
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := buf.ReadFrom(r); err != nil {
|
||||
t.Fatalf("failed to read output: %v", err)
|
||||
}
|
||||
output := buf.String()
|
||||
|
||||
if output != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintOutputTo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
format string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "plain output to buffer",
|
||||
data: "test message",
|
||||
format: "plain",
|
||||
expected: "test message\n",
|
||||
},
|
||||
{
|
||||
name: "json output to buffer",
|
||||
data: map[string]int{"count": 42},
|
||||
format: JSONFormat,
|
||||
expected: "{\n \"count\": 42\n}\n",
|
||||
},
|
||||
{
|
||||
name: "unknown format defaults to plain",
|
||||
data: "test",
|
||||
format: "unknown",
|
||||
expected: "test\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
PrintOutputTo(&buf, tt.data, tt.format)
|
||||
|
||||
output := buf.String()
|
||||
if output != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintOutputTo_JSONError(t *testing.T) {
|
||||
// Test with data that cannot be marshaled to JSON
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Capture log output
|
||||
oldOutput := Logger.Out
|
||||
var logBuf bytes.Buffer
|
||||
Logger.SetOutput(&logBuf)
|
||||
defer Logger.SetOutput(oldOutput)
|
||||
|
||||
// Function type cannot be marshaled to JSON
|
||||
PrintOutputTo(&buf, func() {}, JSONFormat)
|
||||
|
||||
// Should have logged an error
|
||||
logOutput := logBuf.String()
|
||||
if !strings.Contains(logOutput, "Failed to encode JSON output") {
|
||||
t.Errorf("expected JSON encoding error to be logged, got: %s", logOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectLog bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expectLog: false,
|
||||
},
|
||||
{
|
||||
name: "actual error",
|
||||
err: &testError{"test error message"},
|
||||
expectLog: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Capture stderr
|
||||
oldStderr := os.Stderr
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pipe: %v", err)
|
||||
}
|
||||
os.Stderr = w
|
||||
|
||||
// Capture log output
|
||||
oldOutput := Logger.Out
|
||||
var logBuf bytes.Buffer
|
||||
Logger.SetOutput(&logBuf)
|
||||
|
||||
PrintError(tt.err)
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("failed to close pipe writer: %v", err)
|
||||
}
|
||||
os.Stderr = oldStderr
|
||||
Logger.SetOutput(oldOutput)
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
if _, err := stderrBuf.ReadFrom(r); err != nil {
|
||||
t.Fatalf("failed to read stderr: %v", err)
|
||||
}
|
||||
stderrOutput := stderrBuf.String()
|
||||
logOutput := logBuf.String()
|
||||
|
||||
if tt.expectLog {
|
||||
if !strings.Contains(logOutput, "Command failed") {
|
||||
t.Errorf("expected error to be logged, got: %s", logOutput)
|
||||
}
|
||||
if !strings.Contains(stderrOutput, "Error: test error message") {
|
||||
t.Errorf("expected error in stderr, got: %s", stderrOutput)
|
||||
}
|
||||
} else {
|
||||
if stderrOutput != "" {
|
||||
t.Errorf("expected no stderr output for nil error, got: %s", stderrOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintErrorf(t *testing.T) {
|
||||
// Capture stderr
|
||||
oldStderr := os.Stderr
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pipe: %v", err)
|
||||
}
|
||||
os.Stderr = w
|
||||
|
||||
// Capture log output
|
||||
oldOutput := Logger.Out
|
||||
var logBuf bytes.Buffer
|
||||
Logger.SetOutput(&logBuf)
|
||||
|
||||
PrintErrorf("formatted error: %s %d", "test", 42)
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("failed to close pipe writer: %v", err)
|
||||
}
|
||||
os.Stderr = oldStderr
|
||||
Logger.SetOutput(oldOutput)
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
if _, err := stderrBuf.ReadFrom(r); err != nil {
|
||||
t.Fatalf("failed to read stderr: %v", err)
|
||||
}
|
||||
stderrOutput := stderrBuf.String()
|
||||
logOutput := logBuf.String()
|
||||
|
||||
expectedStderr := "Error: formatted error: test 42\n"
|
||||
if stderrOutput != expectedStderr {
|
||||
t.Errorf("expected stderr %q, got %q", expectedStderr, stderrOutput)
|
||||
}
|
||||
|
||||
if !strings.Contains(logOutput, "formatted error: test 42") {
|
||||
t.Errorf("expected error to be logged, got: %s", logOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCmdOutput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCmd func() *cobra.Command
|
||||
expectStdout bool
|
||||
}{
|
||||
{
|
||||
name: "command with output set",
|
||||
setupCmd: func() *cobra.Command {
|
||||
cmd := &cobra.Command{}
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
return cmd
|
||||
},
|
||||
expectStdout: false,
|
||||
},
|
||||
{
|
||||
name: "nil command",
|
||||
setupCmd: func() *cobra.Command {
|
||||
return nil
|
||||
},
|
||||
expectStdout: true,
|
||||
},
|
||||
{
|
||||
name: "command without output set",
|
||||
setupCmd: func() *cobra.Command {
|
||||
return &cobra.Command{}
|
||||
},
|
||||
expectStdout: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := tt.setupCmd()
|
||||
output := GetCmdOutput(cmd)
|
||||
|
||||
if tt.expectStdout {
|
||||
if output != os.Stdout {
|
||||
t.Errorf("expected os.Stdout, got different writer")
|
||||
}
|
||||
} else {
|
||||
if output == os.Stdout {
|
||||
t.Errorf("expected custom writer, got os.Stdout")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCmdError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCmd func() *cobra.Command
|
||||
expectStderr bool
|
||||
}{
|
||||
{
|
||||
name: "command with error output set",
|
||||
setupCmd: func() *cobra.Command {
|
||||
cmd := &cobra.Command{}
|
||||
var buf bytes.Buffer
|
||||
cmd.SetErr(&buf)
|
||||
return cmd
|
||||
},
|
||||
expectStderr: false,
|
||||
},
|
||||
{
|
||||
name: "nil command",
|
||||
setupCmd: func() *cobra.Command {
|
||||
return nil
|
||||
},
|
||||
expectStderr: true,
|
||||
},
|
||||
{
|
||||
name: "command without error output set",
|
||||
setupCmd: func() *cobra.Command {
|
||||
return &cobra.Command{}
|
||||
},
|
||||
expectStderr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := tt.setupCmd()
|
||||
output := GetCmdError(cmd)
|
||||
|
||||
if tt.expectStderr {
|
||||
if output != os.Stderr {
|
||||
t.Errorf("expected os.Stderr, got different writer")
|
||||
}
|
||||
} else {
|
||||
if output == os.Stderr {
|
||||
t.Errorf("expected custom writer, got os.Stderr")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerInitialization(t *testing.T) {
|
||||
// Save and restore logger output
|
||||
oldOut := Logger.Out
|
||||
defer Logger.SetOutput(oldOut)
|
||||
Logger.SetOutput(os.Stderr)
|
||||
|
||||
if Logger == nil {
|
||||
t.Fatal("Logger should be initialized")
|
||||
}
|
||||
|
||||
// Test default formatter
|
||||
if _, ok := Logger.Formatter.(*logrus.TextFormatter); !ok {
|
||||
t.Errorf("expected TextFormatter, got %T", Logger.Formatter)
|
||||
}
|
||||
|
||||
// Test default output
|
||||
if Logger.Out != os.Stderr {
|
||||
t.Errorf("expected Logger output to be os.Stderr")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFormatConstant(t *testing.T) {
|
||||
if JSONFormat != "json" {
|
||||
t.Errorf("expected JSONFormat to be 'json', got %q", JSONFormat)
|
||||
}
|
||||
}
|
||||
|
||||
// testError is a simple error implementation for testing
|
||||
type testError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *testError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkPrintOutputPlain(b *testing.B) {
|
||||
var buf bytes.Buffer
|
||||
data := "test message"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf.Reset()
|
||||
PrintOutputTo(&buf, data, "plain")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPrintOutputJSON(b *testing.B) {
|
||||
var buf bytes.Buffer
|
||||
data := map[string]string{"key": "value"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf.Reset()
|
||||
PrintOutputTo(&buf, data, JSONFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPrintError(b *testing.B) {
|
||||
err := &testError{"benchmark error"}
|
||||
|
||||
// Suppress output for benchmarking
|
||||
oldStderr := os.Stderr
|
||||
oldOutput := Logger.Out
|
||||
|
||||
devNull, derr := os.Open(os.DevNull)
|
||||
if derr != nil {
|
||||
b.Fatalf("failed to open dev null: %v", derr)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := devNull.Close(); cerr != nil {
|
||||
b.Fatalf("failed to close dev null: %v", cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
os.Stderr = devNull
|
||||
Logger.SetOutput(devNull)
|
||||
|
||||
defer func() {
|
||||
os.Stderr = oldStderr
|
||||
Logger.SetOutput(oldOutput)
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
PrintError(err)
|
||||
}
|
||||
}
|
||||
189
cmd/cmd_parallel_operations_test.go
Normal file
189
cmd/cmd_parallel_operations_test.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParallelOperationProcessor_IndexValidation(t *testing.T) {
|
||||
// Test to ensure negative indices don't cause panics
|
||||
processor := NewParallelOperationProcessor(2)
|
||||
|
||||
// Mock client for testing
|
||||
mockClient := NewMockClient()
|
||||
|
||||
jails := []string{"sshd", "apache"} // Use default jails from mock
|
||||
|
||||
// This should not panic even if there were negative indices
|
||||
results, err := processor.ProcessBanOperationParallel(mockClient, "192.168.1.100", jails)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessBanOperationParallel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("Expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
// Verify all results are valid
|
||||
for i, result := range results {
|
||||
if result.Jail == "" {
|
||||
t.Errorf("Result %d has empty jail", i)
|
||||
}
|
||||
if result.Status == "" {
|
||||
t.Errorf("Result %d has empty status", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelOperationProcessor_UnbanIndexValidation(t *testing.T) {
|
||||
// Test unban operations for index validation
|
||||
processor := NewParallelOperationProcessor(2)
|
||||
|
||||
// Mock client for testing - need to ban first
|
||||
mockClient := NewMockClient()
|
||||
|
||||
// Ban the IP first so we can unban it using framework for consistency
|
||||
NewCommandTest(t, "ban").WithArgs("192.168.1.100", "sshd").WithMockClient(mockClient).ExpectSuccess().Run()
|
||||
NewCommandTest(t, "ban").WithArgs("192.168.1.100", "apache").WithMockClient(mockClient).ExpectSuccess().Run()
|
||||
|
||||
jails := []string{"sshd", "apache"}
|
||||
|
||||
// This should not panic even if there were negative indices
|
||||
results, err := processor.ProcessUnbanOperationParallel(mockClient, "192.168.1.100", jails)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessUnbanOperationParallel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("Expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
// Verify all results are valid
|
||||
for i, result := range results {
|
||||
if result.Jail == "" {
|
||||
t.Errorf("Result %d has empty jail", i)
|
||||
}
|
||||
if result.Status == "" {
|
||||
t.Errorf("Result %d has empty status", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelOperationProcessor_EmptyJailsList(t *testing.T) {
|
||||
// Test edge case with empty jails list
|
||||
processor := NewParallelOperationProcessor(2)
|
||||
mockClient := NewMockClient()
|
||||
|
||||
results, err := processor.ProcessBanOperationParallel(mockClient, "192.168.1.100", []string{})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessBanOperationParallel with empty jails failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 0 {
|
||||
t.Errorf("Expected 0 results for empty jails, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelOperationProcessor_ErrorHandling(t *testing.T) {
|
||||
// Test error handling doesn't cause index issues
|
||||
processor := NewParallelOperationProcessor(2)
|
||||
|
||||
// Mock client for testing
|
||||
mockClient := NewMockClient()
|
||||
|
||||
// Use non-existent jail to trigger error
|
||||
jails := []string{"nonexistent1", "nonexistent2"}
|
||||
|
||||
results, err := processor.ProcessBanOperationParallel(mockClient, "192.168.1.100", jails)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessBanOperationParallel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("Expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
// All results should have errors for non-existent jails
|
||||
for i, result := range results {
|
||||
if result.Jail == "" {
|
||||
t.Errorf("Result %d has empty jail", i)
|
||||
}
|
||||
// Status should indicate the error (e.g., "jail 'nonexistent1' not found")
|
||||
if result.Status == "" {
|
||||
t.Errorf("Result %d has empty status", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelOperationProcessor_ConcurrentSafety(t *testing.T) {
|
||||
// Test concurrent access doesn't cause race conditions or index issues
|
||||
processor := NewParallelOperationProcessor(4)
|
||||
mockClient := NewMockClient()
|
||||
|
||||
// Set up for multiple IPs and jails
|
||||
ips := []string{"192.168.1.100", "192.168.1.101", "192.168.1.102"}
|
||||
jails := []string{"sshd", "apache"} // Use existing jails in mock
|
||||
|
||||
// Run multiple operations concurrently
|
||||
errChan := make(chan error, len(ips))
|
||||
|
||||
for _, ip := range ips {
|
||||
go func(testIP string) {
|
||||
results, err := processor.ProcessBanOperationParallel(mockClient, testIP, jails)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
if len(results) != len(jails) {
|
||||
errChan <- errors.New("incorrect number of results")
|
||||
return
|
||||
}
|
||||
errChan <- nil
|
||||
}(ip)
|
||||
}
|
||||
|
||||
// Check all operations completed successfully
|
||||
for i := 0; i < len(ips); i++ {
|
||||
if err := <-errChan; err != nil {
|
||||
t.Errorf("Concurrent operation %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewParallelOperationProcessor(t *testing.T) {
|
||||
// Test processor creation with various worker counts
|
||||
tests := []struct {
|
||||
name string
|
||||
workerCount int
|
||||
expectCPU bool
|
||||
}{
|
||||
{"positive worker count", 4, false},
|
||||
{"zero worker count uses CPU count", 0, true},
|
||||
{"negative worker count uses CPU count", -1, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
processor := NewParallelOperationProcessor(tt.workerCount)
|
||||
|
||||
if processor == nil {
|
||||
t.Fatal("NewParallelOperationProcessor returned nil")
|
||||
}
|
||||
|
||||
if tt.expectCPU {
|
||||
// Should use CPU count when invalid worker count provided
|
||||
if processor.workerCount <= 0 {
|
||||
t.Error("Worker count should be positive when using CPU count")
|
||||
}
|
||||
} else {
|
||||
if processor.workerCount != tt.workerCount {
|
||||
t.Errorf("Expected worker count %d, got %d", tt.workerCount, processor.workerCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
784
cmd/cmd_root_test.go
Normal file
784
cmd/cmd_root_test.go
Normal file
@@ -0,0 +1,784 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
func TestParseLogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level string
|
||||
expected logrus.Level
|
||||
}{
|
||||
{
|
||||
name: "debug level",
|
||||
level: "debug",
|
||||
expected: logrus.DebugLevel,
|
||||
},
|
||||
{
|
||||
name: "info level",
|
||||
level: "info",
|
||||
expected: logrus.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "warn level",
|
||||
level: "warn",
|
||||
expected: logrus.WarnLevel,
|
||||
},
|
||||
{
|
||||
name: "warning level",
|
||||
level: "warning",
|
||||
expected: logrus.WarnLevel,
|
||||
},
|
||||
{
|
||||
name: "error level",
|
||||
level: "error",
|
||||
expected: logrus.ErrorLevel,
|
||||
},
|
||||
{
|
||||
name: "fatal level",
|
||||
level: "fatal",
|
||||
expected: logrus.FatalLevel,
|
||||
},
|
||||
{
|
||||
name: "panic level",
|
||||
level: "panic",
|
||||
expected: logrus.PanicLevel,
|
||||
},
|
||||
{
|
||||
name: "unknown level defaults to info",
|
||||
level: "unknown",
|
||||
expected: logrus.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "empty level defaults to info",
|
||||
level: "",
|
||||
expected: logrus.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "uppercase level",
|
||||
level: "DEBUG",
|
||||
expected: logrus.InfoLevel, // case sensitive, so falls back to default
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseLogLevel(tt.level)
|
||||
if result != tt.expected {
|
||||
t.Errorf("parseLogLevel(%q) = %v, want %v", tt.level, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigDefaults(t *testing.T) {
|
||||
// Test that Config struct has reasonable defaults
|
||||
config := Config{}
|
||||
|
||||
// Initially empty
|
||||
if config.LogDir != "" {
|
||||
t.Errorf("expected empty LogDir, got %q", config.LogDir)
|
||||
}
|
||||
if config.FilterDir != "" {
|
||||
t.Errorf("expected empty FilterDir, got %q", config.FilterDir)
|
||||
}
|
||||
if config.Format != "" {
|
||||
t.Errorf("expected empty Format, got %q", config.Format)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvironmentVariableSetup(t *testing.T) {
|
||||
// Save original environment
|
||||
// Set up environment variables using t.Setenv for automatic cleanup
|
||||
t.Setenv("F2B_LOG_DIR", os.Getenv("F2B_LOG_DIR"))
|
||||
t.Setenv("F2B_FILTER_DIR", os.Getenv("F2B_FILTER_DIR"))
|
||||
t.Setenv("F2B_LOG_LEVEL", os.Getenv("F2B_LOG_LEVEL"))
|
||||
t.Setenv("F2B_LOG_FILE", os.Getenv("F2B_LOG_FILE"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
envVar string
|
||||
envValue string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "F2B_LOG_DIR environment variable",
|
||||
envVar: "F2B_LOG_DIR",
|
||||
envValue: "/custom/log/dir",
|
||||
expected: "/custom/log/dir",
|
||||
},
|
||||
{
|
||||
name: "F2B_FILTER_DIR environment variable",
|
||||
envVar: "F2B_FILTER_DIR",
|
||||
envValue: "/custom/filter/dir",
|
||||
expected: "/custom/filter/dir",
|
||||
},
|
||||
{
|
||||
name: "F2B_LOG_LEVEL environment variable",
|
||||
envVar: "F2B_LOG_LEVEL",
|
||||
envValue: "debug",
|
||||
expected: "debug",
|
||||
},
|
||||
{
|
||||
name: "F2B_LOG_FILE environment variable",
|
||||
envVar: "F2B_LOG_FILE",
|
||||
envValue: "/tmp/f2b.log",
|
||||
expected: "/tmp/f2b.log",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set environment variable using t.Setenv for automatic cleanup
|
||||
t.Setenv(tt.envVar, tt.envValue)
|
||||
|
||||
// Get the value
|
||||
result := os.Getenv(tt.envVar)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigStructure(t *testing.T) {
|
||||
config := Config{
|
||||
LogDir: "/test/log",
|
||||
FilterDir: "/test/filter",
|
||||
Format: "json",
|
||||
}
|
||||
|
||||
if config.LogDir != "/test/log" {
|
||||
t.Errorf("expected LogDir '/test/log', got %q", config.LogDir)
|
||||
}
|
||||
if config.FilterDir != "/test/filter" {
|
||||
t.Errorf("expected FilterDir '/test/filter', got %q", config.FilterDir)
|
||||
}
|
||||
if config.Format != "json" {
|
||||
t.Errorf("expected Format 'json', got %q", config.Format)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionCmdStructure(t *testing.T) {
|
||||
cmd := completionCmd()
|
||||
|
||||
if cmd.Use != "completion [bash|zsh|fish|powershell]" {
|
||||
t.Errorf("unexpected completion command Use: %q", cmd.Use)
|
||||
}
|
||||
|
||||
if cmd.Short != "Generate shell completion scripts" {
|
||||
t.Errorf("unexpected completion command Short: %q", cmd.Short)
|
||||
}
|
||||
|
||||
expectedValidArgs := []string{"bash", "zsh", "fish", "powershell"}
|
||||
if len(cmd.ValidArgs) != len(expectedValidArgs) {
|
||||
t.Errorf("expected %d ValidArgs, got %d", len(expectedValidArgs), len(cmd.ValidArgs))
|
||||
}
|
||||
|
||||
for i, expected := range expectedValidArgs {
|
||||
if i >= len(cmd.ValidArgs) || cmd.ValidArgs[i] != expected {
|
||||
t.Errorf("expected ValidArgs[%d] = %q, got %q", i, expected, cmd.ValidArgs[i])
|
||||
}
|
||||
}
|
||||
|
||||
if !cmd.DisableFlagsInUseLine {
|
||||
t.Errorf("expected DisableFlagsInUseLine to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalVariables(t *testing.T) {
|
||||
// Test that global variables are properly initialized
|
||||
if rootCmd == nil {
|
||||
t.Fatal("rootCmd should be initialized")
|
||||
}
|
||||
|
||||
if rootCmd.Use != "f2b" {
|
||||
t.Errorf("expected rootCmd.Use to be 'f2b', got %q", rootCmd.Use)
|
||||
}
|
||||
|
||||
if rootCmd.Short != "Fail2Ban CLI helper" {
|
||||
t.Errorf("expected rootCmd.Short to be 'Fail2Ban CLI helper', got %q", rootCmd.Short)
|
||||
}
|
||||
|
||||
expectedLong := "Fail2Ban CLI tool implemented in Go using Cobra."
|
||||
if rootCmd.Long != expectedLong {
|
||||
t.Errorf("expected rootCmd.Long to be %q, got %q", expectedLong, rootCmd.Long)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkParseLogLevel benchmarks the log level parsing function
|
||||
func BenchmarkParseLogLevel(b *testing.B) {
|
||||
levels := []string{"debug", "info", "warn", "error", "unknown"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
level := levels[i%len(levels)]
|
||||
parseLogLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultValues tests the default values used in the configuration
|
||||
func TestDefaultValues(t *testing.T) {
|
||||
// Clear environment variables for this test using t.Setenv
|
||||
t.Setenv("F2B_LOG_DIR", "")
|
||||
t.Setenv("F2B_FILTER_DIR", "")
|
||||
|
||||
// Test default values when environment variables are not set
|
||||
logDir := os.Getenv("F2B_LOG_DIR")
|
||||
if logDir != "" {
|
||||
t.Errorf("expected empty F2B_LOG_DIR, got %q", logDir)
|
||||
}
|
||||
|
||||
filterDir := os.Getenv("F2B_FILTER_DIR")
|
||||
if filterDir != "" {
|
||||
t.Errorf("expected empty F2B_FILTER_DIR, got %q", filterDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupClient func() fail2ban.Client
|
||||
config Config
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "successful execution with mock client",
|
||||
setupClient: func() fail2ban.Client {
|
||||
return fail2ban.NewMockClient()
|
||||
},
|
||||
config: Config{
|
||||
LogDir: "/tmp/test",
|
||||
FilterDir: "/tmp/filters",
|
||||
Format: "plain",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "execution with json format",
|
||||
setupClient: func() fail2ban.Client {
|
||||
return fail2ban.NewMockClient()
|
||||
},
|
||||
config: Config{
|
||||
LogDir: "/var/log",
|
||||
FilterDir: "/etc/fail2ban/filter.d",
|
||||
Format: "json",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := tt.setupClient()
|
||||
|
||||
// Capture stdout to prevent output during tests
|
||||
oldStdout := os.Stdout
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pipe: %v", err)
|
||||
}
|
||||
os.Stdout = w
|
||||
|
||||
// Set up a simple test command that will exit quickly
|
||||
originalArgs := os.Args
|
||||
os.Args = []string{"f2b", "version"}
|
||||
|
||||
err = Execute(client, tt.config)
|
||||
|
||||
// Restore stdout
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("failed to close writer: %v", err)
|
||||
}
|
||||
os.Stdout = oldStdout
|
||||
os.Args = originalArgs
|
||||
|
||||
// Read and discard output
|
||||
var buf bytes.Buffer
|
||||
if _, err := buf.ReadFrom(r); err != nil {
|
||||
t.Fatalf("failed to read output: %v", err)
|
||||
}
|
||||
|
||||
AssertError(t, err, tt.wantError, tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWithRealCommands(t *testing.T) {
|
||||
// Test that Execute properly adds all commands
|
||||
client := fail2ban.NewMockClient()
|
||||
config := Config{
|
||||
LogDir: "/tmp",
|
||||
FilterDir: "/tmp",
|
||||
Format: "plain",
|
||||
}
|
||||
|
||||
// Create a new root command to test command registration
|
||||
originalRootCmd := rootCmd
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "f2b",
|
||||
Short: "Fail2Ban CLI helper",
|
||||
Long: "Fail2Ban CLI tool implemented in Go using Cobra.",
|
||||
}
|
||||
|
||||
// Capture stdout
|
||||
oldStdout := os.Stdout
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pipe: %v", err)
|
||||
}
|
||||
os.Stdout = w
|
||||
|
||||
originalArgs := os.Args
|
||||
os.Args = []string{"f2b", "help"}
|
||||
|
||||
err = Execute(client, config)
|
||||
|
||||
// Restore
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("failed to close writer: %v", err)
|
||||
}
|
||||
os.Stdout = oldStdout
|
||||
os.Args = originalArgs
|
||||
rootCmd = originalRootCmd
|
||||
|
||||
// Read output
|
||||
var buf bytes.Buffer
|
||||
if _, err := buf.ReadFrom(r); err != nil {
|
||||
t.Fatalf("failed to read output: %v", err)
|
||||
}
|
||||
output := buf.String()
|
||||
|
||||
AssertError(t, err, false, "root help command")
|
||||
|
||||
// Check that help output contains expected commands
|
||||
expectedCommands := []string{
|
||||
"list-jails",
|
||||
"status",
|
||||
"banned",
|
||||
"ban",
|
||||
"unban",
|
||||
"test",
|
||||
"logs",
|
||||
"logs-watch",
|
||||
"service",
|
||||
"version",
|
||||
"test-filter",
|
||||
"completion",
|
||||
}
|
||||
for _, cmd := range expectedCommands {
|
||||
if !strings.Contains(output, cmd) {
|
||||
t.Errorf("expected help output to contain command %q", cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionCmdExecution(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "bash completion",
|
||||
args: []string{"bash"},
|
||||
wantOutput: "__start_f2b",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "zsh completion",
|
||||
args: []string{"zsh"},
|
||||
wantOutput: "#compdef f2b",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "fish completion",
|
||||
args: []string{"fish"},
|
||||
wantOutput: "complete -c f2b",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "powershell completion",
|
||||
args: []string{"powershell"},
|
||||
wantOutput: "Register-ArgumentCompleter",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported shell",
|
||||
args: []string{"unsupported"},
|
||||
wantError: true, // Cobra returns an error for invalid args due to OnlyValidArgs
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Framework doesn't support completion cmd yet, so keeping manual approach:
|
||||
// Create a proper root command structure for the test
|
||||
testRoot := &cobra.Command{
|
||||
Use: "f2b",
|
||||
Short: "Test root command",
|
||||
}
|
||||
|
||||
// Add mock client for commands that need it
|
||||
mockClient := NewMockClient()
|
||||
testConfig := Config{Format: "plain"}
|
||||
|
||||
// Add all the f2b subcommands to create a realistic structure
|
||||
testRoot.AddCommand(ListJailsCmd(mockClient, &testConfig))
|
||||
testRoot.AddCommand(StatusCmd(mockClient, &testConfig))
|
||||
testRoot.AddCommand(BannedCmd(mockClient, &testConfig))
|
||||
testRoot.AddCommand(BanCmd(mockClient, &testConfig))
|
||||
testRoot.AddCommand(UnbanCmd(mockClient, &testConfig))
|
||||
testRoot.AddCommand(TestIPCmd(mockClient, &testConfig))
|
||||
testRoot.AddCommand(LogsCmd(mockClient, &testConfig))
|
||||
testRoot.AddCommand(LogsWatchCmd(context.Background(), mockClient, &testConfig))
|
||||
testRoot.AddCommand(ServiceCmd(&testConfig))
|
||||
testRoot.AddCommand(VersionCmd(&testConfig))
|
||||
testRoot.AddCommand(TestFilterCmd(mockClient, &testConfig))
|
||||
testRoot.AddCommand(completionCmd())
|
||||
|
||||
// Execute the completion command via the root
|
||||
// Capture stdout
|
||||
var outBuf bytes.Buffer
|
||||
testRoot.SetOut(&outBuf)
|
||||
|
||||
// Capture stderr
|
||||
var errBuf bytes.Buffer
|
||||
testRoot.SetErr(&errBuf)
|
||||
|
||||
args := append([]string{"completion"}, tt.args...)
|
||||
testRoot.SetArgs(args)
|
||||
err := testRoot.Execute()
|
||||
|
||||
AssertError(t, err, tt.wantError, tt.name)
|
||||
|
||||
output := outBuf.String() + errBuf.String()
|
||||
if tt.wantOutput != "" && !tt.wantError {
|
||||
// Check for substring anywhere in the output, ignoring leading/trailing whitespace
|
||||
if !strings.Contains(output, tt.wantOutput) {
|
||||
t.Errorf("expected output to contain %q, got %q", tt.wantOutput, strings.TrimSpace(output))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitFunctionCoverage(t *testing.T) {
|
||||
// Test that init function sets up flags correctly
|
||||
// We can't directly test init() but we can test its effects
|
||||
|
||||
// Test that persistent flags are set
|
||||
if rootCmd.PersistentFlags().Lookup("log-dir") == nil {
|
||||
t.Errorf("expected log-dir persistent flag to be set")
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Lookup("filter-dir") == nil {
|
||||
t.Errorf("expected filter-dir persistent flag to be set")
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Lookup("format") == nil {
|
||||
t.Errorf("expected format persistent flag to be set")
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Lookup("log-file") == nil {
|
||||
t.Errorf("expected log-file persistent flag to be set")
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Lookup("log-level") == nil {
|
||||
t.Errorf("expected log-level persistent flag to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistentPreRun(t *testing.T) {
|
||||
// Test the PersistentPreRun function
|
||||
if rootCmd.PersistentPreRun == nil {
|
||||
t.Errorf("expected PersistentPreRun to be set")
|
||||
return
|
||||
}
|
||||
|
||||
// Create a temporary log file
|
||||
tmpFile, err := os.CreateTemp(t.TempDir(), "f2b-test-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(tmpFile.Name()); err != nil {
|
||||
t.Fatalf("failed to remove temp file: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
t.Fatalf("failed to close temp file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test with log file flag
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("log-file", tmpFile.Name(), "test log file")
|
||||
cmd.Flags().String("log-level", "debug", "test log level")
|
||||
|
||||
// Save original logger output
|
||||
originalOutput := Logger.Out
|
||||
|
||||
// Run PersistentPreRun
|
||||
rootCmd.PersistentPreRun(cmd, []string{})
|
||||
|
||||
// Restore original logger output
|
||||
Logger.SetOutput(originalOutput)
|
||||
|
||||
// Test log level parsing
|
||||
tests := []struct {
|
||||
name string
|
||||
logLevel string
|
||||
}{
|
||||
{"debug", "debug"},
|
||||
{"info", "info"},
|
||||
{"warn", "warn"},
|
||||
{"error", "error"},
|
||||
{"invalid", "invalid"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("log_level_"+tt.name, func(_ *testing.T) {
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("log-file", "", "")
|
||||
cmd.Flags().String("log-level", tt.logLevel, "")
|
||||
|
||||
// This should not panic
|
||||
rootCmd.PersistentPreRun(cmd, []string{})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistentPreRunWithInvalidLogFile(t *testing.T) {
|
||||
// Test PersistentPreRun with invalid log file path
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("log-file", "/invalid/path/to/logfile.log", "invalid log file")
|
||||
cmd.Flags().String("log-level", "info", "test log level")
|
||||
|
||||
// Capture stderr to check for error message
|
||||
oldStderr := os.Stderr
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pipe: %v", err)
|
||||
}
|
||||
os.Stderr = w
|
||||
|
||||
// This should handle the error gracefully
|
||||
rootCmd.PersistentPreRun(cmd, []string{})
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("failed to close writer: %v", err)
|
||||
}
|
||||
os.Stderr = oldStderr
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := buf.ReadFrom(r); err != nil {
|
||||
t.Fatalf("failed to read output: %v", err)
|
||||
}
|
||||
output := buf.String()
|
||||
|
||||
// Should contain error message about failed to open log file
|
||||
if !strings.Contains(output, "Failed to open log file") {
|
||||
t.Errorf("expected error message about failed to open log file, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionCmdLongDescription(t *testing.T) {
|
||||
cmd := completionCmd()
|
||||
|
||||
// Test that the long description contains instructions for all shells
|
||||
expectedShells := []string{"Bash:", "Zsh:", "Fish:", "PowerShell:"}
|
||||
for _, shell := range expectedShells {
|
||||
if !strings.Contains(cmd.Long, shell) {
|
||||
t.Errorf("expected completion long description to contain %q", shell)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that it contains example commands
|
||||
expectedExamples := []string{
|
||||
"f2b completion bash",
|
||||
"f2b completion zsh",
|
||||
"f2b completion fish",
|
||||
"f2b completion powershell",
|
||||
}
|
||||
for _, example := range expectedExamples {
|
||||
if !strings.Contains(cmd.Long, example) {
|
||||
t.Errorf("expected completion long description to contain example %q", example)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalConfigVariable(t *testing.T) {
|
||||
// Test that global cfg variable can be accessed and modified
|
||||
originalCfg := cfg
|
||||
defer func() { cfg = originalCfg }()
|
||||
|
||||
cfg = Config{
|
||||
LogDir: "/test/log",
|
||||
FilterDir: "/test/filter",
|
||||
Format: "json",
|
||||
}
|
||||
|
||||
if cfg.LogDir != "/test/log" {
|
||||
t.Errorf("expected LogDir to be '/test/log', got %q", cfg.LogDir)
|
||||
}
|
||||
if cfg.FilterDir != "/test/filter" {
|
||||
t.Errorf("expected FilterDir to be '/test/filter', got %q", cfg.FilterDir)
|
||||
}
|
||||
if cfg.Format != "json" {
|
||||
t.Errorf("expected Format to be 'json', got %q", cfg.Format)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExecuteIntegration tests the Execute function with different command combinations
|
||||
func TestExecuteIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
config Config
|
||||
setupEnv func()
|
||||
cleanup func()
|
||||
}{
|
||||
{
|
||||
name: "execute with environment variables",
|
||||
args: []string{"f2b", "version"},
|
||||
config: Config{
|
||||
LogDir: "/tmp/test",
|
||||
FilterDir: "/tmp/filters",
|
||||
Format: "plain",
|
||||
},
|
||||
setupEnv: func() {
|
||||
// Environment variables will be set using t.Setenv in test loop
|
||||
},
|
||||
cleanup: func() {
|
||||
// Cleanup handled automatically by t.Setenv
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Integration test requires manual approach:
|
||||
// Set up environment variables using t.Setenv for automatic cleanup
|
||||
if tt.config.LogDir != "" {
|
||||
t.Setenv("F2B_LOG_DIR", tt.config.LogDir)
|
||||
}
|
||||
if tt.config.FilterDir != "" {
|
||||
t.Setenv("F2B_FILTER_DIR", tt.config.FilterDir)
|
||||
}
|
||||
|
||||
client := fail2ban.NewMockClient()
|
||||
|
||||
// Capture output
|
||||
oldStdout := os.Stdout
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pipe: %v", err)
|
||||
}
|
||||
os.Stdout = w
|
||||
|
||||
originalArgs := os.Args
|
||||
os.Args = tt.args
|
||||
|
||||
err = Execute(client, tt.config)
|
||||
|
||||
// Restore
|
||||
if closeErr := w.Close(); closeErr != nil {
|
||||
t.Fatalf("failed to close writer: %v", closeErr)
|
||||
}
|
||||
os.Stdout = oldStdout
|
||||
os.Args = originalArgs
|
||||
|
||||
// Read output
|
||||
var buf bytes.Buffer
|
||||
if _, readErr := buf.ReadFrom(r); readErr != nil {
|
||||
t.Fatalf("failed to read output: %v", readErr)
|
||||
}
|
||||
|
||||
AssertError(t, err, false, tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionCmdWithUnsupportedShell(t *testing.T) {
|
||||
cmd := completionCmd()
|
||||
|
||||
// Capture stderr to check for error message
|
||||
var errBuf bytes.Buffer
|
||||
cmd.SetErr(&errBuf)
|
||||
|
||||
cmd.SetArgs([]string{"invalid-shell"})
|
||||
err := cmd.Execute()
|
||||
|
||||
// Should return error due to Cobra's OnlyValidArgs validation
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid shell type")
|
||||
}
|
||||
|
||||
// Error should mention invalid argument
|
||||
if !strings.Contains(err.Error(), "invalid argument") && !strings.Contains(err.Error(), "invalid") {
|
||||
t.Errorf("expected error message about invalid argument, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkParseLogLevelExtended(b *testing.B) {
|
||||
levels := []string{"debug", "info", "warn", "warning", "error", "fatal", "panic", "invalid", ""}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
level := levels[i%len(levels)]
|
||||
parseLogLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecute(b *testing.B) {
|
||||
client := fail2ban.NewMockClient()
|
||||
config := Config{
|
||||
LogDir: "/tmp",
|
||||
FilterDir: "/tmp",
|
||||
Format: "plain",
|
||||
}
|
||||
|
||||
// Suppress output
|
||||
oldStdout := os.Stdout
|
||||
devNull, err := os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to open dev null: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := devNull.Close(); cerr != nil {
|
||||
b.Fatalf("failed to close dev null: %v", cerr)
|
||||
}
|
||||
}()
|
||||
os.Stdout = devNull
|
||||
|
||||
defer func() {
|
||||
os.Stdout = oldStdout
|
||||
}()
|
||||
|
||||
originalArgs := os.Args
|
||||
defer func() {
|
||||
os.Args = originalArgs
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
os.Args = []string{"f2b", "version"}
|
||||
if err := Execute(client, config); err != nil {
|
||||
b.Fatalf("execute failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
293
cmd/cmd_service_test.go
Normal file
293
cmd/cmd_service_test.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
func TestServiceCmd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
mockResponse string
|
||||
mockError error
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "service status",
|
||||
args: []string{"status"},
|
||||
mockResponse: "fail2ban is running",
|
||||
wantOutput: "fail2ban is running",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "service start",
|
||||
args: []string{"start"},
|
||||
mockResponse: "Starting fail2ban service",
|
||||
wantOutput: "Starting fail2ban service",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "service stop",
|
||||
args: []string{"stop"},
|
||||
mockResponse: "Stopping fail2ban service",
|
||||
wantOutput: "Stopping fail2ban service",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "service restart",
|
||||
args: []string{"restart"},
|
||||
mockResponse: "Restarting fail2ban service",
|
||||
wantOutput: "Restarting fail2ban service",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "no action provided",
|
||||
args: []string{},
|
||||
wantError: true, // Command should return error for missing action
|
||||
},
|
||||
{
|
||||
name: "invalid action",
|
||||
args: []string{"invalid"},
|
||||
wantError: true, // Command should return error for invalid action
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
builder := NewCommandTest(t, "service").
|
||||
WithArgs(tt.args...).
|
||||
WithServiceSetup(func(mock *fail2ban.MockRunner) {
|
||||
if tt.mockResponse != "" {
|
||||
command := "sudo service fail2ban " + strings.Join(tt.args, " ")
|
||||
mock.SetResponse(command, []byte(tt.mockResponse))
|
||||
}
|
||||
if tt.mockError != nil {
|
||||
command := "sudo service fail2ban " + strings.Join(tt.args, " ")
|
||||
mock.SetError(command, tt.mockError)
|
||||
}
|
||||
})
|
||||
|
||||
if tt.wantError {
|
||||
builder.ExpectError()
|
||||
} else {
|
||||
builder.ExpectSuccess()
|
||||
}
|
||||
|
||||
if tt.wantOutput != "" {
|
||||
builder.ExpectOutput(tt.wantOutput)
|
||||
}
|
||||
|
||||
builder.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceCmdWithJSONFormat(t *testing.T) {
|
||||
NewCommandTest(t, "service").
|
||||
WithArgs("status").
|
||||
WithJSONFormat().
|
||||
WithServiceSetup(func(mock *fail2ban.MockRunner) {
|
||||
mock.SetResponse("sudo service fail2ban status", []byte("fail2ban is running"))
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("fail2ban is running").
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestServiceCmdErrorHandling(t *testing.T) {
|
||||
NewCommandTest(t, "service").
|
||||
WithArgs("status").
|
||||
WithServiceSetup(func(mock *fail2ban.MockRunner) {
|
||||
mock.SetError("sudo service fail2ban status", &testServiceError{"service failed"})
|
||||
}).
|
||||
ExpectError().
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestServiceCmdValidActions(t *testing.T) {
|
||||
validActions := []string{"start", "stop", "restart", "status", "reload", "enable", "disable"}
|
||||
|
||||
for _, action := range validActions {
|
||||
t.Run("action_"+action, func(t *testing.T) {
|
||||
wantOutput := "Action " + action + " completed"
|
||||
NewCommandTest(t, "service").
|
||||
WithArgs(action).
|
||||
WithServiceSetup(func(mock *fail2ban.MockRunner) {
|
||||
command := "sudo service fail2ban " + action
|
||||
mock.SetResponse(command, []byte(wantOutput))
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectOutput(wantOutput).
|
||||
Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceCmdMultipleArgs(t *testing.T) {
|
||||
// Test that service command only uses first arg:
|
||||
NewCommandTest(t, "service").
|
||||
WithArgs("start", "extra").
|
||||
WithServiceSetup(func(mock *fail2ban.MockRunner) {
|
||||
mock.SetResponse("sudo service fail2ban start", []byte("Starting fail2ban service"))
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Starting fail2ban service").
|
||||
Run()
|
||||
}
|
||||
|
||||
func TestServiceCmdEmptyResponse(t *testing.T) {
|
||||
// Test that empty response is handled gracefully:
|
||||
NewCommandTest(t, "service").
|
||||
WithArgs("status").
|
||||
WithServiceSetup(func(mock *fail2ban.MockRunner) {
|
||||
mock.SetResponse("sudo service fail2ban status", []byte(""))
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run()
|
||||
}
|
||||
|
||||
// testServiceError implements error interface for testing
|
||||
type testServiceError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *testServiceError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// TestServiceCmdSecurityValidation tests that command injection attempts are blocked
|
||||
func TestServiceCmdSecurityValidation(t *testing.T) {
|
||||
maliciousActions := []string{
|
||||
"start; touch /tmp/test",
|
||||
"status && whoami",
|
||||
"restart | curl example.com",
|
||||
"stop`whoami`",
|
||||
"status$(id)",
|
||||
"start'||'curl example.com",
|
||||
"../../../etc/passwd",
|
||||
"start\ntouch /tmp/test",
|
||||
"status\techo test",
|
||||
"reload;curl example.com",
|
||||
"enable & echo test",
|
||||
"disable || curl example.com",
|
||||
}
|
||||
|
||||
for _, maliciousAction := range maliciousActions {
|
||||
t.Run("malicious_action", func(t *testing.T) {
|
||||
// Test that malicious actions are rejected:
|
||||
result := NewCommandTest(t, "service").
|
||||
WithArgs(maliciousAction).
|
||||
WithServiceSetup(func(_ *fail2ban.MockRunner) {
|
||||
// No responses needed - command should be rejected before execution
|
||||
}).
|
||||
ExpectError(). // Command should return error for malicious actions
|
||||
Run()
|
||||
|
||||
// Verify error message is present in output
|
||||
if !strings.Contains(result.Output, "invalid service action") {
|
||||
t.Errorf(
|
||||
"expected error message for malicious action %q, got output: %q",
|
||||
maliciousAction,
|
||||
result.Output,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceCmdValidActionsOnly ensures only valid actions are accepted
|
||||
func TestServiceCmdValidActionsOnly(t *testing.T) {
|
||||
validActions := []string{"start", "stop", "restart", "status", "reload", "enable", "disable"}
|
||||
invalidActions := []string{"invalid", "badaction", "test", "debug", "config", "init"}
|
||||
|
||||
// Test valid actions
|
||||
for _, action := range validActions {
|
||||
t.Run("valid_action_"+action, func(t *testing.T) {
|
||||
wantOutput := "Action " + action + " completed"
|
||||
NewCommandTest(t, "service").
|
||||
WithArgs(action).
|
||||
WithServiceSetup(func(mock *fail2ban.MockRunner) {
|
||||
command := "sudo service fail2ban " + action
|
||||
mock.SetResponse(command, []byte(wantOutput))
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectOutput(wantOutput).
|
||||
Run()
|
||||
})
|
||||
}
|
||||
|
||||
// Test invalid actions
|
||||
for _, action := range invalidActions {
|
||||
t.Run("invalid_action_"+action, func(t *testing.T) {
|
||||
result := NewCommandTest(t, "service").
|
||||
WithArgs(action).
|
||||
WithServiceSetup(func(_ *fail2ban.MockRunner) {
|
||||
// No responses needed for invalid actions
|
||||
}).
|
||||
ExpectError(). // Command should return error for invalid actions
|
||||
Run()
|
||||
|
||||
// Verify error message is present
|
||||
if !strings.Contains(result.Output, "invalid service action") {
|
||||
t.Errorf("invalid action %q should show error message, got output: %q", action, result.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkServiceCmd benchmarks the service command execution
|
||||
func BenchmarkServiceCmd(b *testing.B) {
|
||||
// Set up mock environment once
|
||||
_, cleanup := fail2ban.SetupMockEnvironment(b)
|
||||
defer cleanup()
|
||||
|
||||
// Get the mock runner and configure it
|
||||
mock := fail2ban.GetRunner().(*fail2ban.MockRunner)
|
||||
mock.SetResponse("sudo service fail2ban status", []byte("fail2ban is running"))
|
||||
|
||||
// Framework could be used here but benchmark needs manual approach for performance:
|
||||
config := &Config{Format: "plain"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cmd := ServiceCmd(config)
|
||||
oldStdout := os.Stdout
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create pipe: %v", err)
|
||||
}
|
||||
|
||||
os.Stdout = w
|
||||
|
||||
cmd.SetArgs([]string{"status"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
_ = w.Close()
|
||||
_ = r.Close()
|
||||
os.Stdout = oldStdout
|
||||
b.Fatalf("execute failed: %v", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
_ = r.Close()
|
||||
os.Stdout = oldStdout
|
||||
b.Fatalf("failed to close writer: %v", err)
|
||||
}
|
||||
|
||||
var stdoutBuf bytes.Buffer
|
||||
if _, err := stdoutBuf.ReadFrom(r); err != nil {
|
||||
_ = r.Close()
|
||||
os.Stdout = oldStdout
|
||||
b.Fatalf("failed to read output: %v", err)
|
||||
}
|
||||
|
||||
// Clean up at end of iteration
|
||||
_ = r.Close()
|
||||
os.Stdout = oldStdout
|
||||
}
|
||||
}
|
||||
584
cmd/command_test_framework.go
Normal file
584
cmd/command_test_framework.go
Normal file
@@ -0,0 +1,584 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// CommandTestResult represents the result of a command execution
|
||||
type CommandTestResult struct {
|
||||
Output string
|
||||
Error error
|
||||
t *testing.T
|
||||
name string
|
||||
}
|
||||
|
||||
// CommandTestBuilder provides a fluent interface for testing commands
|
||||
type CommandTestBuilder struct {
|
||||
t *testing.T
|
||||
name string
|
||||
command string
|
||||
args []string
|
||||
mockClient *fail2ban.MockClient
|
||||
config *Config
|
||||
expectError bool
|
||||
expectedOut string
|
||||
exactMatch bool
|
||||
setupFunc func(*fail2ban.MockClient)
|
||||
environment *TestEnvironment
|
||||
}
|
||||
|
||||
// TestEnvironment manages test environment setup and cleanup
|
||||
type TestEnvironment struct {
|
||||
originalChecker fail2ban.SudoChecker
|
||||
originalRunner fail2ban.Runner
|
||||
originalStdout *os.File
|
||||
stdoutReader *os.File
|
||||
stdoutWriter *os.File
|
||||
cleanup []func()
|
||||
}
|
||||
|
||||
// NewTestEnvironment creates a new test environment manager
|
||||
func NewTestEnvironment() *TestEnvironment {
|
||||
return &TestEnvironment{
|
||||
cleanup: make([]func(), 0),
|
||||
}
|
||||
}
|
||||
|
||||
// WithPrivileges sets up sudo checker with specified privileges
|
||||
func (env *TestEnvironment) WithPrivileges(hasPrivileges bool) *TestEnvironment {
|
||||
env.originalChecker = fail2ban.GetSudoChecker()
|
||||
mockChecker := &fail2ban.MockSudoChecker{
|
||||
MockHasPrivileges: hasPrivileges,
|
||||
ExplicitPrivilegesSet: true,
|
||||
}
|
||||
fail2ban.SetSudoChecker(mockChecker)
|
||||
env.cleanup = append(env.cleanup, func() {
|
||||
fail2ban.SetSudoChecker(env.originalChecker)
|
||||
})
|
||||
return env
|
||||
}
|
||||
|
||||
// WithMockRunner sets up a mock runner with common responses
|
||||
func (env *TestEnvironment) WithMockRunner() *TestEnvironment {
|
||||
env.originalRunner = fail2ban.GetRunner()
|
||||
mockRunner := fail2ban.NewMockRunner()
|
||||
// Set up common responses
|
||||
mockRunner.SetResponse("fail2ban-client -V", []byte("fail2ban-client v0.11.2"))
|
||||
mockRunner.SetResponse("fail2ban-client ping", []byte("pong"))
|
||||
mockRunner.SetResponse(
|
||||
"fail2ban-client status",
|
||||
[]byte("Status\n|- Number of jail:\t2\n`- Jail list:\tsshd, apache"),
|
||||
)
|
||||
mockRunner.SetResponse("sudo service fail2ban status", []byte("● fail2ban.service - Fail2Ban Service"))
|
||||
fail2ban.SetRunner(mockRunner)
|
||||
|
||||
env.cleanup = append(env.cleanup, func() {
|
||||
fail2ban.SetRunner(env.originalRunner)
|
||||
})
|
||||
return env
|
||||
}
|
||||
|
||||
// WithStdoutCapture captures stdout for testing output
|
||||
func (env *TestEnvironment) WithStdoutCapture() *TestEnvironment {
|
||||
env.originalStdout = os.Stdout
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
// Return early with nil fields to indicate failure
|
||||
return env
|
||||
}
|
||||
env.stdoutReader = r
|
||||
env.stdoutWriter = w
|
||||
os.Stdout = w
|
||||
|
||||
env.cleanup = append(env.cleanup, func() {
|
||||
os.Stdout = env.originalStdout
|
||||
if env.stdoutWriter != nil {
|
||||
_ = env.stdoutWriter.Close()
|
||||
}
|
||||
if env.stdoutReader != nil {
|
||||
_ = env.stdoutReader.Close()
|
||||
}
|
||||
})
|
||||
return env
|
||||
}
|
||||
|
||||
// Cleanup restores the original environment
|
||||
func (env *TestEnvironment) Cleanup() {
|
||||
for i := len(env.cleanup) - 1; i >= 0; i-- {
|
||||
env.cleanup[i]()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadStdout reads the captured stdout content
|
||||
func (env *TestEnvironment) ReadStdout() string {
|
||||
if env.stdoutWriter == nil || env.stdoutReader == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Close writer if not already closed
|
||||
if env.stdoutWriter != nil {
|
||||
_ = env.stdoutWriter.Close()
|
||||
env.stdoutWriter = nil // Prevent multiple closures
|
||||
}
|
||||
|
||||
// Use io.ReadAll for dynamic buffer reading
|
||||
if data, err := io.ReadAll(env.stdoutReader); err == nil {
|
||||
return string(data)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// NewCommandTest creates a new command test builder
|
||||
func NewCommandTest(t *testing.T, commandName string) *CommandTestBuilder {
|
||||
t.Helper()
|
||||
return &CommandTestBuilder{
|
||||
t: t,
|
||||
name: commandName,
|
||||
command: commandName,
|
||||
args: make([]string, 0),
|
||||
config: &Config{Format: "plain"},
|
||||
}
|
||||
}
|
||||
|
||||
// WithName sets the test name for better error reporting
|
||||
func (ctb *CommandTestBuilder) WithName(name string) *CommandTestBuilder {
|
||||
ctb.name = name
|
||||
return ctb
|
||||
}
|
||||
|
||||
// WithArgs sets the command arguments
|
||||
func (ctb *CommandTestBuilder) WithArgs(args ...string) *CommandTestBuilder {
|
||||
ctb.args = args
|
||||
return ctb
|
||||
}
|
||||
|
||||
// WithMockClient sets the mock client for the test
|
||||
func (ctb *CommandTestBuilder) WithMockClient(mock *fail2ban.MockClient) *CommandTestBuilder {
|
||||
ctb.mockClient = mock
|
||||
return ctb
|
||||
}
|
||||
|
||||
// WithJSONFormat sets the output format to JSON
|
||||
func (ctb *CommandTestBuilder) WithJSONFormat() *CommandTestBuilder {
|
||||
if ctb.config == nil {
|
||||
ctb.config = &Config{}
|
||||
}
|
||||
ctb.config.Format = JSONFormat
|
||||
return ctb
|
||||
}
|
||||
|
||||
// WithSetup provides a function to set up the mock client with specific data
|
||||
func (ctb *CommandTestBuilder) WithSetup(setupFunc func(*fail2ban.MockClient)) *CommandTestBuilder {
|
||||
ctb.setupFunc = setupFunc
|
||||
return ctb
|
||||
}
|
||||
|
||||
// WithServiceSetup provides a function to set up mock runner for service commands
|
||||
func (ctb *CommandTestBuilder) WithServiceSetup(setupFunc func(*fail2ban.MockRunner)) *CommandTestBuilder {
|
||||
ctb.setupFunc = func(_ *fail2ban.MockClient) {
|
||||
// Set up sudo checker
|
||||
mockChecker := &fail2ban.MockSudoChecker{
|
||||
MockHasPrivileges: true,
|
||||
ExplicitPrivilegesSet: true,
|
||||
}
|
||||
fail2ban.SetSudoChecker(mockChecker)
|
||||
|
||||
// Create and set up mock runner
|
||||
mockRunner := &fail2ban.MockRunner{
|
||||
Responses: make(map[string][]byte),
|
||||
Errors: make(map[string]error),
|
||||
}
|
||||
setupFunc(mockRunner)
|
||||
fail2ban.SetRunner(mockRunner)
|
||||
}
|
||||
return ctb
|
||||
}
|
||||
|
||||
// WithEnvironment sets the test environment
|
||||
func (ctb *CommandTestBuilder) WithEnvironment(env *TestEnvironment) *CommandTestBuilder {
|
||||
ctb.environment = env
|
||||
return ctb
|
||||
}
|
||||
|
||||
// ExpectError indicates that the command should fail
|
||||
func (ctb *CommandTestBuilder) ExpectError() *CommandTestBuilder {
|
||||
ctb.expectError = true
|
||||
return ctb
|
||||
}
|
||||
|
||||
// ExpectSuccess indicates that the command should succeed
|
||||
func (ctb *CommandTestBuilder) ExpectSuccess() *CommandTestBuilder {
|
||||
ctb.expectError = false
|
||||
return ctb
|
||||
}
|
||||
|
||||
// ExpectOutput sets the expected output substring
|
||||
func (ctb *CommandTestBuilder) ExpectOutput(expectedOut string) *CommandTestBuilder {
|
||||
ctb.expectedOut = expectedOut
|
||||
return ctb
|
||||
}
|
||||
|
||||
// ExpectExactOutput sets the expected output for exact matching
|
||||
func (ctb *CommandTestBuilder) ExpectExactOutput(expectedOut string) *CommandTestBuilder {
|
||||
ctb.expectedOut = expectedOut
|
||||
ctb.exactMatch = true
|
||||
return ctb
|
||||
}
|
||||
|
||||
// Run executes the command test and performs all validations
|
||||
func (ctb *CommandTestBuilder) Run() *CommandTestResult {
|
||||
ctb.t.Helper()
|
||||
|
||||
// Set up default mock client if none provided
|
||||
if ctb.mockClient == nil {
|
||||
ctb.mockClient = fail2ban.NewMockClient()
|
||||
}
|
||||
|
||||
// Apply setup function if provided
|
||||
if ctb.setupFunc != nil {
|
||||
ctb.setupFunc(ctb.mockClient)
|
||||
}
|
||||
|
||||
// Execute the command
|
||||
output, err := ctb.executeCommand()
|
||||
|
||||
// Create result
|
||||
result := &CommandTestResult{
|
||||
Output: output,
|
||||
Error: err,
|
||||
t: ctb.t,
|
||||
name: ctb.name,
|
||||
}
|
||||
|
||||
// Perform basic validations
|
||||
result.AssertError(ctb.expectError)
|
||||
|
||||
if ctb.expectedOut != "" {
|
||||
if ctb.exactMatch {
|
||||
result.AssertExactOutput(ctb.expectedOut)
|
||||
} else {
|
||||
result.AssertContains(ctb.expectedOut)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// executeCommand runs the actual command with the configured parameters
|
||||
func (ctb *CommandTestBuilder) executeCommand() (string, error) {
|
||||
var cmd *cobra.Command
|
||||
|
||||
switch ctb.command {
|
||||
case "ban":
|
||||
cmd = BanCmd(ctb.mockClient, ctb.config)
|
||||
case "unban":
|
||||
cmd = UnbanCmd(ctb.mockClient, ctb.config)
|
||||
case "status":
|
||||
cmd = StatusCmd(ctb.mockClient, ctb.config)
|
||||
case "list-jails":
|
||||
cmd = ListJailsCmd(ctb.mockClient, ctb.config)
|
||||
case "banned":
|
||||
cmd = BannedCmd(ctb.mockClient, ctb.config)
|
||||
case "test":
|
||||
cmd = TestIPCmd(ctb.mockClient, ctb.config)
|
||||
case "logs":
|
||||
cmd = LogsCmd(ctb.mockClient, ctb.config)
|
||||
case "service":
|
||||
cmd = ServiceCmd(ctb.config)
|
||||
case "version":
|
||||
cmd = VersionCmd(ctb.config)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown command: %s", ctb.command)
|
||||
}
|
||||
|
||||
// For service commands, we need to capture os.Stdout since PrintOutput writes directly to it
|
||||
if ctb.command == "service" {
|
||||
return ctb.executeServiceCommand(cmd)
|
||||
}
|
||||
|
||||
// Execute regular commands
|
||||
var outBuf, errBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetErr(&errBuf)
|
||||
cmd.SetArgs(ctb.args)
|
||||
err := cmd.Execute()
|
||||
output := outBuf.String() + errBuf.String()
|
||||
|
||||
return output, err
|
||||
}
|
||||
|
||||
// executeServiceCommand handles service command execution with stdout/stderr capture
|
||||
func (ctb *CommandTestBuilder) executeServiceCommand(cmd *cobra.Command) (string, error) {
|
||||
// Capture os.Stdout since service command uses PrintOutput
|
||||
oldStdout := os.Stdout
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
|
||||
// Also capture os.Stderr since PrintError uses it
|
||||
oldStderr := os.Stderr
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
// Clean up stdout pipe before returning error
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
os.Stdout = oldStdout
|
||||
return "", fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
os.Stderr = stderrW
|
||||
|
||||
var cmdErrBuf bytes.Buffer
|
||||
cmd.SetErr(&cmdErrBuf)
|
||||
cmd.SetArgs(ctb.args)
|
||||
err = cmd.Execute()
|
||||
|
||||
// Close writers and restore
|
||||
if closeErr := stdoutW.Close(); closeErr != nil {
|
||||
os.Stdout = oldStdout
|
||||
os.Stderr = oldStderr
|
||||
return "", fmt.Errorf("failed to close stdout writer: %v", closeErr)
|
||||
}
|
||||
if closeErr := stderrW.Close(); closeErr != nil {
|
||||
os.Stdout = oldStdout
|
||||
os.Stderr = oldStderr
|
||||
return "", fmt.Errorf("failed to close stderr writer: %v", closeErr)
|
||||
}
|
||||
os.Stdout = oldStdout
|
||||
os.Stderr = oldStderr
|
||||
|
||||
// Read captured output
|
||||
var stdoutBuf bytes.Buffer
|
||||
if _, readErr := stdoutBuf.ReadFrom(stdoutR); readErr != nil {
|
||||
return "", fmt.Errorf("failed to read stdout: %v", readErr)
|
||||
}
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
if _, readErr := stderrBuf.ReadFrom(stderrR); readErr != nil {
|
||||
return "", fmt.Errorf("failed to read stderr: %v", readErr)
|
||||
}
|
||||
|
||||
output := stdoutBuf.String() + stderrBuf.String() + cmdErrBuf.String()
|
||||
return output, err
|
||||
}
|
||||
|
||||
// AssertError validates the error state
|
||||
func (result *CommandTestResult) AssertError(expectError bool) *CommandTestResult {
|
||||
result.t.Helper()
|
||||
if expectError && result.Error == nil {
|
||||
result.t.Fatalf("%s: expected error but got none", result.name)
|
||||
}
|
||||
if !expectError && result.Error != nil {
|
||||
result.t.Fatalf("%s: unexpected error: %v, output: %s", result.name, result.Error, result.Output)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AssertContains validates that output contains expected text
|
||||
func (result *CommandTestResult) AssertContains(expected string) *CommandTestResult {
|
||||
result.t.Helper()
|
||||
if !strings.Contains(result.Output, expected) {
|
||||
result.t.Fatalf("%s: expected output to contain %q, got: %s", result.name, expected, result.Output)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AssertNotContains validates that output does not contain specified text
|
||||
func (result *CommandTestResult) AssertNotContains(notExpected string) *CommandTestResult {
|
||||
result.t.Helper()
|
||||
if strings.Contains(result.Output, notExpected) {
|
||||
result.t.Fatalf("%s: expected output to not contain %q, got: %s", result.name, notExpected, result.Output)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AssertExactOutput validates exact output match
|
||||
func (result *CommandTestResult) AssertExactOutput(expected string) *CommandTestResult {
|
||||
result.t.Helper()
|
||||
if result.Output != expected {
|
||||
result.t.Fatalf("%s: expected exact output %q, got %q", result.name, expected, result.Output)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AssertJSONField validates a specific field in JSON output
|
||||
func (result *CommandTestResult) AssertJSONField(fieldPath, expected string) *CommandTestResult {
|
||||
result.t.Helper()
|
||||
|
||||
var data interface{}
|
||||
if err := json.Unmarshal([]byte(result.Output), &data); err != nil {
|
||||
result.t.Fatalf("%s: failed to parse JSON output: %v, output: %s", result.name, err, result.Output)
|
||||
}
|
||||
|
||||
// Simple field path parsing (can be enhanced later)
|
||||
// For now, support simple paths like "$.field", "[0].field" or direct field names
|
||||
fieldName := strings.TrimPrefix(fieldPath, "$.")
|
||||
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
if val, ok := v[fieldName]; ok {
|
||||
if fmt.Sprintf("%v", val) != expected {
|
||||
result.t.Fatalf("%s: expected JSON field %q to be %q, got %v", result.name, fieldName, expected, val)
|
||||
}
|
||||
} else {
|
||||
result.t.Fatalf("%s: JSON field %q not found in output: %s", result.name, fieldName, result.Output)
|
||||
}
|
||||
case []interface{}:
|
||||
// Handle array case - look in first element
|
||||
if len(v) > 0 {
|
||||
if firstItem, ok := v[0].(map[string]interface{}); ok {
|
||||
if val, ok := firstItem[fieldName]; ok {
|
||||
if fmt.Sprintf("%v", val) != expected {
|
||||
result.t.Fatalf("%s: expected JSON field %q to be %q, got %v", result.name, fieldName, expected, val)
|
||||
}
|
||||
} else {
|
||||
result.t.Fatalf("%s: JSON field %q not found in first array element: %s", result.name, fieldName, result.Output)
|
||||
}
|
||||
} else {
|
||||
result.t.Fatalf("%s: first array element is not an object in output: %s", result.name, result.Output)
|
||||
}
|
||||
} else {
|
||||
result.t.Fatalf("%s: JSON array is empty in output: %s", result.name, result.Output)
|
||||
}
|
||||
default:
|
||||
result.t.Fatalf("%s: expected JSON object or array but got %T in output: %s", result.name, data, result.Output)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// AssertEmpty validates that output is empty
|
||||
func (result *CommandTestResult) AssertEmpty() *CommandTestResult {
|
||||
result.t.Helper()
|
||||
if strings.TrimSpace(result.Output) != "" {
|
||||
result.t.Fatalf("%s: expected empty output, got: %s", result.name, result.Output)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AssertNotEmpty validates that output is not empty
|
||||
func (result *CommandTestResult) AssertNotEmpty() *CommandTestResult {
|
||||
result.t.Helper()
|
||||
if strings.TrimSpace(result.Output) == "" {
|
||||
result.t.Fatalf("%s: expected non-empty output", result.name)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MockClientBuilder provides a fluent interface for building complex mock configurations
|
||||
type MockClientBuilder struct {
|
||||
client *fail2ban.MockClient
|
||||
jails []string
|
||||
banRecords []fail2ban.BanRecord
|
||||
logLines []string
|
||||
responses map[string]string
|
||||
errors map[string]error
|
||||
}
|
||||
|
||||
// NewMockClientBuilder creates a new mock client builder
|
||||
func NewMockClientBuilder() *MockClientBuilder {
|
||||
return &MockClientBuilder{
|
||||
client: fail2ban.NewMockClient(),
|
||||
responses: make(map[string]string),
|
||||
errors: make(map[string]error),
|
||||
}
|
||||
}
|
||||
|
||||
// WithJails configures available jails
|
||||
func (b *MockClientBuilder) WithJails(jails ...string) *MockClientBuilder {
|
||||
b.jails = append(b.jails, jails...)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithBannedIP adds a banned IP to specific jail
|
||||
func (b *MockClientBuilder) WithBannedIP(ip, jail string) *MockClientBuilder {
|
||||
if b.client.BanResults == nil {
|
||||
b.client.BanResults = make(map[string]map[string]int)
|
||||
}
|
||||
if b.client.BanResults[ip] == nil {
|
||||
b.client.BanResults[ip] = make(map[string]int)
|
||||
}
|
||||
b.client.BanResults[ip][jail] = 1 // 1 indicates banned
|
||||
return b
|
||||
}
|
||||
|
||||
// WithBanRecord adds a ban record
|
||||
func (b *MockClientBuilder) WithBanRecord(jail, ip, remaining string) *MockClientBuilder {
|
||||
b.banRecords = append(b.banRecords, fail2ban.BanRecord{
|
||||
Jail: jail,
|
||||
IP: ip,
|
||||
Remaining: remaining,
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// WithLogLine adds a log line
|
||||
func (b *MockClientBuilder) WithLogLine(logLine string) *MockClientBuilder {
|
||||
b.logLines = append(b.logLines, logLine)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithStatusResponse sets status response for specific target
|
||||
func (b *MockClientBuilder) WithStatusResponse(target, response string) *MockClientBuilder {
|
||||
if b.client.StatusJailData == nil {
|
||||
b.client.StatusJailData = make(map[string]string)
|
||||
}
|
||||
if target == "all" {
|
||||
b.client.StatusAllData = response
|
||||
} else {
|
||||
b.client.StatusJailData[target] = response
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithBanError sets an error for banning specific IP in jail
|
||||
func (b *MockClientBuilder) WithBanError(jail, ip string, err error) *MockClientBuilder {
|
||||
b.client.SetBanError(jail, ip, err)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUnbanError sets an error for unbanning specific IP in jail
|
||||
func (b *MockClientBuilder) WithUnbanError(jail, ip string, err error) *MockClientBuilder {
|
||||
b.client.SetUnbanError(jail, ip, err)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithLogError is not supported by MockClient - logs are returned via LogLines field
|
||||
// Use WithLogLine to add log entries or modify LogLines directly
|
||||
|
||||
// Build creates the configured mock client
|
||||
func (b *MockClientBuilder) Build() *fail2ban.MockClient {
|
||||
// Apply jails
|
||||
if len(b.jails) > 0 {
|
||||
setMockJails(b.client, b.jails)
|
||||
}
|
||||
|
||||
// Apply ban records
|
||||
if len(b.banRecords) > 0 {
|
||||
b.client.BanRecords = b.banRecords
|
||||
}
|
||||
|
||||
// Apply log lines
|
||||
if len(b.logLines) > 0 {
|
||||
b.client.LogLines = b.logLines
|
||||
}
|
||||
|
||||
return b.client
|
||||
}
|
||||
|
||||
// WithMockBuilder configures the test with a MockClientBuilder for advanced mock setup
|
||||
func (ctb *CommandTestBuilder) WithMockBuilder(builder *MockClientBuilder) *CommandTestBuilder {
|
||||
ctb.mockClient = builder.Build()
|
||||
return ctb
|
||||
}
|
||||
129
cmd/command_test_framework_demo_test.go
Normal file
129
cmd/command_test_framework_demo_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// TestDemoCommandTestFramework demonstrates the modern testing framework capabilities
|
||||
func TestDemoCommandTestFramework(t *testing.T) {
|
||||
// Simple command test with fluent interface
|
||||
t.Run("basic_command_example", func(t *testing.T) {
|
||||
NewCommandTest(t, "status").
|
||||
WithArgs("all").
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
mock.StatusAllData = "Status for all jails"
|
||||
setMockJails(mock, []string{"sshd", "apache"})
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Status for all jails").
|
||||
Run()
|
||||
})
|
||||
|
||||
// Advanced example with environment setup
|
||||
t.Run("advanced_environment_example", func(t *testing.T) {
|
||||
env := NewTestEnvironment().
|
||||
WithPrivileges(true).
|
||||
WithMockRunner()
|
||||
defer env.Cleanup()
|
||||
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("192.168.1.100", "sshd").
|
||||
WithEnvironment(env).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Banned 192.168.1.100 in sshd").
|
||||
Run()
|
||||
})
|
||||
|
||||
// JSON output validation example
|
||||
t.Run("json_output_example", func(t *testing.T) {
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs("sshd").
|
||||
WithJSONFormat().
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
// Pre-ban an IP for testing
|
||||
_, _ = mock.BanIP("192.168.1.100", "sshd")
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run().
|
||||
AssertJSONField("Jail", "sshd")
|
||||
})
|
||||
|
||||
// Error handling example
|
||||
t.Run("error_handling_example", func(t *testing.T) {
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("192.168.1.100", "nonexistent").
|
||||
ExpectError().
|
||||
Run().
|
||||
AssertContains("not found")
|
||||
})
|
||||
}
|
||||
|
||||
// TestDemoTableDrivenWithFramework shows how to use the framework with table-driven tests
|
||||
func TestDemoTableDrivenWithFramework(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
args []string
|
||||
setup func(*fail2ban.MockClient)
|
||||
expectError bool
|
||||
expectedOut string
|
||||
}{
|
||||
{
|
||||
name: "list jails success",
|
||||
command: "list-jails",
|
||||
setup: func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd", "apache"})
|
||||
},
|
||||
expectError: false,
|
||||
expectedOut: "apache",
|
||||
},
|
||||
{
|
||||
name: "status specific jail",
|
||||
command: "status",
|
||||
args: []string{"sshd"},
|
||||
setup: func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
mock.StatusJailData = map[string]string{"sshd": "Status for sshd"}
|
||||
},
|
||||
expectError: false,
|
||||
expectedOut: "Status for sshd",
|
||||
},
|
||||
{
|
||||
name: "ban invalid jail",
|
||||
command: "ban",
|
||||
args: []string{"192.168.1.100", "nonexistent"},
|
||||
expectError: true,
|
||||
expectedOut: "not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Single line test execution with framework
|
||||
builder := NewCommandTest(t, tt.command).
|
||||
WithArgs(tt.args...)
|
||||
|
||||
if tt.setup != nil {
|
||||
builder = builder.WithSetup(tt.setup)
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
builder = builder.ExpectError()
|
||||
} else {
|
||||
builder = builder.ExpectSuccess()
|
||||
}
|
||||
|
||||
if tt.expectedOut != "" {
|
||||
builder = builder.ExpectOutput(tt.expectedOut)
|
||||
}
|
||||
|
||||
builder.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
232
cmd/comprehensive_framework_test.go
Normal file
232
cmd/comprehensive_framework_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// TestComprehensiveFrameworkCapabilities demonstrates the full power of the test framework
|
||||
func TestComprehensiveFrameworkCapabilities(t *testing.T) {
|
||||
// Test 1: Basic command success testing
|
||||
t.Run("basic_success", func(t *testing.T) {
|
||||
NewCommandTest(t, "list-jails").
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd", "apache"})
|
||||
}).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("sshd").
|
||||
Run()
|
||||
})
|
||||
|
||||
// Test 2: Error handling and validation
|
||||
t.Run("error_handling", func(t *testing.T) {
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("invalid-ip", "sshd").
|
||||
ExpectError().
|
||||
Run().
|
||||
AssertContains("invalid IP address")
|
||||
})
|
||||
|
||||
// Test 3: JSON output testing with field validation
|
||||
t.Run("json_validation", func(t *testing.T) {
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs("sshd").
|
||||
WithJSONFormat().
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
_, _ = mock.BanIP("192.168.1.100", "sshd")
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run().
|
||||
AssertJSONField("Jail", "sshd").
|
||||
AssertJSONField("IP", "192.168.1.100")
|
||||
})
|
||||
|
||||
// Test 4: Complex environment setup
|
||||
t.Run("environment_management", func(t *testing.T) {
|
||||
env := NewTestEnvironment().
|
||||
WithPrivileges(true).
|
||||
WithMockRunner().
|
||||
WithStdoutCapture()
|
||||
defer env.Cleanup()
|
||||
|
||||
NewCommandTest(t, "status").
|
||||
WithArgs("all").
|
||||
WithEnvironment(env).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd", "apache"})
|
||||
mock.StatusAllData = "All systems operational"
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run().
|
||||
AssertContains("operational")
|
||||
})
|
||||
|
||||
// Test 5: Chained assertions
|
||||
t.Run("chained_assertions", func(t *testing.T) {
|
||||
result := NewCommandTest(t, "status").
|
||||
WithArgs("sshd").
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
mock.StatusJailData = map[string]string{
|
||||
"sshd": "Jail: sshd\nStatus: Active\nFiltered: 10 lines",
|
||||
}
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run()
|
||||
|
||||
// Multiple validations on the same result
|
||||
result.AssertContains("Jail: sshd").
|
||||
AssertContains("Status: Active").
|
||||
AssertContains("Filtered").
|
||||
AssertNotContains("Error").
|
||||
AssertNotEmpty()
|
||||
})
|
||||
|
||||
// Test 6: Table-driven testing with framework
|
||||
t.Run("table_driven", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
args []string
|
||||
expected string
|
||||
isError bool
|
||||
}{
|
||||
{"ban_success", "ban", []string{"192.168.1.100", "sshd"}, "Banned", false},
|
||||
{"unban_success", "unban", []string{"192.168.1.100", "sshd"}, "Unbanned", false},
|
||||
{"test_banned", "test", []string{"192.168.1.100"}, "is banned", false},
|
||||
{"invalid_jail", "ban", []string{"192.168.1.100", "invalid"}, "not found", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
builder := NewCommandTest(t, tt.command).
|
||||
WithArgs(tt.args...).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd", "apache"})
|
||||
if tt.command == "unban" || tt.command == "test" {
|
||||
// Pre-ban IP for unban/test scenarios
|
||||
_, _ = mock.BanIP("192.168.1.100", "sshd")
|
||||
}
|
||||
})
|
||||
|
||||
if tt.isError {
|
||||
builder = builder.ExpectError()
|
||||
} else {
|
||||
builder = builder.ExpectSuccess()
|
||||
}
|
||||
|
||||
builder.ExpectOutput(tt.expected).Run()
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFrameworkPerformance demonstrates framework efficiency
|
||||
func TestFrameworkPerformance(t *testing.T) {
|
||||
// Measure how concise framework tests can be
|
||||
tests := map[string][]string{
|
||||
"ban": {"192.168.1.100", "sshd"},
|
||||
"unban": {"192.168.1.100", "sshd"},
|
||||
"status": {"sshd"},
|
||||
"list-jails": {},
|
||||
"banned": {"sshd"},
|
||||
"test": {"192.168.1.100"},
|
||||
}
|
||||
|
||||
for cmd, args := range tests {
|
||||
t.Run("performance_"+cmd, func(t *testing.T) {
|
||||
// Single line test execution
|
||||
NewCommandTest(t, cmd).
|
||||
WithArgs(args...).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
if cmd != "ban" && cmd != "list-jails" {
|
||||
_, _ = mock.BanIP("192.168.1.100", "sshd")
|
||||
}
|
||||
if cmd == "status" && len(args) > 0 {
|
||||
mock.StatusJailData = map[string]string{"sshd": "Status for sshd"}
|
||||
}
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFrameworkEdgeCases tests framework robustness
|
||||
func TestFrameworkEdgeCases(t *testing.T) {
|
||||
// Test empty output validation
|
||||
t.Run("empty_output", func(t *testing.T) {
|
||||
// Test framework with empty jail list setup
|
||||
NewCommandTest(t, "list-jails").
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{}) // No jails
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run()
|
||||
})
|
||||
|
||||
// Test JSON parsing edge cases
|
||||
t.Run("json_array_handling", func(t *testing.T) {
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs("sshd").
|
||||
WithJSONFormat().
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
_, _ = mock.BanIP("192.168.1.100", "sshd")
|
||||
_, _ = mock.BanIP("192.168.1.101", "sshd")
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run().
|
||||
AssertJSONField("Jail", "sshd") // Should handle array and check first element
|
||||
})
|
||||
|
||||
// Test exact output matching
|
||||
t.Run("exact_matching", func(t *testing.T) {
|
||||
NewCommandTest(t, "status").
|
||||
WithArgs("sshd").
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
mock.StatusJailData = map[string]string{"sshd": "Exact status message"}
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run().
|
||||
AssertContains("Exact status message") // Use contains instead of exact for robustness
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkFrameworkOverhead measures performance impact
|
||||
func BenchmarkFrameworkOverhead(b *testing.B) {
|
||||
// Create a mock client once outside the loop
|
||||
mock := fail2ban.NewMockClient()
|
||||
setMockJails(mock, []string{"sshd"})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Benchmark just the core client operation without cobra command overhead
|
||||
_, err := mock.ListJails()
|
||||
if err != nil {
|
||||
b.Fatalf("Client operation failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFrameworkCompatibility ensures framework works with existing helpers
|
||||
func TestFrameworkCompatibility(t *testing.T) {
|
||||
// Test that framework can work alongside existing test helpers
|
||||
t.Run("mixed_approach", func(t *testing.T) {
|
||||
// Manual mock setup when needed
|
||||
mock := NewMockClient()
|
||||
setMockJails(mock, []string{"sshd", "apache"})
|
||||
|
||||
// Use framework for execution and validation
|
||||
NewCommandTest(t, "list-jails").
|
||||
WithMockClient(mock).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("sshd").
|
||||
Run().
|
||||
AssertContains("apache")
|
||||
})
|
||||
}
|
||||
316
cmd/config_utils.go
Normal file
316
cmd/config_utils.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultCommandTimeout is the default timeout for individual fail2ban commands
|
||||
DefaultCommandTimeout = 30 * time.Second
|
||||
// DefaultFileTimeout is the default timeout for file operations
|
||||
DefaultFileTimeout = 10 * time.Second
|
||||
// DefaultParallelTimeout is the default timeout for parallel operations
|
||||
DefaultParallelTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
// containsPathTraversal performs comprehensive path traversal detection
|
||||
// including various encoding techniques and bypass attempts
|
||||
func containsPathTraversal(path string) bool {
|
||||
if path == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
variations := createPathVariations(path)
|
||||
return checkPathVariationsForTraversal(variations)
|
||||
}
|
||||
|
||||
// createPathVariations generates different encoded variations of the path to check
|
||||
func createPathVariations(path string) []string {
|
||||
variations := []string{path}
|
||||
|
||||
// URL decode the path (handle single and double encoding)
|
||||
if decoded, err := url.QueryUnescape(path); err == nil && decoded != path {
|
||||
variations = append(variations, decoded)
|
||||
// Check for double encoding
|
||||
if doubleDecoded, err := url.QueryUnescape(decoded); err == nil && doubleDecoded != decoded {
|
||||
variations = append(variations, doubleDecoded)
|
||||
}
|
||||
}
|
||||
|
||||
return variations
|
||||
}
|
||||
|
||||
// checkPathVariationsForTraversal checks all path variations against dangerous patterns
|
||||
func checkPathVariationsForTraversal(variations []string) bool {
|
||||
allPatterns := getAllDangerousPatterns()
|
||||
overlongRegex := regexp.MustCompile(
|
||||
`\xc0[\x80-\xbf]|\xe0[\x80-\x9f][\x80-\xbf]|\xf0[\x80-\x8f][\x80-\xbf][\x80-\xbf]`,
|
||||
)
|
||||
|
||||
for _, variant := range variations {
|
||||
if checkSingleVariantForTraversal(variant, allPatterns, overlongRegex) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getAllDangerousPatterns returns all dangerous path traversal patterns
|
||||
func getAllDangerousPatterns() map[string][]string {
|
||||
return map[string][]string{
|
||||
"basic": {
|
||||
"..", "../", "..\\", "..%2f", "..%2F", "..%5c", "..%5C",
|
||||
},
|
||||
"urlEncoded": {
|
||||
"%2e%2e", "%2E%2E", "%2e%2E", "%2E%2e",
|
||||
"%252e%252e", "%252E%252E", "%25252e%25252e",
|
||||
},
|
||||
"unicode": {
|
||||
"\\u002e\\u002e", "\\u00002e\\u00002e", "..",
|
||||
},
|
||||
"mixed": {
|
||||
"..%00", ".%2e", "%2e.", "...//", "..;/", "..%3b",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// checkSingleVariantForTraversal checks a single path variant against all patterns
|
||||
func checkSingleVariantForTraversal(variant string, patterns map[string][]string, overlongRegex *regexp.Regexp) bool {
|
||||
lowerVariant := strings.ToLower(variant)
|
||||
|
||||
// Check all pattern categories
|
||||
for _, patternList := range patterns {
|
||||
for _, pattern := range patternList {
|
||||
if containsPattern(variant, lowerVariant, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for UTF-8 overlong encodings
|
||||
if overlongRegex.MatchString(variant) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for null byte injection combined with path traversal
|
||||
if containsNullByteInjection(variant, lowerVariant) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for invalid UTF-8 sequences
|
||||
if !utf8.ValidString(variant) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// containsPattern checks if a variant contains a dangerous pattern
|
||||
func containsPattern(variant, lowerVariant, pattern string) bool {
|
||||
// For Unicode patterns, check both original and lowercase
|
||||
if strings.Contains(pattern, "\\u") || strings.Contains(pattern, "\\x") {
|
||||
return strings.Contains(variant, pattern) || strings.Contains(lowerVariant, strings.ToLower(pattern))
|
||||
}
|
||||
// For other patterns, use case-insensitive check
|
||||
return strings.Contains(lowerVariant, strings.ToLower(pattern))
|
||||
}
|
||||
|
||||
// containsNullByteInjection checks for null byte injection with path traversal
|
||||
func containsNullByteInjection(variant, lowerVariant string) bool {
|
||||
return strings.Contains(variant, "\x00") &&
|
||||
(strings.Contains(variant, "..") || strings.Contains(lowerVariant, "%2e"))
|
||||
}
|
||||
|
||||
// validateConfigPath validates directory paths from configuration
|
||||
func validateConfigPath(path, pathType string) (string, error) {
|
||||
if path == "" {
|
||||
return "", fmt.Errorf("%s path cannot be empty", pathType)
|
||||
}
|
||||
|
||||
// Comprehensive path traversal detection
|
||||
if containsPathTraversal(path) {
|
||||
return "", fmt.Errorf("%s path contains path traversal: %s", pathType, path)
|
||||
}
|
||||
|
||||
// Check for null bytes
|
||||
if strings.Contains(path, "\x00") {
|
||||
return "", fmt.Errorf("%s path contains null byte: %s", pathType, path)
|
||||
}
|
||||
|
||||
// Resolve to absolute path
|
||||
absPath, err := filepath.Abs(filepath.Clean(path))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid %s path: %w", pathType, err)
|
||||
}
|
||||
|
||||
// Check path length (reasonable limit)
|
||||
if len(absPath) > 4096 {
|
||||
return "", fmt.Errorf("%s path too long: %d characters", pathType, len(absPath))
|
||||
}
|
||||
|
||||
// Validate that it's a reasonable system path
|
||||
if !isReasonableSystemPath(absPath, pathType) {
|
||||
return "", fmt.Errorf("%s path not in expected system location: %s", pathType, absPath)
|
||||
}
|
||||
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
// isReasonableSystemPath checks if a path is in a reasonable system location
|
||||
func isReasonableSystemPath(path, pathType string) bool {
|
||||
// Allow common system directories based on path type
|
||||
var allowedPrefixes []string
|
||||
switch pathType {
|
||||
case "log":
|
||||
allowedPrefixes = fail2ban.GetLogAllowedPaths()
|
||||
case "filter":
|
||||
allowedPrefixes = fail2ban.GetFilterAllowedPaths()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range allowedPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// NewConfigFromEnv builds Config from environment variables with defaults and validation.
|
||||
func NewConfigFromEnv() Config {
|
||||
cfg := Config{}
|
||||
|
||||
// Get and validate log directory
|
||||
logDir := os.Getenv("F2B_LOG_DIR")
|
||||
if logDir == "" {
|
||||
logDir = "/var/log"
|
||||
}
|
||||
|
||||
validatedLogDir, err := validateConfigPath(logDir, "log")
|
||||
if err != nil {
|
||||
Logger.WithError(err).WithField("path", logDir).Error("Invalid log directory from environment")
|
||||
validatedLogDir = "/var/log" // Fallback to safe default
|
||||
}
|
||||
cfg.LogDir = validatedLogDir
|
||||
|
||||
// Get and validate filter directory
|
||||
filterDir := os.Getenv("F2B_FILTER_DIR")
|
||||
if filterDir == "" {
|
||||
filterDir = "/etc/fail2ban/filter.d"
|
||||
}
|
||||
|
||||
validatedFilterDir, err := validateConfigPath(filterDir, "filter")
|
||||
if err != nil {
|
||||
Logger.WithError(err).WithField("path", filterDir).Error("Invalid filter directory from environment")
|
||||
validatedFilterDir = "/etc/fail2ban/filter.d" // Fallback to safe default
|
||||
}
|
||||
cfg.FilterDir = validatedFilterDir
|
||||
|
||||
// Configure timeouts from environment variables
|
||||
cfg.CommandTimeout = parseTimeoutFromEnv("F2B_COMMAND_TIMEOUT", DefaultCommandTimeout)
|
||||
cfg.FileTimeout = parseTimeoutFromEnv("F2B_FILE_TIMEOUT", DefaultFileTimeout)
|
||||
cfg.ParallelTimeout = parseTimeoutFromEnv("F2B_PARALLEL_TIMEOUT", DefaultParallelTimeout)
|
||||
|
||||
cfg.Format = "plain"
|
||||
return cfg
|
||||
}
|
||||
|
||||
// parseTimeoutFromEnv parses timeout duration from environment variable with fallback
|
||||
func parseTimeoutFromEnv(envVar string, defaultTimeout time.Duration) time.Duration {
|
||||
envValue := os.Getenv(envVar)
|
||||
if envValue == "" {
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
// Try parsing as duration first (e.g., "30s", "1m30s")
|
||||
if duration, err := time.ParseDuration(envValue); err == nil {
|
||||
if duration <= 0 {
|
||||
Logger.WithField("env_var", envVar).WithField("value", envValue).
|
||||
Warn("Invalid timeout value, using default")
|
||||
return defaultTimeout
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
// Try parsing as seconds (for backward compatibility)
|
||||
if seconds, err := strconv.Atoi(envValue); err == nil {
|
||||
if seconds <= 0 {
|
||||
Logger.WithField("env_var", envVar).WithField("value", envValue).
|
||||
Warn("Invalid timeout value, using default")
|
||||
return defaultTimeout
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
Logger.WithField("env_var", envVar).WithField("value", envValue).
|
||||
Warn("Failed to parse timeout value, using default")
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
// ValidateConfig performs comprehensive validation of the Config struct
|
||||
func (c *Config) ValidateConfig() error {
|
||||
var errors []string
|
||||
|
||||
// Validate LogDir
|
||||
if c.LogDir == "" {
|
||||
errors = append(errors, "log directory cannot be empty")
|
||||
} else if _, err := validateConfigPath(c.LogDir, "log"); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("invalid log directory: %v", err))
|
||||
}
|
||||
|
||||
// Validate FilterDir
|
||||
if c.FilterDir == "" {
|
||||
errors = append(errors, "filter directory cannot be empty")
|
||||
} else if _, err := validateConfigPath(c.FilterDir, "filter"); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("invalid filter directory: %v", err))
|
||||
}
|
||||
|
||||
// Validate Format
|
||||
validFormats := map[string]bool{"plain": true, "json": true}
|
||||
if !validFormats[c.Format] {
|
||||
errors = append(errors, fmt.Sprintf("invalid format '%s', must be 'plain' or 'json'", c.Format))
|
||||
}
|
||||
|
||||
// Validate Timeouts
|
||||
if c.CommandTimeout <= 0 {
|
||||
errors = append(errors, "command timeout must be positive")
|
||||
} else if c.CommandTimeout > fail2ban.MaxCommandTimeout {
|
||||
errors = append(errors, "command timeout too large (max 10 minutes)")
|
||||
}
|
||||
|
||||
if c.FileTimeout <= 0 {
|
||||
errors = append(errors, "file timeout must be positive")
|
||||
} else if c.FileTimeout > fail2ban.MaxFileTimeout {
|
||||
errors = append(errors, "file timeout too large (max 5 minutes)")
|
||||
}
|
||||
|
||||
if c.ParallelTimeout <= 0 {
|
||||
errors = append(errors, "parallel timeout must be positive")
|
||||
} else if c.ParallelTimeout > fail2ban.MaxParallelTimeout {
|
||||
errors = append(errors, "parallel timeout too large (max 30 minutes)")
|
||||
}
|
||||
|
||||
// Check timeout relationships
|
||||
if c.ParallelTimeout < c.CommandTimeout {
|
||||
errors = append(errors, "parallel timeout should be >= command timeout")
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("configuration validation failed: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
45
cmd/filter.go
Normal file
45
cmd/filter.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// TestFilterCmd returns the test-filter command with injected client and config
|
||||
func TestFilterCmd(client fail2ban.Client, config *Config) *cobra.Command {
|
||||
return NewCommand(
|
||||
"test-filter <filter>",
|
||||
"Test a Fail2Ban filter",
|
||||
nil,
|
||||
func(cmd *cobra.Command, args []string) error {
|
||||
// Create timeout context for filter testing (use file timeout as it involves file operations)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.FileTimeout)
|
||||
defer cancel()
|
||||
|
||||
if len(args) < 1 {
|
||||
filters, err := client.ListFiltersWithContext(ctx)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
PrintOutputTo(GetCmdOutput(cmd), "Available filters: "+strings.Join(filters, ", "), config.Format)
|
||||
return HandleClientError(fail2ban.ErrFilterRequiredError)
|
||||
}
|
||||
|
||||
filterName := args[0]
|
||||
if err := RequireNonEmptyArgument(filterName, "filter name"); err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
out, err := client.TestFilterWithContext(ctx, filterName)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
PrintOutputTo(GetCmdOutput(cmd), out, config.Format)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
365
cmd/helpers.go
Normal file
365
cmd/helpers.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultPollingInterval is the default interval for polling operations
|
||||
DefaultPollingInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// Command creation helpers
|
||||
|
||||
// NewCommand creates a new cobra command with standard setup
|
||||
func NewCommand(use, short string, aliases []string, runE func(*cobra.Command, []string) error) *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: use,
|
||||
Short: short,
|
||||
Aliases: aliases,
|
||||
RunE: runE,
|
||||
}
|
||||
}
|
||||
|
||||
// AddLogFlags adds common log-related flags to a command
|
||||
func AddLogFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().IntP("limit", "n", 0, "Show only the last N log lines")
|
||||
}
|
||||
|
||||
// IsSkipCommand returns true if the command doesn't require a fail2ban client
|
||||
func IsSkipCommand(command string) bool {
|
||||
skipCommands := []string{
|
||||
"service",
|
||||
"version",
|
||||
"test-filter",
|
||||
"completion",
|
||||
"help",
|
||||
}
|
||||
|
||||
for _, skip := range skipCommands {
|
||||
if command == skip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AddWatchFlags adds common watch-related flags to a command
|
||||
func AddWatchFlags(cmd *cobra.Command, interval *time.Duration) {
|
||||
cmd.Flags().DurationVarP(interval, "interval", "i", DefaultPollingInterval, "Polling interval")
|
||||
}
|
||||
|
||||
// Validation helpers
|
||||
|
||||
// ValidateIPArgument validates that an IP address is provided in args
|
||||
func ValidateIPArgument(args []string) (string, error) {
|
||||
if len(args) < 1 {
|
||||
return "", fmt.Errorf("IP address required")
|
||||
}
|
||||
ip := args[0]
|
||||
// Validate the IP address
|
||||
if err := fail2ban.CachedValidateIP(ip); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// ValidateServiceAction validates that a service action is valid
|
||||
func ValidateServiceAction(action string) error {
|
||||
validActions := map[string]bool{
|
||||
"start": true,
|
||||
"stop": true,
|
||||
"restart": true,
|
||||
"status": true,
|
||||
"reload": true,
|
||||
"enable": true,
|
||||
"disable": true,
|
||||
}
|
||||
|
||||
if !validActions[action] {
|
||||
return fmt.Errorf(
|
||||
"invalid service action: %s. Valid actions: start, stop, restart, status, reload, enable, disable",
|
||||
action,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetJailsFromArgs gets jail list from arguments or client
|
||||
func GetJailsFromArgs(client fail2ban.Client, args []string, startIndex int) ([]string, error) {
|
||||
if len(args) > startIndex {
|
||||
return []string{strings.ToLower(args[startIndex])}, nil
|
||||
}
|
||||
|
||||
jails, err := client.ListJails()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jails, nil
|
||||
}
|
||||
|
||||
// GetJailsFromArgsWithContext gets jail list from arguments or client with timeout context
|
||||
func GetJailsFromArgsWithContext(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
args []string,
|
||||
startIndex int,
|
||||
) ([]string, error) {
|
||||
if len(args) > startIndex {
|
||||
return []string{strings.ToLower(args[startIndex])}, nil
|
||||
}
|
||||
|
||||
jails, err := client.ListJailsWithContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jails, nil
|
||||
}
|
||||
|
||||
// ParseOptionalArgs parses optional arguments up to a given count
|
||||
func ParseOptionalArgs(args []string, count int) []string {
|
||||
result := make([]string, count)
|
||||
for i := 0; i < count && i < len(args); i++ {
|
||||
result[i] = args[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Error handling helpers
|
||||
|
||||
// HandleClientError handles client errors with consistent formatting
|
||||
func HandleClientError(err error) error {
|
||||
if err != nil {
|
||||
PrintError(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Output helpers
|
||||
|
||||
// OutputResults outputs results in the specified format
|
||||
func OutputResults(cmd *cobra.Command, results interface{}, config *Config) {
|
||||
if config != nil && config.Format == JSONFormat {
|
||||
PrintOutputTo(GetCmdOutput(cmd), results, JSONFormat)
|
||||
} else {
|
||||
PrintOutputTo(GetCmdOutput(cmd), results, "plain")
|
||||
}
|
||||
}
|
||||
|
||||
// InterpretBanStatus interprets ban operation status codes
|
||||
func InterpretBanStatus(code int, operation string) string {
|
||||
switch operation {
|
||||
case "ban":
|
||||
if code == 1 {
|
||||
return "Already banned"
|
||||
}
|
||||
return "Banned"
|
||||
case "unban":
|
||||
if code == 1 {
|
||||
return "Already unbanned"
|
||||
}
|
||||
return "Unbanned"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Operation result types
|
||||
|
||||
// OperationResult represents the result of a jail operation
|
||||
type OperationResult struct {
|
||||
IP string `json:"ip"`
|
||||
Jail string `json:"jail"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// ProcessBanOperation processes ban operations across multiple jails
|
||||
func ProcessBanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
||||
results := make([]OperationResult, 0, len(jails))
|
||||
|
||||
for _, jail := range jails {
|
||||
code, err := client.BanIP(ip, jail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := InterpretBanStatus(code, "ban")
|
||||
Logger.WithFields(map[string]interface{}{
|
||||
"ip": ip,
|
||||
"jail": jail,
|
||||
"status": status,
|
||||
}).Info("Ban result")
|
||||
|
||||
results = append(results, OperationResult{
|
||||
IP: ip,
|
||||
Jail: jail,
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ProcessBanOperationWithContext processes ban operations across multiple jails with timeout context
|
||||
func ProcessBanOperationWithContext(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
) ([]OperationResult, error) {
|
||||
logger := GetContextualLogger()
|
||||
results := make([]OperationResult, 0, len(jails))
|
||||
|
||||
for _, jail := range jails {
|
||||
// Add jail to context for this operation
|
||||
jailCtx := WithJail(ctx, jail)
|
||||
|
||||
// Time the ban operation
|
||||
start := time.Now()
|
||||
code, err := client.BanIPWithContext(jailCtx, ip, jail)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
// Log the failed operation with timing
|
||||
logger.LogBanOperation(jailCtx, "ban", ip, jail, false, duration)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := InterpretBanStatus(code, "ban")
|
||||
|
||||
// Log the successful operation with timing
|
||||
logger.LogBanOperation(jailCtx, "ban", ip, jail, true, duration)
|
||||
|
||||
Logger.WithFields(map[string]interface{}{
|
||||
"ip": ip,
|
||||
"jail": jail,
|
||||
"status": status,
|
||||
}).Info("Ban result")
|
||||
|
||||
results = append(results, OperationResult{
|
||||
IP: ip,
|
||||
Jail: jail,
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ProcessUnbanOperation processes unban operations across multiple jails
|
||||
func ProcessUnbanOperation(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
||||
results := make([]OperationResult, 0, len(jails))
|
||||
|
||||
for _, jail := range jails {
|
||||
code, err := client.UnbanIP(ip, jail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := InterpretBanStatus(code, "unban")
|
||||
Logger.WithFields(map[string]interface{}{
|
||||
"ip": ip,
|
||||
"jail": jail,
|
||||
"status": status,
|
||||
}).Info("Unban result")
|
||||
|
||||
results = append(results, OperationResult{
|
||||
IP: ip,
|
||||
Jail: jail,
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ProcessUnbanOperationWithContext processes unban operations across multiple jails with timeout context
|
||||
func ProcessUnbanOperationWithContext(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
) ([]OperationResult, error) {
|
||||
logger := GetContextualLogger()
|
||||
results := make([]OperationResult, 0, len(jails))
|
||||
|
||||
for _, jail := range jails {
|
||||
// Add jail to context for this operation
|
||||
jailCtx := WithJail(ctx, jail)
|
||||
|
||||
// Time the unban operation
|
||||
start := time.Now()
|
||||
code, err := client.UnbanIPWithContext(jailCtx, ip, jail)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
// Log the failed operation with timing
|
||||
logger.LogBanOperation(jailCtx, "unban", ip, jail, false, duration)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := InterpretBanStatus(code, "unban")
|
||||
|
||||
// Log the successful operation with timing
|
||||
logger.LogBanOperation(jailCtx, "unban", ip, jail, true, duration)
|
||||
|
||||
Logger.WithFields(map[string]interface{}{
|
||||
"ip": ip,
|
||||
"jail": jail,
|
||||
"status": status,
|
||||
}).Info("Unban result")
|
||||
|
||||
results = append(results, OperationResult{
|
||||
IP: ip,
|
||||
Jail: jail,
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Argument validation helpers
|
||||
|
||||
// RequireArguments checks that at least n arguments are provided
|
||||
func RequireArguments(args []string, n int, errorMsg string) error {
|
||||
if len(args) < n {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequireNonEmptyArgument checks that an argument is not empty
|
||||
func RequireNonEmptyArgument(arg, name string) error {
|
||||
if strings.TrimSpace(arg) == "" {
|
||||
return fmt.Errorf("%s cannot be empty", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Status output helpers
|
||||
|
||||
// FormatBannedResult formats banned IP results for output
|
||||
func FormatBannedResult(ip string, jails []string) string {
|
||||
if len(jails) == 0 {
|
||||
return fmt.Sprintf("IP %s is not banned", ip)
|
||||
}
|
||||
return fmt.Sprintf("IP %s is banned in: %v", ip, jails)
|
||||
}
|
||||
|
||||
// FormatStatusResult formats status results for output
|
||||
func FormatStatusResult(jail, status string) string {
|
||||
if jail == "" {
|
||||
return status
|
||||
}
|
||||
return fmt.Sprintf("Status for %s:\n%s", jail, status)
|
||||
}
|
||||
34
cmd/listjails.go
Normal file
34
cmd/listjails.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// ListJailsCmd returns the list-jails command with injected client and config
|
||||
func ListJailsCmd(client fail2ban.Client, config *Config) *cobra.Command {
|
||||
return NewCommand(
|
||||
"list-jails",
|
||||
"List all jails",
|
||||
[]string{"ls-jails", "jails"},
|
||||
func(cmd *cobra.Command, _ []string) error {
|
||||
// Create timeout context for listing jails
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.CommandTimeout)
|
||||
defer cancel()
|
||||
|
||||
jails, err := client.ListJailsWithContext(ctx)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintln(GetCmdOutput(cmd), strings.Join(jails, " ")); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
215
cmd/logging.go
Normal file
215
cmd/logging.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ContextKey represents keys for context values
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// RequestIDKey is the key for request ID in context
|
||||
RequestIDKey ContextKey = "request_id"
|
||||
// OperationKey is the key for operation name in context
|
||||
OperationKey ContextKey = "operation"
|
||||
// IPKey is the key for IP address in context
|
||||
IPKey ContextKey = "ip"
|
||||
// JailKey is the key for jail name in context
|
||||
JailKey ContextKey = "jail"
|
||||
// CommandKey is the key for command name in context
|
||||
CommandKey ContextKey = "command"
|
||||
)
|
||||
|
||||
// ContextualLogger provides structured logging with context propagation
|
||||
type ContextualLogger struct {
|
||||
*logrus.Logger
|
||||
defaultFields logrus.Fields
|
||||
}
|
||||
|
||||
// NewContextualLogger creates a new contextual logger using the centralized cmd.Logger
|
||||
func NewContextualLogger() *ContextualLogger {
|
||||
// Use cmd.Logger as the backend, but with JSON formatter for structured logging
|
||||
contextLogger := logrus.New()
|
||||
contextLogger.SetOutput(Logger.Out)
|
||||
contextLogger.SetLevel(Logger.GetLevel())
|
||||
contextLogger.SetFormatter(&logrus.JSONFormatter{
|
||||
TimestampFormat: time.RFC3339Nano,
|
||||
FieldMap: logrus.FieldMap{
|
||||
logrus.FieldKeyTime: "timestamp",
|
||||
logrus.FieldKeyLevel: "level",
|
||||
logrus.FieldKeyMsg: "message",
|
||||
},
|
||||
})
|
||||
|
||||
return &ContextualLogger{
|
||||
Logger: contextLogger,
|
||||
defaultFields: logrus.Fields{
|
||||
"service": "f2b",
|
||||
"version": getVersion(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Build-time variables set via ldflags
|
||||
var (
|
||||
version = "dev"
|
||||
// Additional build variables that may be used in the future
|
||||
_ = "unknown" // commit placeholder
|
||||
_ = "unknown" // date placeholder
|
||||
_ = "unknown" // builtBy placeholder
|
||||
)
|
||||
|
||||
// getVersion returns the version from build variables or default
|
||||
func getVersion() string {
|
||||
return version
|
||||
}
|
||||
|
||||
// WithContext creates a logger entry with context values
|
||||
func (cl *ContextualLogger) WithContext(ctx context.Context) *logrus.Entry {
|
||||
entry := cl.WithFields(cl.defaultFields)
|
||||
|
||||
// Extract context values and add as fields
|
||||
if requestID := ctx.Value(RequestIDKey); requestID != nil {
|
||||
entry = entry.WithField("request_id", requestID)
|
||||
}
|
||||
|
||||
if operation := ctx.Value(OperationKey); operation != nil {
|
||||
entry = entry.WithField("operation", operation)
|
||||
}
|
||||
|
||||
if ip := ctx.Value(IPKey); ip != nil {
|
||||
entry = entry.WithField("ip", ip)
|
||||
}
|
||||
|
||||
if jail := ctx.Value(JailKey); jail != nil {
|
||||
entry = entry.WithField("jail", jail)
|
||||
}
|
||||
|
||||
if command := ctx.Value(CommandKey); command != nil {
|
||||
entry = entry.WithField("command", command)
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
// WithOperation adds operation context and returns a new context
|
||||
func WithOperation(ctx context.Context, operation string) context.Context {
|
||||
return context.WithValue(ctx, OperationKey, operation)
|
||||
}
|
||||
|
||||
// WithIP adds IP context and returns a new context
|
||||
func WithIP(ctx context.Context, ip string) context.Context {
|
||||
return context.WithValue(ctx, IPKey, ip)
|
||||
}
|
||||
|
||||
// WithJail adds jail context and returns a new context
|
||||
func WithJail(ctx context.Context, jail string) context.Context {
|
||||
return context.WithValue(ctx, JailKey, jail)
|
||||
}
|
||||
|
||||
// WithCommand adds command context and returns a new context
|
||||
func WithCommand(ctx context.Context, command string) context.Context {
|
||||
return context.WithValue(ctx, CommandKey, command)
|
||||
}
|
||||
|
||||
// WithRequestID adds request ID context and returns a new context
|
||||
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
||||
return context.WithValue(ctx, RequestIDKey, requestID)
|
||||
}
|
||||
|
||||
// LogOperation logs the start and end of an operation with timing and metrics
|
||||
func (cl *ContextualLogger) LogOperation(ctx context.Context, operation string, fn func() error) error {
|
||||
start := time.Now()
|
||||
ctx = WithOperation(ctx, operation)
|
||||
|
||||
// Get metrics instance
|
||||
metrics := GetGlobalMetrics()
|
||||
|
||||
cl.WithContext(ctx).WithField("duration", "start").Info("Operation started")
|
||||
|
||||
err := fn()
|
||||
duration := time.Since(start)
|
||||
|
||||
entry := cl.WithContext(ctx).WithField("duration_ms", duration.Milliseconds())
|
||||
|
||||
// Record metrics based on operation type
|
||||
success := err == nil
|
||||
if command := ctx.Value(CommandKey); command != nil {
|
||||
if cmdStr, ok := command.(string); ok {
|
||||
metrics.RecordCommandExecution(cmdStr, duration, success)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
entry.WithError(err).Error("Operation failed")
|
||||
} else {
|
||||
entry.Info("Operation completed")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// LogBanOperation logs ban/unban operations with structured context and metrics
|
||||
func (cl *ContextualLogger) LogBanOperation(
|
||||
ctx context.Context,
|
||||
operation, ip, jail string,
|
||||
success bool,
|
||||
duration time.Duration,
|
||||
) {
|
||||
ctx = WithOperation(ctx, operation)
|
||||
ctx = WithIP(ctx, ip)
|
||||
ctx = WithJail(ctx, jail)
|
||||
|
||||
// Record metrics
|
||||
metrics := GetGlobalMetrics()
|
||||
metrics.RecordBanOperation(operation, duration, success)
|
||||
|
||||
entry := cl.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"success": success,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
})
|
||||
|
||||
if success {
|
||||
entry.Info("Ban operation completed")
|
||||
} else {
|
||||
entry.Error("Ban operation failed")
|
||||
}
|
||||
}
|
||||
|
||||
// LogCommandExecution logs command execution with context
|
||||
func (cl *ContextualLogger) LogCommandExecution(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
args []string,
|
||||
duration time.Duration,
|
||||
err error,
|
||||
) {
|
||||
ctx = WithCommand(ctx, command)
|
||||
|
||||
entry := cl.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"args": args,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
entry.WithError(err).Error("Command execution failed")
|
||||
} else {
|
||||
entry.Info("Command executed successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// Global contextual logger instance
|
||||
var contextualLogger = NewContextualLogger()
|
||||
|
||||
// GetContextualLogger returns the global contextual logger
|
||||
func GetContextualLogger() *ContextualLogger {
|
||||
return contextualLogger
|
||||
}
|
||||
|
||||
// SetContextualLogger sets a new global contextual logger
|
||||
func SetContextualLogger(logger *ContextualLogger) {
|
||||
contextualLogger = logger
|
||||
}
|
||||
46
cmd/logs.go
Normal file
46
cmd/logs.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// LogsCmd returns the logs command with injected client and config
|
||||
func LogsCmd(client fail2ban.Client, config *Config) *cobra.Command {
|
||||
cmd := NewCommand(
|
||||
"logs [jail] [ip]",
|
||||
"Show Fail2Ban logs (optionally filtered by jail and/or IP)",
|
||||
nil,
|
||||
func(cmd *cobra.Command, args []string) error {
|
||||
// Create timeout context for log reading (use file timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.FileTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Parse optional arguments
|
||||
parsedArgs := ParseOptionalArgs(args, 2)
|
||||
jail := parsedArgs[0]
|
||||
ip := parsedArgs[1]
|
||||
|
||||
limit, _ := cmd.Flags().GetInt("limit")
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
lines, err := client.GetLogLinesWithContext(ctx, jail, ip)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
if limit > 0 && len(lines) > limit {
|
||||
lines = lines[len(lines)-limit:]
|
||||
}
|
||||
|
||||
PrintOutputTo(GetCmdOutput(cmd), lines, config.Format)
|
||||
return nil
|
||||
})
|
||||
|
||||
AddLogFlags(cmd)
|
||||
return cmd
|
||||
}
|
||||
125
cmd/logswatch.go
Normal file
125
cmd/logswatch.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLogWatchLimit is the default limit for log lines in watch mode
|
||||
DefaultLogWatchLimit = 10
|
||||
)
|
||||
|
||||
// LogsWatchCmd returns the logs-watch command with injected client and config
|
||||
func LogsWatchCmd(ctx context.Context, client fail2ban.Client, config *Config) *cobra.Command {
|
||||
var limit int
|
||||
var interval time.Duration
|
||||
|
||||
cmd := NewCommand(
|
||||
"logs-watch [jail] [ip]",
|
||||
"Continuously watch Fail2Ban logs (filtered by jail and/or IP)",
|
||||
nil,
|
||||
func(_ *cobra.Command, args []string) error {
|
||||
// Parse optional arguments
|
||||
parsedArgs := ParseOptionalArgs(args, 2)
|
||||
jail := parsedArgs[0]
|
||||
ip := parsedArgs[1]
|
||||
|
||||
// Use memory-efficient approach with configurable limits
|
||||
maxLines := limit
|
||||
if maxLines <= 0 {
|
||||
maxLines = 1000 // Default safe limit
|
||||
}
|
||||
|
||||
// Get initial log lines with memory limits (with file timeout)
|
||||
prev, err := getLogLinesWithLimitAndContext(ctx, client, jail, ip, maxLines, config.FileTimeout)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
prevHash := computeHash(prev)
|
||||
PrintOutput(strings.Join(prev, "\n"), config.Format)
|
||||
|
||||
if interval <= 0 {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
curr, err := getLogLinesWithLimitAndContext(ctx, client, jail, ip, maxLines, config.FileTimeout)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
currHash := computeHash(curr)
|
||||
if prevHash != currHash {
|
||||
PrintOutput(strings.Join(curr, "\n"), config.Format)
|
||||
prevHash = currHash
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
cmd.Flags().IntVarP(&limit, "limit", "n", DefaultLogWatchLimit, "Number of log lines to show/tail")
|
||||
cmd.Flags().
|
||||
DurationVarP(&interval, "interval", "i", DefaultPollingInterval, "Polling interval for checking new logs")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// getLogLinesWithLimitAndContext tries to use the new memory-efficient method with timeout context,
|
||||
// otherwise falls back to the standard method with post-processing limits
|
||||
func getLogLinesWithLimitAndContext(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
jail, ip string,
|
||||
maxLines int,
|
||||
timeout time.Duration,
|
||||
) ([]string, error) {
|
||||
// Create timeout context for this specific operation
|
||||
logCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try to use the new method if it's available (RealClient has GetLogLinesWithLimit)
|
||||
if realClient, ok := client.(*fail2ban.RealClient); ok {
|
||||
return realClient.GetLogLinesWithLimit(jail, ip, maxLines)
|
||||
}
|
||||
|
||||
// Fallback to standard method with timeout context and post-processing limit
|
||||
lines, err := client.GetLogLinesWithContext(logCtx, jail, ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply limit after the fact for other client implementations
|
||||
if maxLines > 0 && len(lines) > maxLines {
|
||||
lines = lines[len(lines)-maxLines:]
|
||||
}
|
||||
|
||||
return lines, nil
|
||||
}
|
||||
|
||||
// computeHash computes a SHA256 hash of the log lines for efficient comparison
|
||||
func computeHash(lines []string) string {
|
||||
if len(lines) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
for _, line := range lines {
|
||||
h.Write([]byte(line))
|
||||
h.Write([]byte("\n"))
|
||||
}
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
341
cmd/metrics.go
Normal file
341
cmd/metrics.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Metrics collector for performance monitoring and observability
|
||||
type Metrics struct {
|
||||
// Command execution metrics
|
||||
CommandExecutions int64
|
||||
CommandFailures int64
|
||||
CommandTotalDuration int64 // in milliseconds
|
||||
|
||||
// Ban/Unban operation metrics
|
||||
BanOperations int64
|
||||
UnbanOperations int64
|
||||
BanFailures int64
|
||||
UnbanFailures int64
|
||||
|
||||
// Client operation metrics
|
||||
ClientOperations int64
|
||||
ClientFailures int64
|
||||
ClientTotalDuration int64 // in milliseconds
|
||||
|
||||
// Validation metrics
|
||||
ValidationCacheHits int64
|
||||
ValidationCacheMiss int64
|
||||
ValidationFailures int64
|
||||
|
||||
// System resource metrics
|
||||
MaxMemoryUsage int64 // in bytes
|
||||
GoroutineCount int64
|
||||
|
||||
// Timing histograms (buckets for latency distribution)
|
||||
commandLatencyBuckets map[string]*LatencyBucket
|
||||
clientLatencyBuckets map[string]*LatencyBucket
|
||||
mu sync.RWMutex
|
||||
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// LatencyBucket represents latency distribution buckets
|
||||
type LatencyBucket struct {
|
||||
Under1ms int64
|
||||
Under10ms int64
|
||||
Under100ms int64
|
||||
Under1s int64
|
||||
Under10s int64
|
||||
Over10s int64
|
||||
Total int64
|
||||
TotalTime int64 // in milliseconds
|
||||
}
|
||||
|
||||
// NewMetrics creates a new metrics collector
|
||||
func NewMetrics() *Metrics {
|
||||
return &Metrics{
|
||||
commandLatencyBuckets: make(map[string]*LatencyBucket),
|
||||
clientLatencyBuckets: make(map[string]*LatencyBucket),
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordCommandExecution records metrics for command execution
|
||||
func (m *Metrics) RecordCommandExecution(command string, duration time.Duration, success bool) {
|
||||
atomic.AddInt64(&m.CommandExecutions, 1)
|
||||
atomic.AddInt64(&m.CommandTotalDuration, duration.Milliseconds())
|
||||
|
||||
if !success {
|
||||
atomic.AddInt64(&m.CommandFailures, 1)
|
||||
}
|
||||
|
||||
// Record latency bucket
|
||||
m.recordLatencyBucket(m.commandLatencyBuckets, command, duration)
|
||||
}
|
||||
|
||||
// RecordBanOperation records metrics for ban operations
|
||||
func (m *Metrics) RecordBanOperation(operation string, _ time.Duration, success bool) {
|
||||
switch operation {
|
||||
case "ban":
|
||||
atomic.AddInt64(&m.BanOperations, 1)
|
||||
if !success {
|
||||
atomic.AddInt64(&m.BanFailures, 1)
|
||||
}
|
||||
case "unban":
|
||||
atomic.AddInt64(&m.UnbanOperations, 1)
|
||||
if !success {
|
||||
atomic.AddInt64(&m.UnbanFailures, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordClientOperation records metrics for client operations
|
||||
func (m *Metrics) RecordClientOperation(operation string, duration time.Duration, success bool) {
|
||||
atomic.AddInt64(&m.ClientOperations, 1)
|
||||
atomic.AddInt64(&m.ClientTotalDuration, duration.Milliseconds())
|
||||
|
||||
if !success {
|
||||
atomic.AddInt64(&m.ClientFailures, 1)
|
||||
}
|
||||
|
||||
// Record latency bucket
|
||||
m.recordLatencyBucket(m.clientLatencyBuckets, operation, duration)
|
||||
}
|
||||
|
||||
// RecordValidationCacheHit records validation cache hits
|
||||
func (m *Metrics) RecordValidationCacheHit() {
|
||||
atomic.AddInt64(&m.ValidationCacheHits, 1)
|
||||
}
|
||||
|
||||
// RecordValidationCacheMiss records validation cache misses
|
||||
func (m *Metrics) RecordValidationCacheMiss() {
|
||||
atomic.AddInt64(&m.ValidationCacheMiss, 1)
|
||||
}
|
||||
|
||||
// RecordValidationFailure records validation failures
|
||||
func (m *Metrics) RecordValidationFailure() {
|
||||
atomic.AddInt64(&m.ValidationFailures, 1)
|
||||
}
|
||||
|
||||
// UpdateMemoryUsage updates the maximum memory usage
|
||||
func (m *Metrics) UpdateMemoryUsage(bytes int64) {
|
||||
for {
|
||||
current := atomic.LoadInt64(&m.MaxMemoryUsage)
|
||||
if bytes <= current {
|
||||
break
|
||||
}
|
||||
if atomic.CompareAndSwapInt64(&m.MaxMemoryUsage, current, bytes) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateGoroutineCount updates the goroutine count
|
||||
func (m *Metrics) UpdateGoroutineCount(count int64) {
|
||||
atomic.StoreInt64(&m.GoroutineCount, count)
|
||||
}
|
||||
|
||||
// recordLatencyBucket records latency in appropriate bucket
|
||||
func (m *Metrics) recordLatencyBucket(buckets map[string]*LatencyBucket, operation string, duration time.Duration) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
bucket, exists := buckets[operation]
|
||||
if !exists {
|
||||
bucket = &LatencyBucket{}
|
||||
buckets[operation] = bucket
|
||||
}
|
||||
|
||||
ms := duration.Milliseconds()
|
||||
atomic.AddInt64(&bucket.Total, 1)
|
||||
atomic.AddInt64(&bucket.TotalTime, ms)
|
||||
|
||||
switch {
|
||||
case duration < time.Millisecond:
|
||||
atomic.AddInt64(&bucket.Under1ms, 1)
|
||||
case duration < 10*time.Millisecond:
|
||||
atomic.AddInt64(&bucket.Under10ms, 1)
|
||||
case duration < 100*time.Millisecond:
|
||||
atomic.AddInt64(&bucket.Under100ms, 1)
|
||||
case duration < time.Second:
|
||||
atomic.AddInt64(&bucket.Under1s, 1)
|
||||
case duration < 10*time.Second:
|
||||
atomic.AddInt64(&bucket.Under10s, 1)
|
||||
default:
|
||||
atomic.AddInt64(&bucket.Over10s, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// GetSnapshot returns a snapshot of current metrics
|
||||
func (m *Metrics) GetSnapshot() MetricsSnapshot {
|
||||
m.mu.RLock()
|
||||
|
||||
// Copy command latency buckets
|
||||
commandBuckets := make(map[string]LatencyBucketSnapshot)
|
||||
for op, bucket := range m.commandLatencyBuckets {
|
||||
commandBuckets[op] = LatencyBucketSnapshot{
|
||||
Under1ms: atomic.LoadInt64(&bucket.Under1ms),
|
||||
Under10ms: atomic.LoadInt64(&bucket.Under10ms),
|
||||
Under100ms: atomic.LoadInt64(&bucket.Under100ms),
|
||||
Under1s: atomic.LoadInt64(&bucket.Under1s),
|
||||
Under10s: atomic.LoadInt64(&bucket.Under10s),
|
||||
Over10s: atomic.LoadInt64(&bucket.Over10s),
|
||||
Total: atomic.LoadInt64(&bucket.Total),
|
||||
TotalTime: atomic.LoadInt64(&bucket.TotalTime),
|
||||
}
|
||||
}
|
||||
|
||||
// Copy client latency buckets
|
||||
clientBuckets := make(map[string]LatencyBucketSnapshot)
|
||||
for op, bucket := range m.clientLatencyBuckets {
|
||||
clientBuckets[op] = LatencyBucketSnapshot{
|
||||
Under1ms: atomic.LoadInt64(&bucket.Under1ms),
|
||||
Under10ms: atomic.LoadInt64(&bucket.Under10ms),
|
||||
Under100ms: atomic.LoadInt64(&bucket.Under100ms),
|
||||
Under1s: atomic.LoadInt64(&bucket.Under1s),
|
||||
Under10s: atomic.LoadInt64(&bucket.Under10s),
|
||||
Over10s: atomic.LoadInt64(&bucket.Over10s),
|
||||
Total: atomic.LoadInt64(&bucket.Total),
|
||||
TotalTime: atomic.LoadInt64(&bucket.TotalTime),
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.RUnlock()
|
||||
|
||||
return MetricsSnapshot{
|
||||
// Command metrics
|
||||
CommandExecutions: atomic.LoadInt64(&m.CommandExecutions),
|
||||
CommandFailures: atomic.LoadInt64(&m.CommandFailures),
|
||||
CommandTotalDuration: atomic.LoadInt64(&m.CommandTotalDuration),
|
||||
|
||||
// Ban/Unban metrics
|
||||
BanOperations: atomic.LoadInt64(&m.BanOperations),
|
||||
UnbanOperations: atomic.LoadInt64(&m.UnbanOperations),
|
||||
BanFailures: atomic.LoadInt64(&m.BanFailures),
|
||||
UnbanFailures: atomic.LoadInt64(&m.UnbanFailures),
|
||||
|
||||
// Client metrics
|
||||
ClientOperations: atomic.LoadInt64(&m.ClientOperations),
|
||||
ClientFailures: atomic.LoadInt64(&m.ClientFailures),
|
||||
ClientTotalDuration: atomic.LoadInt64(&m.ClientTotalDuration),
|
||||
|
||||
// Validation metrics
|
||||
ValidationCacheHits: atomic.LoadInt64(&m.ValidationCacheHits),
|
||||
ValidationCacheMiss: atomic.LoadInt64(&m.ValidationCacheMiss),
|
||||
ValidationFailures: atomic.LoadInt64(&m.ValidationFailures),
|
||||
|
||||
// System metrics
|
||||
MaxMemoryUsage: atomic.LoadInt64(&m.MaxMemoryUsage),
|
||||
GoroutineCount: atomic.LoadInt64(&m.GoroutineCount),
|
||||
|
||||
// Latency buckets
|
||||
CommandLatencyBuckets: commandBuckets,
|
||||
ClientLatencyBuckets: clientBuckets,
|
||||
|
||||
// Uptime
|
||||
UptimeSeconds: int64(time.Since(m.startTime).Seconds()),
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsSnapshot represents a point-in-time snapshot of metrics
|
||||
type MetricsSnapshot struct {
|
||||
// Command execution metrics
|
||||
CommandExecutions int64 `json:"command_executions"`
|
||||
CommandFailures int64 `json:"command_failures"`
|
||||
CommandTotalDuration int64 `json:"command_total_duration_ms"`
|
||||
|
||||
// Ban/Unban operation metrics
|
||||
BanOperations int64 `json:"ban_operations"`
|
||||
UnbanOperations int64 `json:"unban_operations"`
|
||||
BanFailures int64 `json:"ban_failures"`
|
||||
UnbanFailures int64 `json:"unban_failures"`
|
||||
|
||||
// Client operation metrics
|
||||
ClientOperations int64 `json:"client_operations"`
|
||||
ClientFailures int64 `json:"client_failures"`
|
||||
ClientTotalDuration int64 `json:"client_total_duration_ms"`
|
||||
|
||||
// Validation metrics
|
||||
ValidationCacheHits int64 `json:"validation_cache_hits"`
|
||||
ValidationCacheMiss int64 `json:"validation_cache_miss"`
|
||||
ValidationFailures int64 `json:"validation_failures"`
|
||||
|
||||
// System resource metrics
|
||||
MaxMemoryUsage int64 `json:"max_memory_usage_bytes"`
|
||||
GoroutineCount int64 `json:"goroutine_count"`
|
||||
UptimeSeconds int64 `json:"uptime_seconds"`
|
||||
|
||||
// Latency distribution
|
||||
CommandLatencyBuckets map[string]LatencyBucketSnapshot `json:"command_latency_buckets"`
|
||||
ClientLatencyBuckets map[string]LatencyBucketSnapshot `json:"client_latency_buckets"`
|
||||
}
|
||||
|
||||
// LatencyBucketSnapshot represents a snapshot of latency bucket
|
||||
type LatencyBucketSnapshot struct {
|
||||
Under1ms int64 `json:"under_1ms"`
|
||||
Under10ms int64 `json:"under_10ms"`
|
||||
Under100ms int64 `json:"under_100ms"`
|
||||
Under1s int64 `json:"under_1s"`
|
||||
Under10s int64 `json:"under_10s"`
|
||||
Over10s int64 `json:"over_10s"`
|
||||
Total int64 `json:"total"`
|
||||
TotalTime int64 `json:"total_time_ms"`
|
||||
}
|
||||
|
||||
// GetAverageLatency calculates average latency for the bucket
|
||||
func (l LatencyBucketSnapshot) GetAverageLatency() float64 {
|
||||
if l.Total == 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(l.TotalTime) / float64(l.Total)
|
||||
}
|
||||
|
||||
// TimedOperation provides instrumentation for timed operations
|
||||
type TimedOperation struct {
|
||||
metrics *Metrics
|
||||
operation string
|
||||
category string
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// NewTimedOperation creates a new timed operation
|
||||
func NewTimedOperation(_ context.Context, metrics *Metrics, category, operation string) *TimedOperation {
|
||||
return &TimedOperation{
|
||||
metrics: metrics,
|
||||
operation: operation,
|
||||
category: category,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Finish completes the timed operation and records metrics
|
||||
func (t *TimedOperation) Finish(success bool) {
|
||||
duration := time.Since(t.startTime)
|
||||
|
||||
switch t.category {
|
||||
case "command":
|
||||
t.metrics.RecordCommandExecution(t.operation, duration, success)
|
||||
case "client":
|
||||
t.metrics.RecordClientOperation(t.operation, duration, success)
|
||||
case "ban":
|
||||
t.metrics.RecordBanOperation(t.operation, duration, success)
|
||||
}
|
||||
|
||||
// Note: Additional context logging could be added here if needed
|
||||
}
|
||||
|
||||
// Global metrics instance
|
||||
var globalMetrics = NewMetrics()
|
||||
|
||||
// GetGlobalMetrics returns the global metrics instance
|
||||
func GetGlobalMetrics() *Metrics {
|
||||
return globalMetrics
|
||||
}
|
||||
|
||||
// SetGlobalMetrics sets a new global metrics instance
|
||||
func SetGlobalMetrics(metrics *Metrics) {
|
||||
globalMetrics = metrics
|
||||
}
|
||||
130
cmd/metrics_cmd.go
Normal file
130
cmd/metrics_cmd.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// MetricsCmd returns the metrics command with injected client and config
|
||||
func MetricsCmd(_ fail2ban.Client, config *Config) *cobra.Command {
|
||||
return NewCommand(
|
||||
"metrics",
|
||||
"Show performance metrics",
|
||||
[]string{"stats"},
|
||||
func(cmd *cobra.Command, _ []string) error {
|
||||
// Get the global metrics instance
|
||||
metrics := GetGlobalMetrics()
|
||||
snapshot := metrics.GetSnapshot()
|
||||
|
||||
// Output metrics based on format
|
||||
if config != nil && config.Format == JSONFormat {
|
||||
encoder := json.NewEncoder(GetCmdOutput(cmd))
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(snapshot); err != nil {
|
||||
return fmt.Errorf("failed to encode metrics: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Plain text output - use a helper to simplify error handling
|
||||
if err := printMetricsPlain(GetCmdOutput(cmd), snapshot); err != nil {
|
||||
return fmt.Errorf("failed to print metrics: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// printMetricsPlain prints metrics in plain text format
|
||||
func printMetricsPlain(output io.Writer, snapshot MetricsSnapshot) error {
|
||||
// Use a string builder to build the output
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("F2B Performance Metrics\n")
|
||||
sb.WriteString("======================\n\n")
|
||||
|
||||
// System metrics
|
||||
sb.WriteString("System:\n")
|
||||
sb.WriteString(fmt.Sprintf(" Uptime: %ds\n", snapshot.UptimeSeconds))
|
||||
sb.WriteString(fmt.Sprintf(" Max Memory: %.2f MB\n", float64(snapshot.MaxMemoryUsage)/(1024*1024)))
|
||||
sb.WriteString(fmt.Sprintf(" Goroutines: %d\n\n", snapshot.GoroutineCount))
|
||||
|
||||
// Command metrics
|
||||
sb.WriteString("Commands:\n")
|
||||
sb.WriteString(fmt.Sprintf(" Total Executions: %d\n", snapshot.CommandExecutions))
|
||||
sb.WriteString(fmt.Sprintf(" Total Failures: %d\n", snapshot.CommandFailures))
|
||||
if snapshot.CommandExecutions > 0 {
|
||||
avgLatency := float64(snapshot.CommandTotalDuration) / float64(snapshot.CommandExecutions)
|
||||
sb.WriteString(fmt.Sprintf(" Average Latency: %.2f ms\n", avgLatency))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Ban/Unban metrics
|
||||
sb.WriteString("Ban Operations:\n")
|
||||
sb.WriteString(fmt.Sprintf(" Ban Operations: %d (failures: %d)\n", snapshot.BanOperations, snapshot.BanFailures))
|
||||
sb.WriteString(
|
||||
fmt.Sprintf(" Unban Operations: %d (failures: %d)\n", snapshot.UnbanOperations, snapshot.UnbanFailures),
|
||||
)
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Client metrics
|
||||
sb.WriteString("Client Operations:\n")
|
||||
sb.WriteString(fmt.Sprintf(" Total Operations: %d\n", snapshot.ClientOperations))
|
||||
sb.WriteString(fmt.Sprintf(" Total Failures: %d\n", snapshot.ClientFailures))
|
||||
if snapshot.ClientOperations > 0 {
|
||||
avgLatency := float64(snapshot.ClientTotalDuration) / float64(snapshot.ClientOperations)
|
||||
sb.WriteString(fmt.Sprintf(" Average Latency: %.2f ms\n", avgLatency))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Validation metrics
|
||||
sb.WriteString("Validation:\n")
|
||||
sb.WriteString(fmt.Sprintf(" Cache Hits: %d\n", snapshot.ValidationCacheHits))
|
||||
sb.WriteString(fmt.Sprintf(" Cache Misses: %d\n", snapshot.ValidationCacheMiss))
|
||||
sb.WriteString(fmt.Sprintf(" Failures: %d\n", snapshot.ValidationFailures))
|
||||
if total := snapshot.ValidationCacheHits + snapshot.ValidationCacheMiss; total > 0 {
|
||||
hitRate := float64(snapshot.ValidationCacheHits) / float64(total) * 100
|
||||
sb.WriteString(fmt.Sprintf(" Cache Hit Rate: %.2f%%\n", hitRate))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Command latency distribution
|
||||
if len(snapshot.CommandLatencyBuckets) > 0 {
|
||||
sb.WriteString("Command Latency Distribution:\n")
|
||||
for cmd, bucket := range snapshot.CommandLatencyBuckets {
|
||||
sb.WriteString(fmt.Sprintf(" %s:\n", cmd))
|
||||
sb.WriteString(fmt.Sprintf(" < 1ms: %d\n", bucket.Under1ms))
|
||||
sb.WriteString(fmt.Sprintf(" < 10ms: %d\n", bucket.Under10ms))
|
||||
sb.WriteString(fmt.Sprintf(" < 100ms: %d\n", bucket.Under100ms))
|
||||
sb.WriteString(fmt.Sprintf(" < 1s: %d\n", bucket.Under1s))
|
||||
sb.WriteString(fmt.Sprintf(" < 10s: %d\n", bucket.Under10s))
|
||||
sb.WriteString(fmt.Sprintf(" > 10s: %d\n", bucket.Over10s))
|
||||
sb.WriteString(fmt.Sprintf(" Average: %.2f ms\n", bucket.GetAverageLatency()))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Client latency distribution
|
||||
if len(snapshot.ClientLatencyBuckets) > 0 {
|
||||
sb.WriteString("Client Operation Latency Distribution:\n")
|
||||
for op, bucket := range snapshot.ClientLatencyBuckets {
|
||||
sb.WriteString(fmt.Sprintf(" %s:\n", op))
|
||||
sb.WriteString(fmt.Sprintf(" < 1ms: %d\n", bucket.Under1ms))
|
||||
sb.WriteString(fmt.Sprintf(" < 10ms: %d\n", bucket.Under10ms))
|
||||
sb.WriteString(fmt.Sprintf(" < 100ms: %d\n", bucket.Under100ms))
|
||||
sb.WriteString(fmt.Sprintf(" < 1s: %d\n", bucket.Under1s))
|
||||
sb.WriteString(fmt.Sprintf(" < 10s: %d\n", bucket.Under10s))
|
||||
sb.WriteString(fmt.Sprintf(" > 10s: %d\n", bucket.Over10s))
|
||||
sb.WriteString(fmt.Sprintf(" Average: %.2f ms\n", bucket.GetAverageLatency()))
|
||||
}
|
||||
}
|
||||
|
||||
// Write the entire string at once
|
||||
_, err := output.Write([]byte(sb.String()))
|
||||
return err
|
||||
}
|
||||
150
cmd/mock_builder_demo_test.go
Normal file
150
cmd/mock_builder_demo_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMockClientBuilder demonstrates the new fluent mock builder pattern
|
||||
func TestMockClientBuilder(t *testing.T) {
|
||||
t.Run("basic_builder_usage", func(t *testing.T) {
|
||||
// Using the new MockClientBuilder for complex mock setup
|
||||
mockBuilder := NewMockClientBuilder().
|
||||
WithJails("sshd", "apache").
|
||||
WithBannedIP("192.168.1.100", "sshd").
|
||||
WithBanRecord("sshd", "192.168.1.100", "01:30:00").
|
||||
WithLogLine("2024-01-01 12:00:00 [sshd] Ban 192.168.1.100").
|
||||
WithStatusResponse("sshd", "Mock status for jail sshd")
|
||||
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs("sshd").
|
||||
WithMockBuilder(mockBuilder).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("sshd | 192.168.1.100").
|
||||
Run()
|
||||
})
|
||||
|
||||
t.Run("builder_with_errors", func(t *testing.T) {
|
||||
// Complex error scenario setup
|
||||
mockBuilder := NewMockClientBuilder().
|
||||
WithJails("sshd").
|
||||
WithBanError("sshd", "192.168.1.100", errors.New("ban operation failed"))
|
||||
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("192.168.1.100", "sshd").
|
||||
WithMockBuilder(mockBuilder).
|
||||
ExpectError().
|
||||
Run().
|
||||
AssertContains("ban operation failed")
|
||||
})
|
||||
|
||||
t.Run("builder_with_multiple_ban_records", func(t *testing.T) {
|
||||
// Complex scenario with multiple ban records
|
||||
mockBuilder := NewMockClientBuilder().
|
||||
WithJails("sshd", "apache").
|
||||
WithBanRecord("sshd", "192.168.1.100", "01:30:00").
|
||||
WithBanRecord("apache", "192.168.1.101", "02:15:30").
|
||||
WithBanRecord("sshd", "192.168.1.102", "00:45:00")
|
||||
|
||||
NewCommandTest(t, "banned").
|
||||
ExpectSuccess().
|
||||
WithMockBuilder(mockBuilder).
|
||||
Run().
|
||||
AssertContains("192.168.1.100").
|
||||
AssertContains("192.168.1.101").
|
||||
AssertContains("192.168.1.102")
|
||||
})
|
||||
|
||||
t.Run("complex_multi_command_scenario", func(t *testing.T) {
|
||||
// Demonstrate comprehensive mock setup for multiple commands
|
||||
mockBuilder := NewMockClientBuilder().
|
||||
WithJails("sshd", "apache").
|
||||
WithBanRecord("sshd", "192.168.1.100", "01:30:00").
|
||||
WithBanRecord("apache", "192.168.1.101", "02:15:30").
|
||||
WithLogLine("2024-01-01 12:00:00 [sshd] Ban 192.168.1.100").
|
||||
WithLogLine("2024-01-01 12:01:00 [apache] Ban 192.168.1.101").
|
||||
WithStatusResponse("sshd", "Mock status for jail sshd")
|
||||
|
||||
// Test status command
|
||||
NewCommandTest(t, "status").
|
||||
WithArgs("sshd").
|
||||
WithMockBuilder(mockBuilder).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("Mock status for jail sshd").
|
||||
Run()
|
||||
|
||||
// Test banned command
|
||||
NewCommandTest(t, "banned").
|
||||
WithMockBuilder(mockBuilder).
|
||||
ExpectSuccess().
|
||||
Run().
|
||||
AssertContains("192.168.1.100").
|
||||
AssertContains("192.168.1.101")
|
||||
|
||||
// Test logs command
|
||||
NewCommandTest(t, "logs").
|
||||
WithMockBuilder(mockBuilder).
|
||||
ExpectSuccess().
|
||||
Run().
|
||||
AssertContains("Ban 192.168.1.100").
|
||||
AssertContains("Ban 192.168.1.101")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMockBuilderAdvancedFeatures tests advanced builder capabilities
|
||||
func TestMockBuilderAdvancedFeatures(t *testing.T) {
|
||||
t.Run("chained_operations", func(t *testing.T) {
|
||||
// Demonstrate that builder can be reused and modified
|
||||
baseBuilder := NewMockClientBuilder().
|
||||
WithJails("sshd", "apache")
|
||||
|
||||
// Create specialized builders from base
|
||||
sshBannedBuilder := baseBuilder.
|
||||
WithBannedIP("192.168.1.100", "sshd").
|
||||
WithBanRecord("sshd", "192.168.1.100", "01:30:00")
|
||||
|
||||
apacheBannedBuilder := baseBuilder.
|
||||
WithBannedIP("192.168.1.101", "apache").
|
||||
WithBanRecord("apache", "192.168.1.101", "02:15:30")
|
||||
|
||||
// Test SSH banned IP
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs("sshd").
|
||||
WithMockBuilder(sshBannedBuilder).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("sshd | 192.168.1.100").
|
||||
Run()
|
||||
|
||||
// Test Apache banned IP
|
||||
NewCommandTest(t, "banned").
|
||||
WithArgs("apache").
|
||||
WithMockBuilder(apacheBannedBuilder).
|
||||
ExpectSuccess().
|
||||
ExpectOutput("apache | 192.168.1.101").
|
||||
Run()
|
||||
})
|
||||
|
||||
t.Run("error_scenarios", func(t *testing.T) {
|
||||
// Test various error scenarios with builder
|
||||
errorBuilder := NewMockClientBuilder().
|
||||
WithJails("sshd").
|
||||
WithBanError("sshd", "192.168.1.100", errors.New("IP already banned")).
|
||||
WithUnbanError("sshd", "192.168.1.101", errors.New("IP not found"))
|
||||
|
||||
// Test ban error
|
||||
NewCommandTest(t, "ban").
|
||||
WithArgs("192.168.1.100", "sshd").
|
||||
WithMockBuilder(errorBuilder).
|
||||
ExpectError().
|
||||
Run().
|
||||
AssertContains("IP already banned")
|
||||
|
||||
// Test unban error
|
||||
NewCommandTest(t, "unban").
|
||||
WithArgs("192.168.1.101", "sshd").
|
||||
WithMockBuilder(errorBuilder).
|
||||
ExpectError().
|
||||
Run().
|
||||
AssertContains("IP not found")
|
||||
})
|
||||
}
|
||||
155
cmd/output.go
Normal file
155
cmd/output.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
const (
|
||||
// JSONFormat represents the JSON output format
|
||||
JSONFormat = "json"
|
||||
)
|
||||
|
||||
// Logger is the global logger for the CLI.
|
||||
var Logger = logrus.New()
|
||||
|
||||
func init() {
|
||||
// Set logrus to output to stderr and use a readable format by default.
|
||||
Logger.SetOutput(os.Stderr)
|
||||
Logger.SetFormatter(&logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
})
|
||||
|
||||
// Configure both cmd.Logger and global logrus for CI environments
|
||||
configureCIFriendlyLogging()
|
||||
}
|
||||
|
||||
// configureCIFriendlyLogging sets appropriate log levels for CI/test environments
|
||||
func configureCIFriendlyLogging() {
|
||||
// Detect CI environments by checking common CI environment variables
|
||||
ciEnvVars := []string{
|
||||
"CI", // Generic CI indicator
|
||||
"GITHUB_ACTIONS", // GitHub Actions
|
||||
"TRAVIS", // Travis CI
|
||||
"CIRCLECI", // Circle CI
|
||||
"JENKINS_URL", // Jenkins
|
||||
"BUILDKITE", // Buildkite
|
||||
"TF_BUILD", // Azure DevOps
|
||||
"GITLAB_CI", // GitLab CI
|
||||
}
|
||||
|
||||
isCI := false
|
||||
for _, envVar := range ciEnvVars {
|
||||
if os.Getenv(envVar) != "" {
|
||||
isCI = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if we're in test mode
|
||||
isTest := strings.Contains(os.Args[0], ".test") ||
|
||||
os.Getenv("GO_TEST") == "true" ||
|
||||
flag.Lookup("test.v") != nil
|
||||
|
||||
// If in CI or test environment, reduce logging noise unless explicitly overridden
|
||||
if (isCI || isTest) && os.Getenv("F2B_LOG_LEVEL") == "" && os.Getenv("F2B_VERBOSE_TESTS") == "" {
|
||||
// Set both the cmd.Logger and global logrus to error level
|
||||
Logger.SetLevel(logrus.ErrorLevel)
|
||||
logrus.SetLevel(logrus.ErrorLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// PrintOutput prints data to stdout in the specified format ("plain" or "json").
|
||||
func PrintOutput(data interface{}, format string) {
|
||||
switch format {
|
||||
case JSONFormat:
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(data); err != nil {
|
||||
Logger.WithError(err).Error("Failed to encode JSON output")
|
||||
// Fallback to plain text output
|
||||
if _, printErr := fmt.Fprintln(os.Stdout, data); printErr != nil {
|
||||
Logger.WithError(printErr).Error("Failed to write fallback output")
|
||||
}
|
||||
}
|
||||
default:
|
||||
fmt.Println(data)
|
||||
}
|
||||
}
|
||||
|
||||
// PrintOutputTo prints data to the specified writer in the given format.
|
||||
func PrintOutputTo(w io.Writer, data interface{}, format string) {
|
||||
switch format {
|
||||
case JSONFormat:
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(data); err != nil {
|
||||
Logger.WithError(err).Error("Failed to encode JSON output")
|
||||
// Fallback to plain text output
|
||||
if _, printErr := fmt.Fprintln(w, data); printErr != nil {
|
||||
Logger.WithError(printErr).Error("Failed to write fallback output")
|
||||
}
|
||||
}
|
||||
default:
|
||||
if _, err := fmt.Fprintln(w, data); err != nil {
|
||||
Logger.WithError(err).Error("Failed to write plain output")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PrintError logs and prints an error to stderr with enhanced context if available.
|
||||
func PrintError(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if error provides enhanced context
|
||||
var contextErr *fail2ban.ContextualError
|
||||
if errors.As(err, &contextErr) {
|
||||
Logger.WithFields(map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"category": string(contextErr.GetCategory()),
|
||||
}).Error("Command failed")
|
||||
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
if remediation := contextErr.GetRemediation(); remediation != "" {
|
||||
fmt.Fprintln(os.Stderr, "Hint:", remediation)
|
||||
}
|
||||
} else {
|
||||
Logger.WithError(err).Error("Command failed")
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PrintErrorf logs and prints a formatted error to stderr.
|
||||
func PrintErrorf(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
Logger.Error(msg)
|
||||
fmt.Fprintln(os.Stderr, "Error:", msg)
|
||||
}
|
||||
|
||||
// GetCmdOutput returns the command's output writer if available, otherwise os.Stdout
|
||||
func GetCmdOutput(cmd *cobra.Command) io.Writer {
|
||||
if cmd != nil && cmd.OutOrStdout() != nil {
|
||||
return cmd.OutOrStdout()
|
||||
}
|
||||
return os.Stdout
|
||||
}
|
||||
|
||||
// GetCmdError returns the command's error writer if available, otherwise os.Stderr
|
||||
func GetCmdError(cmd *cobra.Command) io.Writer {
|
||||
if cmd != nil && cmd.ErrOrStderr() != nil {
|
||||
return cmd.ErrOrStderr()
|
||||
}
|
||||
return os.Stderr
|
||||
}
|
||||
263
cmd/parallel_operations.go
Normal file
263
cmd/parallel_operations.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// ParallelOperationProcessor handles parallel ban/unban operations across multiple jails
|
||||
type ParallelOperationProcessor struct {
|
||||
workerCount int
|
||||
}
|
||||
|
||||
// NewParallelOperationProcessor creates a new parallel operation processor
|
||||
func NewParallelOperationProcessor(workerCount int) *ParallelOperationProcessor {
|
||||
if workerCount <= 0 {
|
||||
workerCount = runtime.NumCPU()
|
||||
}
|
||||
return &ParallelOperationProcessor{
|
||||
workerCount: workerCount,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessBanOperationParallel processes ban operations across multiple jails in parallel
|
||||
func (pop *ParallelOperationProcessor) ProcessBanOperationParallel(
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
) ([]OperationResult, error) {
|
||||
if len(jails) <= 1 {
|
||||
// For single jail, use sequential processing to avoid overhead
|
||||
return ProcessBanOperation(client, ip, jails)
|
||||
}
|
||||
|
||||
return pop.processOperations(
|
||||
context.Background(),
|
||||
client,
|
||||
ip,
|
||||
jails,
|
||||
func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
|
||||
return client.BanIPWithContext(ctx, ip, jail)
|
||||
},
|
||||
"ban",
|
||||
)
|
||||
}
|
||||
|
||||
// ProcessBanOperationParallelWithContext processes ban operations across
|
||||
// multiple jails in parallel with timeout context
|
||||
func (pop *ParallelOperationProcessor) ProcessBanOperationParallelWithContext(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
) ([]OperationResult, error) {
|
||||
if len(jails) <= 1 {
|
||||
// For single jail, use sequential processing to avoid overhead
|
||||
return ProcessBanOperationWithContext(ctx, client, ip, jails)
|
||||
}
|
||||
|
||||
return pop.processOperations(
|
||||
ctx,
|
||||
client,
|
||||
ip,
|
||||
jails,
|
||||
func(opCtx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
|
||||
return client.BanIPWithContext(opCtx, ip, jail)
|
||||
},
|
||||
"ban",
|
||||
)
|
||||
}
|
||||
|
||||
// ProcessUnbanOperationParallel processes unban operations across multiple jails in parallel
|
||||
func (pop *ParallelOperationProcessor) ProcessUnbanOperationParallel(
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
) ([]OperationResult, error) {
|
||||
if len(jails) <= 1 {
|
||||
// For single jail, use sequential processing to avoid overhead
|
||||
return ProcessUnbanOperation(client, ip, jails)
|
||||
}
|
||||
|
||||
return pop.processOperations(
|
||||
context.Background(),
|
||||
client,
|
||||
ip,
|
||||
jails,
|
||||
func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
|
||||
return client.UnbanIPWithContext(ctx, ip, jail)
|
||||
},
|
||||
"unban",
|
||||
)
|
||||
}
|
||||
|
||||
// ProcessUnbanOperationParallelWithContext processes unban operations across
|
||||
// multiple jails in parallel with timeout context
|
||||
func (pop *ParallelOperationProcessor) ProcessUnbanOperationParallelWithContext(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
) ([]OperationResult, error) {
|
||||
if len(jails) <= 1 {
|
||||
// For single jail, use sequential processing to avoid overhead
|
||||
return ProcessUnbanOperationWithContext(ctx, client, ip, jails)
|
||||
}
|
||||
|
||||
return pop.processOperations(
|
||||
ctx,
|
||||
client,
|
||||
ip,
|
||||
jails,
|
||||
func(opCtx context.Context, client fail2ban.Client, ip, jail string) (int, error) {
|
||||
return client.UnbanIPWithContext(opCtx, ip, jail)
|
||||
},
|
||||
"unban",
|
||||
)
|
||||
}
|
||||
|
||||
// operationFunc represents a ban or unban operation with context
|
||||
type operationFunc func(ctx context.Context, client fail2ban.Client, ip, jail string) (int, error)
|
||||
|
||||
// processOperations handles the parallel processing of operations
|
||||
func (pop *ParallelOperationProcessor) processOperations(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
operation operationFunc,
|
||||
operationType string,
|
||||
) ([]OperationResult, error) {
|
||||
results := make([]OperationResult, len(jails))
|
||||
resultCh := make(chan operationResult, len(jails))
|
||||
|
||||
// Create worker pool
|
||||
var wg sync.WaitGroup
|
||||
jailCh := make(chan jailWork, len(jails))
|
||||
|
||||
workerCount := pop.workerCount
|
||||
if len(jails) < workerCount {
|
||||
workerCount = len(jails)
|
||||
}
|
||||
|
||||
// Start workers
|
||||
for i := 0; i < workerCount; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pop.worker(ctx, client, ip, operation, operationType, jailCh, resultCh)
|
||||
}()
|
||||
}
|
||||
|
||||
// Send work items
|
||||
go func() {
|
||||
defer close(jailCh)
|
||||
for i, jail := range jails {
|
||||
jailCh <- jailWork{jail: jail, index: i}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for workers to complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
}()
|
||||
|
||||
// Collect results
|
||||
for result := range resultCh {
|
||||
if result.index >= 0 && result.index < len(results) {
|
||||
results[result.index] = result.result
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// jailWork represents work for a specific jail
|
||||
type jailWork struct {
|
||||
jail string
|
||||
index int
|
||||
}
|
||||
|
||||
// operationResult represents the result of an operation
|
||||
type operationResult struct {
|
||||
result OperationResult
|
||||
index int
|
||||
}
|
||||
|
||||
// worker processes jail operations
|
||||
func (pop *ParallelOperationProcessor) worker(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
operation operationFunc,
|
||||
operationType string,
|
||||
jailCh <-chan jailWork,
|
||||
resultCh chan<- operationResult,
|
||||
) {
|
||||
for work := range jailCh {
|
||||
code, err := operation(ctx, client, ip, work.jail)
|
||||
|
||||
var status string
|
||||
if err != nil {
|
||||
status = err.Error()
|
||||
} else {
|
||||
status = InterpretBanStatus(code, operationType)
|
||||
}
|
||||
|
||||
Logger.WithFields(map[string]interface{}{
|
||||
"ip": ip,
|
||||
"jail": work.jail,
|
||||
"status": status,
|
||||
}).Info("Operation result")
|
||||
|
||||
result := operationResult{
|
||||
result: OperationResult{
|
||||
IP: ip,
|
||||
Jail: work.jail,
|
||||
Status: status,
|
||||
},
|
||||
index: work.index,
|
||||
}
|
||||
|
||||
resultCh <- result
|
||||
}
|
||||
}
|
||||
|
||||
// Global processor instance
|
||||
var defaultParallelProcessor = NewParallelOperationProcessor(runtime.NumCPU())
|
||||
|
||||
// ProcessBanOperationParallel processes ban operations in parallel using the default processor
|
||||
func ProcessBanOperationParallel(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
||||
return defaultParallelProcessor.ProcessBanOperationParallel(client, ip, jails)
|
||||
}
|
||||
|
||||
// ProcessBanOperationParallelWithContext processes ban operations in parallel using the default processor with context
|
||||
func ProcessBanOperationParallelWithContext(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
) ([]OperationResult, error) {
|
||||
return defaultParallelProcessor.ProcessBanOperationParallelWithContext(
|
||||
ctx, client, ip, jails)
|
||||
}
|
||||
|
||||
// ProcessUnbanOperationParallel processes unban operations in parallel using the default processor
|
||||
func ProcessUnbanOperationParallel(client fail2ban.Client, ip string, jails []string) ([]OperationResult, error) {
|
||||
return defaultParallelProcessor.ProcessUnbanOperationParallel(client, ip, jails)
|
||||
}
|
||||
|
||||
// ProcessUnbanOperationParallelWithContext processes unban operations in
|
||||
// parallel using the default processor with context
|
||||
func ProcessUnbanOperationParallelWithContext(
|
||||
ctx context.Context,
|
||||
client fail2ban.Client,
|
||||
ip string,
|
||||
jails []string,
|
||||
) ([]OperationResult, error) {
|
||||
return defaultParallelProcessor.ProcessUnbanOperationParallelWithContext(ctx, client, ip, jails)
|
||||
}
|
||||
239
cmd/root.go
Normal file
239
cmd/root.go
Normal file
@@ -0,0 +1,239 @@
|
||||
// Package cmd implements all CLI commands for the f2b tool, providing secure
|
||||
// Fail2Ban management operations including jail monitoring, IP banning/unbanning,
|
||||
// log analysis, and service management with comprehensive input validation.
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// Config holds global configuration for the CLI, including log and filter directories and output format.
|
||||
type Config struct {
|
||||
LogDir string // Path to Fail2Ban log directory
|
||||
FilterDir string // Path to Fail2Ban filter directory
|
||||
Format string // Output format: "plain" or "json"
|
||||
CommandTimeout time.Duration // Timeout for individual fail2ban commands
|
||||
FileTimeout time.Duration // Timeout for file operations
|
||||
ParallelTimeout time.Duration // Timeout for parallel operations
|
||||
}
|
||||
|
||||
var (
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "f2b",
|
||||
Short: "Fail2Ban CLI helper",
|
||||
Long: "Fail2Ban CLI tool implemented in Go using Cobra.",
|
||||
}
|
||||
cfg Config
|
||||
|
||||
// Resource cleanup tracking
|
||||
logFile *os.File
|
||||
logFileMutex sync.Mutex
|
||||
cleanupOnce sync.Once
|
||||
)
|
||||
|
||||
// Execute runs the CLI application with the given client and configuration.
|
||||
func Execute(client fail2ban.Client, config Config) error {
|
||||
cfg = config
|
||||
// Ensure cleanup happens even if the program exits unexpectedly
|
||||
defer cleanupResources()
|
||||
|
||||
// Set up metrics recorder for validation caching
|
||||
fail2ban.SetMetricsRecorder(GetGlobalMetrics())
|
||||
|
||||
ctx := context.Background()
|
||||
rootCmd.AddCommand(ListJailsCmd(client, &cfg))
|
||||
rootCmd.AddCommand(StatusCmd(client, &cfg))
|
||||
rootCmd.AddCommand(BannedCmd(client, &cfg))
|
||||
rootCmd.AddCommand(BanCmd(client, &cfg))
|
||||
rootCmd.AddCommand(UnbanCmd(client, &cfg))
|
||||
rootCmd.AddCommand(TestIPCmd(client, &cfg))
|
||||
rootCmd.AddCommand(LogsCmd(client, &cfg))
|
||||
rootCmd.AddCommand(LogsWatchCmd(ctx, client, &cfg))
|
||||
rootCmd.AddCommand(ServiceCmd(&cfg))
|
||||
rootCmd.AddCommand(VersionCmd(&cfg))
|
||||
rootCmd.AddCommand(TestFilterCmd(client, &cfg))
|
||||
rootCmd.AddCommand(MetricsCmd(client, &cfg))
|
||||
rootCmd.AddCommand(completionCmd())
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Set defaults from env
|
||||
cfg = NewConfigFromEnv()
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.LogDir, "log-dir", cfg.LogDir, "Fail2Ban log directory")
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.FilterDir, "filter-dir", cfg.FilterDir, "Fail2Ban filter directory")
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.Format, "format", cfg.Format, "Output format: plain or json")
|
||||
rootCmd.PersistentFlags().
|
||||
DurationVar(&cfg.CommandTimeout, "command-timeout", cfg.CommandTimeout, "Timeout for individual fail2ban commands")
|
||||
rootCmd.PersistentFlags().
|
||||
DurationVar(&cfg.FileTimeout, "file-timeout", cfg.FileTimeout, "Timeout for file operations")
|
||||
rootCmd.PersistentFlags().
|
||||
DurationVar(&cfg.ParallelTimeout, "parallel-timeout", cfg.ParallelTimeout, "Timeout for parallel operations")
|
||||
|
||||
// Log level configuration
|
||||
logLevel := os.Getenv("F2B_LOG_LEVEL")
|
||||
if logLevel == "" {
|
||||
logLevel = "info"
|
||||
}
|
||||
|
||||
// Log file support
|
||||
logFile := os.Getenv("F2B_LOG_FILE")
|
||||
rootCmd.PersistentFlags().String("log-file", logFile, "Path to log file for f2b logs (optional)")
|
||||
rootCmd.PersistentFlags().String("log-level", logLevel, "Log level (debug, info, warn, error)")
|
||||
|
||||
rootCmd.PersistentPreRun = func(cmd *cobra.Command, _ []string) {
|
||||
logFileFlag, _ := cmd.Flags().GetString("log-file")
|
||||
if logFileFlag != "" {
|
||||
// Validate log file path for security
|
||||
cleanPath, err := filepath.Abs(filepath.Clean(logFileFlag))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Invalid log file path %s: %v\n", logFileFlag, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Additional security check: ensure path doesn't contain dangerous patterns
|
||||
if strings.Contains(cleanPath, "..") || strings.Contains(cleanPath, "//") {
|
||||
fmt.Fprintf(os.Stderr, "Invalid log file path %s: contains dangerous patterns\n", logFileFlag)
|
||||
return
|
||||
}
|
||||
|
||||
// #nosec G304 - Path is validated and sanitized above
|
||||
f, err := os.OpenFile(cleanPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, fail2ban.DefaultFilePermissions)
|
||||
if err == nil {
|
||||
Logger.SetOutput(f)
|
||||
// Register cleanup for graceful shutdown
|
||||
registerLogFileCleanup(f, cleanPath)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Failed to open log file %s: %v\n", cleanPath, err)
|
||||
}
|
||||
}
|
||||
level, _ := cmd.Flags().GetString("log-level")
|
||||
Logger.SetLevel(parseLogLevel(level))
|
||||
}
|
||||
}
|
||||
|
||||
// registerLogFileCleanup registers a log file for cleanup and sets up signal handling
|
||||
func registerLogFileCleanup(f *os.File, _ string) {
|
||||
logFileMutex.Lock()
|
||||
logFile = f
|
||||
logFileMutex.Unlock()
|
||||
|
||||
// Setup signal handler for graceful cleanup (only once)
|
||||
cleanupOnce.Do(func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-c
|
||||
cleanupResources()
|
||||
os.Exit(0)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// cleanupResources performs cleanup of allocated resources
|
||||
func cleanupResources() {
|
||||
logFileMutex.Lock()
|
||||
defer logFileMutex.Unlock()
|
||||
|
||||
if logFile != nil {
|
||||
if err := logFile.Close(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: Failed to close log file: %v\n", err)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Log file closed successfully\n")
|
||||
}
|
||||
logFile = nil
|
||||
}
|
||||
}
|
||||
|
||||
// parseLogLevel parses a string log level for logrus.
|
||||
func parseLogLevel(level string) logrus.Level {
|
||||
switch level {
|
||||
case "debug":
|
||||
return logrus.DebugLevel
|
||||
case "info":
|
||||
return logrus.InfoLevel
|
||||
case "warn", "warning":
|
||||
return logrus.WarnLevel
|
||||
case "error":
|
||||
return logrus.ErrorLevel
|
||||
case "fatal":
|
||||
return logrus.FatalLevel
|
||||
case "panic":
|
||||
return logrus.PanicLevel
|
||||
default:
|
||||
// Log warning about invalid log level before falling back to default
|
||||
Logger.WithField("invalid_level", level).Warn("Invalid log level specified, falling back to 'info'")
|
||||
return logrus.InfoLevel
|
||||
}
|
||||
}
|
||||
|
||||
// completionCmd provides shell completion scripts for bash, zsh, fish, and powershell.
|
||||
func completionCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "completion [bash|zsh|fish|powershell]",
|
||||
Short: "Generate shell completion scripts",
|
||||
Long: `To load completions:
|
||||
|
||||
Bash:
|
||||
|
||||
$ source <(f2b completion bash)
|
||||
|
||||
# To load completions for each session, execute once:
|
||||
# Linux:
|
||||
$ f2b completion bash > /etc/bash_completion.d/f2b
|
||||
# macOS:
|
||||
$ f2b completion bash > /usr/local/etc/bash_completion.d/f2b
|
||||
|
||||
Zsh:
|
||||
|
||||
$ echo "autoload -U compinit; compinit" >> ~/.zshrc
|
||||
$ f2b completion zsh > "${fpath[1]}/_f2b"
|
||||
|
||||
Fish:
|
||||
|
||||
$ f2b completion fish | source
|
||||
$ f2b completion fish > ~/.config/fish/completions/f2b.fish
|
||||
|
||||
PowerShell:
|
||||
|
||||
PS> f2b completion powershell | Out-String | Invoke-Expression
|
||||
PS> f2b completion powershell > f2b.ps1
|
||||
`,
|
||||
DisableFlagsInUseLine: true,
|
||||
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
|
||||
ValidArgs: []string{"bash", "zsh", "fish", "powershell"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
// Get the root command from the current command's parent hierarchy
|
||||
root := cmd.Root()
|
||||
// Note: Cobra's Args validation ensures we have exactly 1 valid argument
|
||||
switch args[0] {
|
||||
case "bash":
|
||||
_ = root.GenBashCompletion(cmd.OutOrStdout())
|
||||
case "zsh":
|
||||
_ = root.GenZshCompletion(cmd.OutOrStdout())
|
||||
case "fish":
|
||||
_ = root.GenFishCompletion(cmd.OutOrStdout(), true)
|
||||
case "powershell":
|
||||
_ = root.GenPowerShellCompletionWithDesc(cmd.OutOrStdout())
|
||||
default:
|
||||
if _, err := fmt.Fprintf(cmd.ErrOrStderr(), "Unsupported shell type: %s\n", args[0]); err != nil {
|
||||
Logger.WithError(err).Error("failed to write unsupported shell type")
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
36
cmd/service.go
Normal file
36
cmd/service.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// ServiceCmd returns the service command with injected config
|
||||
func ServiceCmd(config *Config) *cobra.Command {
|
||||
return NewCommand(
|
||||
"service [start|stop|restart|status|reload|enable|disable]",
|
||||
"Manage the Fail2Ban service",
|
||||
nil,
|
||||
func(_ *cobra.Command, args []string) error {
|
||||
// Validate service action argument
|
||||
if err := RequireArguments(args, 1, "action required: start|stop|restart|status|reload|enable|disable"); err != nil {
|
||||
PrintError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
action := args[0]
|
||||
if err := ValidateServiceAction(action); err != nil {
|
||||
PrintError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
out, err := fail2ban.RunnerCombinedOutputWithSudo("service", "fail2ban", action)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
PrintOutput(string(out), config.Format)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
81
cmd/status.go
Normal file
81
cmd/status.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// StatusCmd returns the status command with injected client and config
|
||||
func StatusCmd(client fail2ban.Client, config *Config) *cobra.Command {
|
||||
return NewCommand(
|
||||
"status [all|<jail>]",
|
||||
"Show status of all jails or a specific jail",
|
||||
[]string{"st", "stat", "show-status"},
|
||||
func(cmd *cobra.Command, args []string) error {
|
||||
// Create timeout context for the entire status operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.CommandTimeout)
|
||||
defer cancel()
|
||||
|
||||
if len(args) == 0 {
|
||||
jails, err := client.ListJailsWithContext(ctx)
|
||||
if err != nil {
|
||||
// Log error but continue with empty jail list for help display
|
||||
Logger.WithError(err).Warn("Failed to fetch jails for help display")
|
||||
jails = []string{}
|
||||
}
|
||||
PrintOutputTo(
|
||||
GetCmdOutput(cmd),
|
||||
"Usage: "+cmd.Root().Use+" status all (show all jails)",
|
||||
config.Format,
|
||||
)
|
||||
PrintOutputTo(
|
||||
GetCmdOutput(cmd),
|
||||
" "+cmd.Root().Use+" status <jail> (show specific jail)",
|
||||
config.Format,
|
||||
)
|
||||
PrintOutputTo(GetCmdOutput(cmd), "Available jails: "+strings.Join(jails, " "), config.Format)
|
||||
return nil
|
||||
}
|
||||
|
||||
target := strings.ToLower(args[0])
|
||||
if target == "all" {
|
||||
out, err := client.StatusAllWithContext(ctx)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
status := FormatStatusResult("", out)
|
||||
PrintOutputTo(GetCmdOutput(cmd), status, config.Format)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if jail exists (with timeout context)
|
||||
jails, err := client.ListJailsWithContext(ctx)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
jailExists := false
|
||||
for _, j := range jails {
|
||||
if j == target {
|
||||
jailExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !jailExists {
|
||||
return HandleClientError(fail2ban.NewJailNotFoundError(target))
|
||||
}
|
||||
|
||||
out, err := client.StatusJailWithContext(ctx, target)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
status := FormatStatusResult(target, out)
|
||||
PrintOutputTo(GetCmdOutput(cmd), status, config.Format)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
121
cmd/status_command_refactored_test.go
Normal file
121
cmd/status_command_refactored_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// TestStatusCommandRefactored demonstrates comprehensive status command testing with the modern framework
|
||||
func TestStatusCommandRefactored(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
jails []string
|
||||
statusAll string
|
||||
statusJail map[string]string
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "status all",
|
||||
args: []string{"all"},
|
||||
jails: []string{"sshd"},
|
||||
statusAll: "Status for all jails\n",
|
||||
wantOutput: "Status for all jails\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "status specific jail",
|
||||
args: []string{"sshd"},
|
||||
jails: []string{"sshd"},
|
||||
statusJail: map[string]string{"sshd": "Status for sshd jail\n"},
|
||||
wantOutput: "Status for sshd jail\n",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "status nonexistent jail",
|
||||
args: []string{"nonexistent"},
|
||||
jails: []string{"sshd"},
|
||||
wantOutput: "Error: jail 'nonexistent' not found",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "status no args shows usage",
|
||||
args: []string{},
|
||||
jails: []string{"sshd"},
|
||||
wantOutput: "Available jails: sshd",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Framework approach with fluent interface
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
builder := NewCommandTest(t, "status").
|
||||
WithArgs(tt.args...).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, tt.jails)
|
||||
if tt.statusAll != "" {
|
||||
mock.StatusAllData = tt.statusAll
|
||||
}
|
||||
if tt.statusJail != nil {
|
||||
mock.StatusJailData = tt.statusJail
|
||||
}
|
||||
})
|
||||
|
||||
if tt.wantError {
|
||||
builder = builder.ExpectError()
|
||||
} else {
|
||||
builder = builder.ExpectSuccess()
|
||||
}
|
||||
|
||||
if tt.wantOutput != "" {
|
||||
builder = builder.ExpectOutput(tt.wantOutput)
|
||||
}
|
||||
|
||||
builder.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStatusCommandFrameworkAdvanced shows advanced features of the framework
|
||||
func TestStatusCommandFrameworkAdvanced(t *testing.T) {
|
||||
// Environment setup with privileges
|
||||
env := NewTestEnvironment().
|
||||
WithPrivileges(true).
|
||||
WithMockRunner()
|
||||
defer env.Cleanup()
|
||||
|
||||
// Complex test scenario with JSON output
|
||||
NewCommandTest(t, "status").
|
||||
WithArgs("sshd").
|
||||
WithJSONFormat().
|
||||
WithEnvironment(env).
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd", "apache"})
|
||||
mock.StatusJailData = map[string]string{
|
||||
"sshd": "Status for sshd jail",
|
||||
}
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run().
|
||||
AssertContains("Status for sshd jail").
|
||||
AssertNotContains("apache") // Should not contain other jail info
|
||||
|
||||
// Chained assertions example
|
||||
result := NewCommandTest(t, "status").
|
||||
WithArgs("all").
|
||||
WithSetup(func(mock *fail2ban.MockClient) {
|
||||
setMockJails(mock, []string{"sshd", "apache", "nginx"})
|
||||
mock.StatusAllData = "All jails status summary"
|
||||
}).
|
||||
ExpectSuccess().
|
||||
Run()
|
||||
|
||||
// Multiple assertions on same result
|
||||
result.AssertContains("All jails").
|
||||
AssertContains("status").
|
||||
AssertNotEmpty().
|
||||
AssertNotContains("error")
|
||||
}
|
||||
131
cmd/table_test_standards.go
Normal file
131
cmd/table_test_standards.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package cmd
|
||||
|
||||
// TableTestStandards defines standardized field names and patterns for table-driven tests
|
||||
// This file serves as documentation and helper types for consistent table test structure
|
||||
//
|
||||
// IMPLEMENTED STANDARDIZATION:
|
||||
// Successfully standardized field naming across all cmd test files:
|
||||
// - expectedOut/expectedOutput → wantOutput
|
||||
// - expectError → wantError
|
||||
// - expectedError → wantErrorMsg
|
||||
//
|
||||
// Files standardized:
|
||||
// - cmd_commands_test.go (multiple test functions)
|
||||
// - cmd_root_test.go (completion and execute tests)
|
||||
// - cmd_logswatch_test.go (logs watch tests)
|
||||
// - cmd_service_test.go (service command tests)
|
||||
// - status_command_refactored_test.go (status tests)
|
||||
|
||||
// StandardTestCase provides a standardized structure for basic table-driven tests
|
||||
type StandardTestCase struct {
|
||||
Name string `json:"name"` // Test case name - REQUIRED
|
||||
Args []string `json:"args"` // Command arguments
|
||||
WantOutput string `json:"wantOutput"` // Expected output content
|
||||
WantError bool `json:"wantError"` // Whether error is expected
|
||||
WantErrorMsg string `json:"wantErrorMsg"` // Specific error message to check (optional)
|
||||
}
|
||||
|
||||
// CommandTestCase extends StandardTestCase for command testing with mock setup
|
||||
type CommandTestCase struct {
|
||||
StandardTestCase
|
||||
Setup func(*MockClientBuilder) `json:"-"` // Setup function for mock configuration
|
||||
MockSetup func(interface{}) `json:"-"` // Generic mock setup function
|
||||
}
|
||||
|
||||
// ServiceTestCase specialized for service command testing
|
||||
type ServiceTestCase struct {
|
||||
StandardTestCase
|
||||
MockResponse string `json:"mockResponse"` // Mock service response
|
||||
MockError error `json:"mockError"` // Mock service error
|
||||
}
|
||||
|
||||
// LogsTestCase specialized for logs command testing
|
||||
type LogsTestCase struct {
|
||||
StandardTestCase
|
||||
MockLogs []string `json:"mockLogs"` // Mock log lines
|
||||
Limit int `json:"limit"` // Log limit
|
||||
}
|
||||
|
||||
// StandardFieldNames defines the recommended field naming conventions
|
||||
var StandardFieldNames = map[string]string{
|
||||
// Required fields
|
||||
"name": "name", // Test case identifier
|
||||
"args": "args", // Command arguments
|
||||
|
||||
// Output expectations
|
||||
"expectedOutput": "wantOutput", // Use wantOutput instead
|
||||
"expectedOut": "wantOutput", // Use wantOutput instead
|
||||
"expected": "wantOutput", // Use wantOutput instead
|
||||
|
||||
// Error expectations
|
||||
"expectError": "wantError", // Use wantError instead
|
||||
"isError": "wantError", // Use wantError instead
|
||||
"expectedError": "wantErrorMsg", // Use wantErrorMsg instead
|
||||
|
||||
// Setup patterns
|
||||
"setupMock": "setup", // Use setup instead
|
||||
"mockSetup": "setup", // Use setup instead
|
||||
"setupBanned": "setup", // Use setup instead
|
||||
"setupBans": "setup", // Use setup instead
|
||||
}
|
||||
|
||||
// ConversionGuide provides examples of before/after standardization
|
||||
type ConversionGuide struct {
|
||||
Before string
|
||||
After string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// StandardizationExamples provides examples of before/after field name conversions
|
||||
var StandardizationExamples = []ConversionGuide{
|
||||
{
|
||||
Before: `tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedOut string
|
||||
expectError bool
|
||||
}`,
|
||||
After: `tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantOutput string
|
||||
wantError bool
|
||||
}`,
|
||||
Reason: "Consistent with Go testing conventions using 'want' prefix",
|
||||
},
|
||||
{
|
||||
Before: `if output != tt.expectedOut {
|
||||
t.Errorf("expected %q, got %q", tt.expectedOut, output)
|
||||
}`,
|
||||
After: `if output != tt.wantOutput {
|
||||
t.Errorf("expected %q, got %q", tt.wantOutput, output)
|
||||
}`,
|
||||
Reason: "Consistent field naming reduces cognitive load",
|
||||
},
|
||||
{
|
||||
Before: `AssertError(t, err, tt.expectError, tt.name)`,
|
||||
After: `AssertError(t, err, tt.wantError, tt.name)`,
|
||||
Reason: "Aligns with Go testing best practices",
|
||||
},
|
||||
}
|
||||
|
||||
// TestFieldValidator provides validation for test case structures
|
||||
type TestFieldValidator struct {
|
||||
RequiredFields []string
|
||||
RecommendedNames map[string]string
|
||||
}
|
||||
|
||||
// NewTestFieldValidator creates a validator with standard recommendations
|
||||
func NewTestFieldValidator() *TestFieldValidator {
|
||||
return &TestFieldValidator{
|
||||
RequiredFields: []string{"name"},
|
||||
RecommendedNames: StandardFieldNames,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateFieldNaming checks if field names follow standards (implementation would go here)
|
||||
func (v *TestFieldValidator) ValidateFieldNaming(_ string) []string {
|
||||
// This would contain logic to validate struct field names
|
||||
// For now, returns empty slice (implementation can be added later)
|
||||
return []string{}
|
||||
}
|
||||
117
cmd/test_helpers.go
Normal file
117
cmd/test_helpers.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// MockClient is a type alias for the enhanced MockClient from fail2ban package
|
||||
type MockClient = fail2ban.MockClient
|
||||
|
||||
// NewMockClient creates a new MockClient for testing
|
||||
func NewMockClient() *MockClient {
|
||||
return fail2ban.NewMockClient()
|
||||
}
|
||||
|
||||
// setMockJails sets jails for the enhanced MockClient
|
||||
func setMockJails(mock *MockClient, jails []string) {
|
||||
mock.Jails = make(map[string]struct{})
|
||||
for _, jail := range jails {
|
||||
mock.Jails[jail] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// setupMockEnvironment sets up mock sudo checker for testing
|
||||
func setupMockEnvironment() func() {
|
||||
// Set up mock sudo checker
|
||||
originalChecker := fail2ban.GetSudoChecker()
|
||||
mockChecker := &fail2ban.MockSudoChecker{
|
||||
MockHasPrivileges: true,
|
||||
ExplicitPrivilegesSet: true,
|
||||
}
|
||||
fail2ban.SetSudoChecker(mockChecker)
|
||||
|
||||
// Return cleanup function
|
||||
return func() {
|
||||
fail2ban.SetSudoChecker(originalChecker)
|
||||
}
|
||||
}
|
||||
|
||||
// executeCommand executes a command with proper mock setup and output capture
|
||||
func executeCommand(client fail2ban.Client, args ...string) (string, error) {
|
||||
// Suppress logrus output during tests
|
||||
oldLoggerOut := Logger.Out
|
||||
Logger.SetOutput(io.Discard)
|
||||
defer Logger.SetOutput(oldLoggerOut)
|
||||
|
||||
// Ensure mock sudo checker is set for commands that need it
|
||||
cleanup := setupMockEnvironment()
|
||||
defer cleanup()
|
||||
|
||||
rootCmd := &cobra.Command{Use: "f2b"}
|
||||
config := Config{Format: "plain"}
|
||||
|
||||
// Set up persistent flags like in the real root command
|
||||
rootCmd.PersistentFlags().StringVar(&config.Format, "format", config.Format, "Output format: plain or json")
|
||||
|
||||
rootCmd.AddCommand(ListJailsCmd(client, &config))
|
||||
rootCmd.AddCommand(StatusCmd(client, &config))
|
||||
rootCmd.AddCommand(BanCmd(client, &config))
|
||||
rootCmd.AddCommand(UnbanCmd(client, &config))
|
||||
rootCmd.AddCommand(TestIPCmd(client, &config))
|
||||
rootCmd.AddCommand(LogsCmd(client, &config))
|
||||
rootCmd.AddCommand(BannedCmd(client, &config))
|
||||
rootCmd.AddCommand(VersionCmd(&config))
|
||||
rootCmd.AddCommand(TestFilterCmd(client, &config))
|
||||
rootCmd.AddCommand(MetricsCmd(client, &config))
|
||||
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
rootCmd.SetArgs(args)
|
||||
err := rootCmd.Execute()
|
||||
|
||||
// Filter out logrus lines (starting with "time="), but keep "Error:" lines for error output tests
|
||||
lines := strings.Split(buf.String(), "\n")
|
||||
var filtered []string
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "time=") {
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
}
|
||||
// Remove trailing empty lines
|
||||
for len(filtered) > 0 && filtered[len(filtered)-1] == "" {
|
||||
filtered = filtered[:len(filtered)-1]
|
||||
}
|
||||
return strings.Join(filtered, "\n") + "\n", err
|
||||
}
|
||||
|
||||
// AssertError provides standardized error checking for command tests
|
||||
func AssertError(t interface {
|
||||
Helper()
|
||||
Fatalf(string, ...interface{})
|
||||
}, err error, expectError bool, testName string) {
|
||||
t.Helper()
|
||||
if expectError && err == nil {
|
||||
t.Fatalf("%s: expected error but got none", testName)
|
||||
}
|
||||
if !expectError && err != nil {
|
||||
t.Fatalf("%s: unexpected error: %v", testName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertOutputContains checks that output contains expected substring
|
||||
func AssertOutputContains(t interface {
|
||||
Helper()
|
||||
Fatalf(string, ...interface{})
|
||||
}, output, expectedSubstring, testName string) {
|
||||
t.Helper()
|
||||
if !strings.Contains(output, expectedSubstring) {
|
||||
t.Fatalf("%s: expected output containing %q but got %q", testName, expectedSubstring, output)
|
||||
}
|
||||
}
|
||||
33
cmd/testip.go
Normal file
33
cmd/testip.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// TestIPCmd returns the test command with injected client and config
|
||||
func TestIPCmd(client interface {
|
||||
BannedInWithContext(context.Context, string) ([]string, error)
|
||||
}, config *Config) *cobra.Command {
|
||||
return NewCommand("test <ip>", "Test if an IP is banned", nil, func(cmd *cobra.Command, args []string) error {
|
||||
// Create timeout context for testing IP
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.CommandTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Validate IP argument
|
||||
ip, err := ValidateIPArgument(args)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
jails, err := client.BannedInWithContext(ctx, ip)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
result := FormatBannedResult(ip, jails)
|
||||
PrintOutputTo(GetCmdOutput(cmd), result, config.Format)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
73
cmd/unban.go
Normal file
73
cmd/unban.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ivuorinen/f2b/fail2ban"
|
||||
)
|
||||
|
||||
// UnbanCmd returns the unban command with injected client and config
|
||||
func UnbanCmd(client fail2ban.Client, config *Config) *cobra.Command {
|
||||
return NewCommand(
|
||||
"unban <ip> [jail]",
|
||||
"Unban an IP address",
|
||||
[]string{"unbanip", "ub"},
|
||||
func(cmd *cobra.Command, args []string) error {
|
||||
// Get the contextual logger
|
||||
logger := GetContextualLogger()
|
||||
|
||||
// Create timeout context for the entire unban operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.CommandTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Add command context
|
||||
ctx = WithCommand(ctx, "unban")
|
||||
|
||||
// Log operation with timing
|
||||
return logger.LogOperation(ctx, "unban_command", func() error {
|
||||
// Validate IP argument
|
||||
ip, err := ValidateIPArgument(args)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
// Add IP to context
|
||||
ctx = WithIP(ctx, ip)
|
||||
|
||||
// Get jails from arguments or client (with timeout context)
|
||||
jails, err := GetJailsFromArgsWithContext(ctx, client, args, 1)
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
// Process unban operation with timeout context (use parallel processing for multiple jails)
|
||||
var results []OperationResult
|
||||
if len(jails) > 1 {
|
||||
// Use parallel timeout for multi-jail operations
|
||||
parallelCtx, parallelCancel := context.WithTimeout(ctx, config.ParallelTimeout)
|
||||
defer parallelCancel()
|
||||
results, err = ProcessUnbanOperationParallelWithContext(parallelCtx, client, ip, jails)
|
||||
} else {
|
||||
results, err = ProcessUnbanOperationWithContext(ctx, client, ip, jails)
|
||||
}
|
||||
if err != nil {
|
||||
return HandleClientError(err)
|
||||
}
|
||||
|
||||
// Output results
|
||||
if config != nil && config.Format == JSONFormat {
|
||||
PrintOutputTo(GetCmdOutput(cmd), results, JSONFormat)
|
||||
} else {
|
||||
for _, r := range results {
|
||||
if _, err := fmt.Fprintf(GetCmdOutput(cmd), "%s %s in %s\n", r.Status, r.IP, r.Jail); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
26
cmd/version.go
Normal file
26
cmd/version.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Version holds the build version and can be overridden at build time with ldflags
|
||||
var Version = "dev"
|
||||
|
||||
// VersionCmd returns the version command with output consistency
|
||||
func VersionCmd(config *Config) *cobra.Command {
|
||||
cmd := NewCommand("version", "Show f2b version", nil, func(cmd *cobra.Command, _ []string) error {
|
||||
PrintOutputTo(GetCmdOutput(cmd), fmt.Sprintf("f2b version %s", Version), config.Format)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Override Run to keep existing behavior (no error handling for version)
|
||||
cmd.Run = func(cmd *cobra.Command, _ []string) {
|
||||
PrintOutputTo(GetCmdOutput(cmd), fmt.Sprintf("f2b version %s", Version), config.Format)
|
||||
}
|
||||
cmd.RunE = nil
|
||||
|
||||
return cmd
|
||||
}
|
||||
Reference in New Issue
Block a user