package fail2ban import ( "context" "errors" "os" "path/filepath" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // setupBasicMockResponses sets up the basic responses needed for client initialization func setupBasicMockResponses(m *MockRunner) { m.SetResponse("fail2ban-client -V", []byte("Fail2Ban v0.11.0")) m.SetResponse("sudo fail2ban-client -V", []byte("Fail2Ban v0.11.0")) m.SetResponse("fail2ban-client ping", []byte("Server replied: pong")) m.SetResponse("sudo fail2ban-client ping", []byte("Server replied: pong")) // NewClient calls fetchJailsWithContext which runs status m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache")) m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 2\n`- Jail list: sshd, apache")) } // TestListJailsWithContext tests jail listing with context func TestListJailsWithContext(t *testing.T) { tests := []struct { name string setupMock func(*MockRunner) timeout time.Duration expectError bool expectJails []string }{ { name: "successful jail listing", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) }, timeout: 5 * time.Second, expectError: false, expectJails: []string{"sshd", "apache"}, // From setupBasicMockResponses }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock := NewMockRunner() tt.setupMock(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) defer cancel() if tt.timeout == 1*time.Nanosecond { time.Sleep(2 * time.Millisecond) // Ensure timeout } jails, err := client.ListJailsWithContext(ctx) if tt.expectError { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.expectJails, jails) } }) } } // TestStatusAllWithContext tests status all with context func TestStatusAllWithContext(t *testing.T) { tests := []struct { name string setupMock func(*MockRunner) timeout time.Duration expectError bool expectContains string }{ { name: "successful status all", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) }, timeout: 5 * time.Second, expectError: false, expectContains: "Status", }, { name: "context timeout", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse("fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) m.SetResponse("sudo fail2ban-client status", []byte("Status\n|- Number of jail: 1\n`- Jail list: sshd")) }, timeout: 1 * time.Nanosecond, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock := NewMockRunner() tt.setupMock(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) defer cancel() if tt.timeout == 1*time.Nanosecond { time.Sleep(2 * time.Millisecond) } status, err := client.StatusAllWithContext(ctx) if tt.expectError { assert.Error(t, err) } else { assert.NoError(t, err) assert.Contains(t, status, tt.expectContains) } }) } } // TestStatusJailWithContext tests status jail with context func TestStatusJailWithContext(t *testing.T) { tests := []struct { name string jail string setupMock func(*MockRunner) timeout time.Duration expectError bool expectContains string }{ { name: "successful status jail", jail: "sshd", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse( "fail2ban-client status sshd", []byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"), ) m.SetResponse( "sudo fail2ban-client status sshd", []byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"), ) }, timeout: 5 * time.Second, expectError: false, expectContains: "sshd", }, { name: "invalid jail name", jail: "invalid@jail", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) // Validation will fail before command execution }, timeout: 5 * time.Second, expectError: true, }, { name: "context timeout", jail: "sshd", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse( "fail2ban-client status sshd", []byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"), ) m.SetResponse( "sudo fail2ban-client status sshd", []byte("Status for the jail: sshd\n|- Filter\n`- Currently banned: 0"), ) }, timeout: 1 * time.Nanosecond, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock := NewMockRunner() tt.setupMock(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) defer cancel() if tt.timeout == 1*time.Nanosecond { time.Sleep(2 * time.Millisecond) } status, err := client.StatusJailWithContext(ctx, tt.jail) if tt.expectError { assert.Error(t, err) } else { assert.NoError(t, err) if tt.expectContains != "" { assert.Contains(t, status, tt.expectContains) } } }) } } // TestUnbanIPWithContext tests unban IP with context func TestUnbanIPWithContext(t *testing.T) { tests := []struct { name string ip string jail string setupMock func(*MockRunner) timeout time.Duration expectError bool expectCode int }{ { name: "successful unban", ip: "192.168.1.100", jail: "sshd", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("0")) m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("0")) }, timeout: 5 * time.Second, expectError: false, expectCode: 0, }, { name: "already unbanned", ip: "192.168.1.100", jail: "sshd", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("1")) m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("1")) }, timeout: 5 * time.Second, expectError: false, expectCode: 1, }, { name: "invalid IP address", ip: "invalid-ip", jail: "sshd", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) // Validation will fail before command execution }, timeout: 5 * time.Second, expectError: true, }, { name: "invalid jail name", ip: "192.168.1.100", jail: "invalid@jail", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) // Validation will fail before command execution }, timeout: 5 * time.Second, expectError: true, }, { name: "context timeout", ip: "192.168.1.100", jail: "sshd", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse("fail2ban-client set sshd unbanip 192.168.1.100", []byte("0")) m.SetResponse("sudo fail2ban-client set sshd unbanip 192.168.1.100", []byte("0")) }, timeout: 1 * time.Nanosecond, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock := NewMockRunner() tt.setupMock(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) defer cancel() if tt.timeout == 1*time.Nanosecond { time.Sleep(2 * time.Millisecond) } code, err := client.UnbanIPWithContext(ctx, tt.ip, tt.jail) if tt.expectError { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.expectCode, code) } }) } } // TestListFiltersWithContext tests filter listing with context func TestListFiltersWithContext(t *testing.T) { tests := []struct { name string setupMock func(*MockRunner) setupEnv func() timeout time.Duration expectError bool expectFilters []string }{ { name: "successful filter listing", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) // Mock responses not needed - uses file system }, setupEnv: func() { // Client will use default filter directory }, timeout: 5 * time.Second, expectError: false, expectFilters: nil, // Will depend on actual filter directory }, { name: "context timeout", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) // Not applicable for file system operation }, setupEnv: func() { // No setup needed }, timeout: 1 * time.Nanosecond, expectError: true, // Context check happens first }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock := NewMockRunner() tt.setupMock(mock) SetRunner(mock) tt.setupEnv() client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) defer cancel() if tt.timeout == 1*time.Nanosecond { time.Sleep(2 * time.Millisecond) } filters, err := client.ListFiltersWithContext(ctx) if tt.expectError { assert.Error(t, err) } else { // May error if directory doesn't exist, which is fine in tests if err == nil { assert.NotNil(t, filters) } } }) } } // TestTestFilterWithContext tests filter testing with context func TestTestFilterWithContext(t *testing.T) { // Enable dev paths to allow temporary directory t.Setenv("ALLOW_DEV_PATHS", "1") // Create temporary filter directory tmpDir := t.TempDir() filterContent := `[Definition] failregex = ^.* Failed .* for .* from logpath = /var/log/auth.log ` err := os.WriteFile(filepath.Join(tmpDir, "sshd.conf"), []byte(filterContent), 0600) require.NoError(t, err) tests := []struct { name string filter string setupMock func(*MockRunner) timeout time.Duration expectError bool }{ { name: "successful filter test", filter: "sshd", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse( "fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"), []byte("Success: 0 matches"), ) m.SetResponse( "sudo fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"), []byte("Success: 0 matches"), ) }, timeout: 5 * time.Second, expectError: false, }, { name: "invalid filter name", filter: "invalid@filter", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) // Validation will fail before command execution }, timeout: 5 * time.Second, expectError: true, }, { name: "context timeout", filter: "sshd", setupMock: func(m *MockRunner) { setupBasicMockResponses(m) m.SetResponse( "fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"), []byte("Success: 0 matches"), ) m.SetResponse( "sudo fail2ban-regex /var/log/auth.log "+filepath.Join(tmpDir, "sshd.conf"), []byte("Success: 0 matches"), ) }, timeout: 1 * time.Nanosecond, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mock := NewMockRunner() tt.setupMock(mock) SetRunner(mock) client, err := NewClient("/var/log", tmpDir) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) defer cancel() if tt.timeout == 1*time.Nanosecond { time.Sleep(2 * time.Millisecond) } result, err := client.TestFilterWithContext(ctx, tt.filter) if tt.expectError { assert.Error(t, err) } else { assert.NoError(t, err) assert.NotEmpty(t, result) } }) } } // TestWithContextCancellation tests that all WithContext functions respect cancellation func TestWithContextCancellation(t *testing.T) { mock := NewMockRunner() setupBasicMockResponses(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") require.NoError(t, err) // Create canceled context ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately // Note: ListJailsWithContext and ListFiltersWithContext are too fast to be canceled // as they return cached data or read from filesystem. Only testing I/O operations. t.Run("StatusAllWithContext respects cancellation", func(t *testing.T) { _, err := client.StatusAllWithContext(ctx) assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled) || isContextError(err)) }) t.Run("StatusJailWithContext respects cancellation", func(t *testing.T) { _, err := client.StatusJailWithContext(ctx, "sshd") assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled) || isContextError(err)) }) t.Run("UnbanIPWithContext respects cancellation", func(t *testing.T) { _, err := client.UnbanIPWithContext(ctx, "192.168.1.100", "sshd") assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled) || isContextError(err)) }) } // TestWithContextDeadline tests that all WithContext functions respect deadlines func TestWithContextDeadline(t *testing.T) { mock := NewMockRunner() setupBasicMockResponses(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") require.NoError(t, err) // Create context with very short deadline ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() // Ensure timeout time.Sleep(2 * time.Millisecond) // Note: ListJailsWithContext, ListFiltersWithContext, and TestFilterWithContext // are too fast to timeout as they return cached data or read from filesystem. // Only testing I/O operations that make network/command calls. tests := []struct { name string fn func() error }{ { name: "StatusAllWithContext", fn: func() error { _, err := client.StatusAllWithContext(ctx) return err }, }, { name: "StatusJailWithContext", fn: func() error { _, err := client.StatusJailWithContext(ctx, "sshd") return err }, }, { name: "UnbanIPWithContext", fn: func() error { _, err := client.UnbanIPWithContext(ctx, "192.168.1.100", "sshd") return err }, }, } for _, tt := range tests { t.Run(tt.name+" respects deadline", func(t *testing.T) { err := tt.fn() assert.Error(t, err) assert.True(t, errors.Is(err, context.DeadlineExceeded) || isContextError(err)) }) } } // TestWithContextValidation tests that validation happens before context usage func TestWithContextValidation(t *testing.T) { mock := NewMockRunner() setupBasicMockResponses(mock) SetRunner(mock) client, err := NewClient("/var/log", "/etc/fail2ban/filter.d") require.NoError(t, err) ctx := context.Background() t.Run("StatusJailWithContext validates jail name", func(t *testing.T) { _, err := client.StatusJailWithContext(ctx, "invalid@jail") assert.Error(t, err) assert.Contains(t, err.Error(), "invalid") }) t.Run("UnbanIPWithContext validates IP", func(t *testing.T) { _, err := client.UnbanIPWithContext(ctx, "invalid-ip", "sshd") assert.Error(t, err) }) t.Run("UnbanIPWithContext validates jail", func(t *testing.T) { _, err := client.UnbanIPWithContext(ctx, "192.168.1.100", "invalid@jail") assert.Error(t, err) assert.Contains(t, err.Error(), "invalid") }) t.Run("TestFilterWithContext validates filter", func(t *testing.T) { _, err := client.TestFilterWithContext(ctx, "invalid@filter") assert.Error(t, err) }) }