diff --git a/modules/cassandra/cassandra.go b/modules/cassandra/cassandra.go index 4a847261a3..56d785ed28 100644 --- a/modules/cassandra/cassandra.go +++ b/modules/cassandra/cassandra.go @@ -2,28 +2,38 @@ package cassandra import ( "context" + "crypto/tls" "fmt" "io" "path/filepath" "strings" + "time" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) const ( - port = "9042/tcp" + port = nat.Port("9042/tcp") + securePort = nat.Port("9142/tcp") // Common port for SSL/TLS connections ) // CassandraContainer represents the Cassandra container type used in the module type CassandraContainer struct { testcontainers.Container + settings Options } -// ConnectionHost returns the host and port of the cassandra container, using the default, native 9042 port, and +// ConnectionHost returns the host and port of the cassandra container, using the default, native port, // obtaining the host and exposed port from the container func (c *CassandraContainer) ConnectionHost(ctx context.Context) (string, error) { - return c.PortEndpoint(ctx, port, "") + // Use the secure port if TLS is enabled + portToUse := port + if c.settings.tlsConfig != nil { + portToUse = securePort + } + + return c.PortEndpoint(ctx, portToUse, "") } // WithConfigFile sets the YAML config file to be used for the cassandra container @@ -37,7 +47,6 @@ func WithConfigFile(configFile string) testcontainers.CustomizeRequestOption { FileMode: 0o755, } req.Files = append(req.Files, cf) - return nil } } @@ -54,10 +63,8 @@ func WithInitScripts(scripts ...string) testcontainers.CustomizeRequestOption { FileMode: 0o755, } initScripts = append(initScripts, cf) - execs = append(execs, initScript{File: cf.ContainerFilePath}) } - req.Files = append(req.Files, initScripts...) return testcontainers.WithAfterReadyCommand(execs...)(req) } @@ -71,31 +78,54 @@ func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomize // Run creates an instance of the Cassandra container type func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*CassandraContainer, error) { - moduleOpts := []testcontainers.ContainerCustomizer{ - testcontainers.WithExposedPorts(port), - testcontainers.WithEnv(map[string]string{ + req := testcontainers.ContainerRequest{ + Image: img, + Env: map[string]string{ "CASSANDRA_SNITCH": "GossipingPropertyFileSnitch", "JVM_OPTS": "-Dcassandra.skip_wait_for_gossip_to_settle=0 -Dcassandra.initial_token=0", "HEAP_NEWSIZE": "128M", "MAX_HEAP_SIZE": "1024M", "CASSANDRA_ENDPOINT_SNITCH": "GossipingPropertyFileSnitch", "CASSANDRA_DC": "datacenter1", - }), - testcontainers.WithWaitStrategy(wait.ForAll( - wait.ForListeningPort(port), - wait.ForExec([]string{"cqlsh", "-e", "SELECT bootstrapped FROM system.local"}).WithResponseMatcher(func(body io.Reader) bool { - data, _ := io.ReadAll(body) - return strings.Contains(string(data), "COMPLETED") - }), - )), + }, + ExposedPorts: []string{string(port)}, } - moduleOpts = append(moduleOpts, opts...) + genericContainerReq := testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + } - ctr, err := testcontainers.Run(ctx, img, moduleOpts...) + var settings Options + for _, opt := range opts { + if err := opt.Customize(&genericContainerReq); err != nil { + return nil, err + } + } + + // Set up wait strategies + waitStrategies := []wait.Strategy{ + wait.ForListeningPort(port), + wait.ForExec([]string{"cqlsh", "-e", "SELECT bootstrapped FROM system.local"}).WithResponseMatcher(func(body io.Reader) bool { + data, _ := io.ReadAll(body) + return strings.Contains(string(data), "COMPLETED") + }).WithStartupTimeout(1 * time.Minute), + } + + // Add TLS wait strategy if TLS config exists + if settings.tlsConfig != nil { + waitStrategies = append(waitStrategies, wait.ForListeningPort(securePort).WithStartupTimeout(1*time.Minute)) + } + + // Apply wait strategy using the correct method + if err := testcontainers.WithWaitStrategy(wait.ForAll(waitStrategies...)).Customize(&genericContainerReq); err != nil { + return nil, err + } + + container, err := testcontainers.GenericContainer(ctx, genericContainerReq) var c *CassandraContainer - if ctr != nil { - c = &CassandraContainer{Container: ctr} + if container != nil { + c = &CassandraContainer{Container: container, settings: settings} } if err != nil { @@ -104,3 +134,11 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom return c, nil } + +// TLSConfig returns the TLS configuration for the Cassandra container, nil if TLS is not enabled. +func (c *CassandraContainer) TLSConfig() *tls.Config { + if c.settings.tlsConfig == nil { + return nil + } + return c.settings.tlsConfig.Config +} diff --git a/modules/cassandra/cassandra_test.go b/modules/cassandra/cassandra_test.go index 4ca7385235..eb3a5448de 100644 --- a/modules/cassandra/cassandra_test.go +++ b/modules/cassandra/cassandra_test.go @@ -2,8 +2,10 @@ package cassandra_test import ( "context" + "fmt" "path/filepath" "testing" + "time" "github.com/gocql/gocql" "github.com/stretchr/testify/require" @@ -12,6 +14,8 @@ import ( "github.com/testcontainers/testcontainers-go/modules/cassandra" ) +const cassandraImage = "cassandra:4.1.3" + type Test struct { ID uint64 Name string @@ -20,7 +24,7 @@ type Test struct { func TestCassandra(t *testing.T) { ctx := context.Background() - ctr, err := cassandra.Run(ctx, "cassandra:4.1.3") + ctr, err := cassandra.Run(ctx, cassandraImage) testcontainers.CleanupContainer(t, ctr) require.NoError(t, err) @@ -52,7 +56,7 @@ func TestCassandra(t *testing.T) { func TestCassandraWithConfigFile(t *testing.T) { ctx := context.Background() - ctr, err := cassandra.Run(ctx, "cassandra:4.1.3", cassandra.WithConfigFile(filepath.Join("testdata", "config.yaml"))) + ctr, err := cassandra.Run(ctx, cassandraImage, cassandra.WithConfigFile(filepath.Join("testdata", "config.yaml"))) testcontainers.CleanupContainer(t, ctr) require.NoError(t, err) @@ -75,7 +79,7 @@ func TestCassandraWithInitScripts(t *testing.T) { ctx := context.Background() // withInitScripts { - ctr, err := cassandra.Run(ctx, "cassandra:4.1.3", cassandra.WithInitScripts(filepath.Join("testdata", "init.cql"))) + ctr, err := cassandra.Run(ctx, cassandraImage, cassandra.WithInitScripts(filepath.Join("testdata", "init.cql"))) // } testcontainers.CleanupContainer(t, ctr) require.NoError(t, err) @@ -99,7 +103,7 @@ func TestCassandraWithInitScripts(t *testing.T) { t.Run("with init bash script", func(t *testing.T) { ctx := context.Background() - ctr, err := cassandra.Run(ctx, "cassandra:4.1.3", cassandra.WithInitScripts(filepath.Join("testdata", "init.sh"))) + ctr, err := cassandra.Run(ctx, cassandraImage, cassandra.WithInitScripts(filepath.Join("testdata", "init.sh"))) testcontainers.CleanupContainer(t, ctr) require.NoError(t, err) @@ -117,3 +121,42 @@ func TestCassandraWithInitScripts(t *testing.T) { require.Equal(t, Test{ID: 1, Name: "NAME"}, test) }) } + +func TestCassandraSSL(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + container, err := cassandra.Run(ctx, cassandraImage, + cassandra.WithConfigFile(filepath.Join("testdata", "cassandra-ssl.yaml")), + cassandra.WithSSL(), + ) + testcontainers.CleanupContainer(t, container) + require.NoError(t, err) + + // Get TLS configurations + tlsConfig := container.TLSConfig() + + host, err := container.Host(ctx) + require.NoError(t, err) + + sslPort, err := container.MappedPort(ctx, "9142/tcp") + require.NoError(t, err) + + cluster := gocql.NewCluster(fmt.Sprintf("%s:%s", host, sslPort.Port())) + cluster.Consistency = gocql.Quorum + cluster.Timeout = 30 * time.Second + cluster.ConnectTimeout = 30 * time.Second + cluster.DisableInitialHostLookup = true + cluster.SslOpts = &gocql.SslOptions{ + Config: tlsConfig, + EnableHostVerification: false, + } + var session *gocql.Session + session, err = cluster.CreateSession() + require.NoError(t, err) + defer session.Close() + var version string + err = session.Query("SELECT release_version FROM system.local").Scan(&version) + require.NoError(t, err) + require.NotEmpty(t, version) +} diff --git a/modules/cassandra/examples_test.go b/modules/cassandra/examples_test.go index 68a80589ea..45e1454a8f 100644 --- a/modules/cassandra/examples_test.go +++ b/modules/cassandra/examples_test.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "path/filepath" + "time" "github.com/gocql/gocql" @@ -67,3 +68,64 @@ func ExampleRun() { // true // 4.1.3 } + +func ExampleRun_withSSL() { + ctx := context.Background() + + cassandraContainer, err := cassandra.Run(ctx, + "cassandra:4.1.3", + cassandra.WithConfigFile(filepath.Join("testdata", "cassandra-ssl.yaml")), + cassandra.WithSSL(), + ) + defer func() { + if err := testcontainers.TerminateContainer(cassandraContainer); err != nil { + log.Printf("failed to terminate container: %s", err) + } + }() + if err != nil { + log.Printf("failed to start container: %s", err) + return + } + + host, err := cassandraContainer.Host(ctx) + if err != nil { + log.Printf("failed to get host: %s", err) + return + } + + sslPort, err := cassandraContainer.MappedPort(ctx, "9142/tcp") + if err != nil { + log.Printf("failed to get SSL port: %s", err) + return + } + + // Get TLS config + tlsConfig := cassandraContainer.TLSConfig() + + cluster := gocql.NewCluster(fmt.Sprintf("%s:%s", host, sslPort.Port())) + cluster.Consistency = gocql.Quorum + cluster.Timeout = 30 * time.Second + cluster.ConnectTimeout = 30 * time.Second + cluster.DisableInitialHostLookup = true + cluster.SslOpts = &gocql.SslOptions{ + Config: tlsConfig, + EnableHostVerification: false, + } + session, err := cluster.CreateSession() + if err != nil { + log.Printf("failed to create session: %s", err) + return + } + defer session.Close() + + var version string + err = session.Query("SELECT release_version FROM system.local").Scan(&version) + if err != nil { + log.Printf("failed to query: %s", err) + return + } + + fmt.Println(version) + // Output: + // 4.1.3 +} diff --git a/modules/cassandra/options.go b/modules/cassandra/options.go new file mode 100644 index 0000000000..29a293bb6f --- /dev/null +++ b/modules/cassandra/options.go @@ -0,0 +1,127 @@ +package cassandra + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + + "github.com/testcontainers/testcontainers-go" +) + +// TLSConfig represents the TLS configuration for Cassandra +type TLSConfig struct { + KeystorePath string + CertificatePath string + Config *tls.Config +} + +// Options represents the configuration options for the Cassandra container +type Options struct { + tlsConfig *TLSConfig +} + +// Option is an option for the Cassandra container. +type Option func(*testcontainers.GenericContainerRequest, *Options) error + +// Customize implements the testcontainers.ContainerCustomizer interface +func (o Option) Customize(req *testcontainers.GenericContainerRequest) error { + return o(req, &Options{}) +} + +// WithSSL enables SSL/TLS support on the Cassandra container +func WithSSL() Option { + return func(req *testcontainers.GenericContainerRequest, settings *Options) error { + req.ExposedPorts = append(req.ExposedPorts, string(securePort)) + + keystorePath, certPath, err := GenerateJKSKeystore() + if err != nil { + return fmt.Errorf("create SSL certs: %w", err) + } + + req.Files = append(req.Files, + testcontainers.ContainerFile{ + HostFilePath: keystorePath, + ContainerFilePath: "/etc/cassandra/conf/keystore.jks", + FileMode: 0o644, + }, + testcontainers.ContainerFile{ + HostFilePath: certPath, + ContainerFilePath: "/etc/cassandra/conf/cassandra.crt", + FileMode: 0o644, + }) + + certPEM, err := os.ReadFile(certPath) + if err != nil { + return fmt.Errorf("error while read certificate: %w", err) + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(certPEM) { + return errors.New("failed to append certificate to pool") + } + + settings.tlsConfig = &TLSConfig{ + KeystorePath: keystorePath, + CertificatePath: certPath, + Config: &tls.Config{ + RootCAs: certPool, + ServerName: "localhost", + MinVersion: tls.VersionTLS12, + }, + } + + return nil + } +} + +// GenerateJKSKeystore generates a JKS keystore with a self-signed cert using keytool, and extracts the public cert for Go client trust. +func GenerateJKSKeystore() (keystorePath, certPath string, err error) { + tmpDir := os.TempDir() + keystorePath = filepath.Join(tmpDir, "keystore.jks") + keystorePassword := "changeit" + certPath = filepath.Join(tmpDir, "cert.pem") + + // Remove existing keystore if it exists + os.Remove(keystorePath) + + // Generate keystore with self-signed certificate + cmd := exec.Command( + "keytool", "-genkeypair", + "-alias", "cassandra", + "-keyalg", "RSA", + "-keysize", "2048", + "-storetype", "JKS", + "-keystore", keystorePath, + "-storepass", keystorePassword, + "-keypass", keystorePassword, + "-dname", "CN=localhost, OU=Test, O=Test, C=US", + "-validity", "365", + ) + if err := cmd.Run(); err != nil { + return "", "", fmt.Errorf("failed to generate keystore: %w", err) + } + + // Export the public certificate for Go client trust + cmd = exec.Command( + "keytool", "-exportcert", + "-alias", "cassandra", + "-keystore", keystorePath, + "-storepass", keystorePassword, + "-rfc", + "-file", certPath, + ) + if err := cmd.Run(); err != nil { + return "", "", fmt.Errorf("failed to export certificate: %w", err) + } + + return keystorePath, certPath, nil +} + +// TLSConfig returns the TLS configuration +func (o *Options) TLSConfig() *TLSConfig { + return o.tlsConfig +} diff --git a/modules/cassandra/options_test.go b/modules/cassandra/options_test.go new file mode 100644 index 0000000000..450dbf59bf --- /dev/null +++ b/modules/cassandra/options_test.go @@ -0,0 +1,93 @@ +package cassandra_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/cassandra" +) + +func TestWithSSL(t *testing.T) { + // Create a test container request + req := &testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{}, + } + opts := &cassandra.Options{} + + // Test that WithSSL configures TLS + err := cassandra.WithSSL()(req, opts) + require.NoError(t, err) + require.NotNil(t, opts.TLSConfig(), "TLS config should be set") + require.NotEmpty(t, opts.TLSConfig().KeystorePath, "Keystore path should be set") + require.NotEmpty(t, opts.TLSConfig().CertificatePath, "Certificate path should be set") + require.Contains(t, req.ExposedPorts, "9142/tcp", "Secure port should be exposed") + require.Len(t, req.Files, 2, "Should have keystore and certificate files") +} + +func TestGenerateJKSKeystore(t *testing.T) { + // Test keystore generation + keystorePath, certPath, err := cassandra.GenerateJKSKeystore() + require.NoError(t, err) + + // Verify that both files exist + _, err = os.Stat(keystorePath) + require.NoError(t, err, "keystore file should exist") + + _, err = os.Stat(certPath) + require.NoError(t, err, "certificate file should exist") + + // Verify file extensions + require.Equal(t, ".jks", filepath.Ext(keystorePath), "keystore should have .jks extension") + require.Equal(t, ".pem", filepath.Ext(certPath), "certificate should have .pem extension") + + // Clean up + os.Remove(keystorePath) + os.Remove(certPath) +} + +func TestGenerateJKSKeystoreOverwrite(t *testing.T) { + // Test that existing keystore is overwritten + keystorePath, certPath, err := cassandra.GenerateJKSKeystore() + require.NoError(t, err) + + // Get initial file info + initialKeystoreInfo, err := os.Stat(keystorePath) + require.NoError(t, err) + + // Generate new keystore + newKeystorePath, newCertPath, err := cassandra.GenerateJKSKeystore() + require.NoError(t, err) + + // Verify paths are the same + require.Equal(t, keystorePath, newKeystorePath) + require.Equal(t, certPath, newCertPath) + + // Get new file info + newKeystoreInfo, err := os.Stat(keystorePath) + require.NoError(t, err) + + // Verify that the file was modified + require.NotEqual(t, initialKeystoreInfo.ModTime(), newKeystoreInfo.ModTime(), "keystore should be overwritten") + + // Clean up + os.Remove(keystorePath) + os.Remove(certPath) +} + +func TestGenerateJKSKeystoreInvalidKeytool(t *testing.T) { + // Save original PATH + originalPath := os.Getenv("PATH") + defer os.Setenv("PATH", originalPath) + + // Set invalid PATH to make keytool unavailable + os.Setenv("PATH", "/nonexistent") + + // Test that keystore generation fails when keytool is not available + _, _, err := cassandra.GenerateJKSKeystore() + require.Error(t, err) + require.Contains(t, err.Error(), "failed to generate keystore") +} diff --git a/modules/cassandra/testdata/cassandra-ssl.yaml b/modules/cassandra/testdata/cassandra-ssl.yaml new file mode 100644 index 0000000000..2711dc78ef --- /dev/null +++ b/modules/cassandra/testdata/cassandra-ssl.yaml @@ -0,0 +1,131 @@ +# SSL Configuration for client connections +cluster_name: "My Cluster" +num_tokens: 16 +allocate_tokens_for_local_replication_factor: 3 +hinted_handoff_enabled: true +max_hint_window: 3h +hinted_handoff_throttle: 1024KiB +max_hints_delivery_threads: 2 +hints_flush_period: 10000ms +max_hints_file_size: 128MiB +auto_hints_cleanup_enabled: false +batchlog_replay_throttle: 1024KiB +authenticator: AllowAllAuthenticator +authorizer: AllowAllAuthorizer +role_manager: CassandraRoleManager +network_authorizer: AllowAllNetworkAuthorizer +roles_validity: 2000ms +permissions_validity: 2000ms +credentials_validity: 2000ms +partitioner: org.apache.cassandra.dht.Murmur3Partitioner +cdc_enabled: false +disk_failure_policy: stop +commit_failure_policy: stop +prepared_statements_cache_size: +key_cache_size: +key_cache_save_period: 4h +row_cache_size: 0MiB +row_cache_save_period: 0s +counter_cache_size: +counter_cache_save_period: 7200s +commitlog_sync: periodic +commitlog_sync_period: 10000ms +commitlog_segment_size: 32MiB +seed_provider: + - class_name: org.apache.cassandra.locator.SimpleSeedProvider + parameters: + - seeds: "192.168.215.2" +concurrent_reads: 32 +concurrent_writes: 32 +concurrent_counter_writes: 32 +concurrent_materialized_view_writes: 32 +memtable_allocation_type: heap_buffers +index_summary_capacity: +index_summary_resize_interval: 60m +trickle_fsync: false +trickle_fsync_interval: 10240KiB +storage_port: 7000 +ssl_storage_port: 7001 +listen_address: 192.168.215.2 +broadcast_address: 192.168.215.2 +start_native_transport: true +native_transport_port: 9042 +native_transport_allow_older_protocols: true +rpc_address: 0.0.0.0 +broadcast_rpc_address: 192.168.215.2 +rpc_keepalive: true +incremental_backups: false +snapshot_before_compaction: false +auto_snapshot: true +snapshot_links_per_second: 0 +column_index_size: 64KiB +column_index_cache_size: 2KiB +concurrent_materialized_view_builders: 1 +compaction_throughput: 64MiB/s +sstable_preemptive_open_interval: 50MiB +uuid_sstable_identifiers_enabled: false +read_request_timeout: 5000ms +range_request_timeout: 10000ms +write_request_timeout: 2000ms +counter_write_request_timeout: 5000ms +cas_contention_timeout: 1000ms +truncate_request_timeout: 60000ms +request_timeout: 10000ms +slow_query_log_timeout: 500ms +endpoint_snitch: SimpleSnitch +dynamic_snitch_update_interval: 100ms +dynamic_snitch_reset_interval: 600000ms +dynamic_snitch_badness_threshold: 1.0 +internode_compression: dc +inter_dc_tcp_nodelay: false +trace_type_query_ttl: 1d +trace_type_repair_ttl: 7d +user_defined_functions_enabled: false +scripted_user_defined_functions_enabled: false +transparent_data_encryption_options: + enabled: false + chunk_length_kb: 64 + cipher: AES/CBC/PKCS5Padding + key_alias: testing:1 + key_provider: + - class_name: org.apache.cassandra.security.JKSKeyProvider + parameters: + - keystore: conf/.keystore + keystore_password: cassandra + store_type: JCEKS + key_password: cassandra +tombstone_warn_threshold: 1000 +tombstone_failure_threshold: 100000 +replica_filtering_protection: + cached_rows_warn_threshold: 2000 + cached_rows_fail_threshold: 32000 +batch_size_warn_threshold: 5KiB +batch_size_fail_threshold: 50KiB +unlogged_batch_across_partitions_warn_threshold: 10 +compaction_large_partition_warning_threshold: 100MiB +compaction_tombstone_warning_threshold: 100000 +audit_logging_options: + enabled: false + logger: + - class_name: BinAuditLogger +diagnostic_events_enabled: false +repaired_data_tracking_for_range_reads_enabled: false +repaired_data_tracking_for_partition_reads_enabled: false +report_unconfirmed_repaired_data_mismatches: false +materialized_views_enabled: false +sasi_indexes_enabled: false +transient_replication_enabled: false +drop_compact_storage_enabled: false +client_encryption_options: + enabled: true + optional: false + keystore: /etc/cassandra/conf/keystore.jks + keystore_password: changeit + require_client_auth: false + protocol: TLS + algorithm: SunX509 + store_type: JKS + cipher_suites: [TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA] + +# Enable SSL port +native_transport_port_ssl: 9142