diff --git a/.gitignore b/.gitignore index 5d4c8caa..bad788bd 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,7 @@ vendor __debug_bin .vscode /proxy/proxy -/proxy/*.prof \ No newline at end of file +/proxy/*.prof + +# Claude Code +CLAUDE.md \ No newline at end of file diff --git a/README.md b/README.md index 70fa60fd..898ad37d 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ ZDM_PROXY_LISTEN_ADDRESS=127.0.0.1 ZDM_PRIMARY_CLUSTER=ORIGIN ZDM_READ_MODE=PRIMARY_ONLY ZDM_LOG_LEVEL=INFO +# ZDM_TARGET_CONSISTENCY_LEVEL=LOCAL_ONE #optional, overrides CL on target during migration ``` The environment variables (or YAM configuration file) must be set for the proxy to work. diff --git a/docs/assets/zdm-config-reference.yml b/docs/assets/zdm-config-reference.yml index e9d239f0..a2edebb8 100644 --- a/docs/assets/zdm-config-reference.yml +++ b/docs/assets/zdm-config-reference.yml @@ -55,7 +55,10 @@ origin_port: 9042 # Local data center for origin cluster. # origin_local_datacenter: -# Origin cluster username. +# Origin cluster username. Avoid using a superuser account for application workloads. +# Superuser authentication in Cassandra requires QUORUM consistency internally, which +# increases the risk of auth failures during node instability. The proxy will log a +# warning at startup if the configured user is a superuser. origin_username: user1 # Origin cluster password. @@ -89,7 +92,10 @@ target_contact_points: 127.0.0.2 # Port used when connecting to nodes from target cluster. target_port: 9042 -# Target cluster username. +# Target cluster username. Avoid using a superuser account for application workloads. +# Superuser authentication in Cassandra requires QUORUM consistency internally, which +# increases the risk of auth failures during node instability. The proxy will log a +# warning at startup if the configured user is a superuser. target_username: user2 # Target cluster password. @@ -166,6 +172,15 @@ proxy_listen_port: 14002 # List of histogram buckets for measuring latency of asynchronous # read requests routed to target cluster. See parameter "read_mode". # metrics_async_read_latency_buckets_ms: 1, 4, 7, 10, 25, 40, 60, 80, 100, 150, 250, 500, 1000, 2500, 5000, 10000, 15000 +# +# Per-table write success metric (automatically populated, no configuration needed): +# The proxy exposes a Prometheus counter "proxy_write_success_total" with labels +# {cluster="origin|target", keyspace="", table=""} that tracks +# successful writes per cluster, keyspace, and table. This counter is incremented +# independently when each cluster responds successfully, providing visibility into +# which tables are being written to and whether both clusters are keeping up. +# During a target cluster outage, origin counters continue to increment while target +# counters flatline, making it easy to identify the scope of any data divergence. # Frequency (in ms) with which heartbeats will be sent on cluster connections # (i.e. all control and request connections to Origin and Target). Heartbeats @@ -180,3 +195,22 @@ proxy_listen_port: 14002 # Control connection failure threshold. If threshold is exceeded, # readiness probe of ZDM will report failure and pod will be recreated. # heartbeat_failure_threshold: 1 + +# Override the consistency level used for all requests forwarded to the target cluster. +# When this property is set, the proxy replaces the client-requested consistency level with the +# specified value on every request sent to the target cluster (reads and writes). The origin cluster +# always receives the original client-requested consistency level, preserving the consistency +# contract on the source of truth. +# +# This is useful during migration when the target cluster is being populated via dual writes. Using +# a weaker consistency level such as LOCAL_ONE on the target reduces the risk of write failures +# caused by target-side instability (e.g. node outages, streaming, or compaction pressure). Because +# the target data can be repaired after migration is complete, temporary under-replication is +# acceptable and preferable to failing writes that would otherwise succeed on origin. +# +# When this property is absent, empty, or not set, the proxy forwards requests to the target with the +# original client-requested consistency level (default behavior, no override). +# +# Valid values: ANY, ONE, TWO, THREE, QUORUM, ALL, LOCAL_QUORUM, EACH_QUORUM, LOCAL_ONE +# (case-insensitive). Serial consistency levels (SERIAL, LOCAL_SERIAL) are not valid here. +# target_consistency_level: LOCAL_ONE diff --git a/integration-tests/per_table_write_metrics_ccm_test.go b/integration-tests/per_table_write_metrics_ccm_test.go new file mode 100644 index 00000000..c1a4e2a8 --- /dev/null +++ b/integration-tests/per_table_write_metrics_ccm_test.go @@ -0,0 +1,295 @@ +package integration_tests + +import ( + "fmt" + "net/http" + "strings" + "testing" + + gocql "github.com/apache/cassandra-gocql-driver/v2" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/utils" +) + +const metricsTable = "metrics_test_data" +const countersTable = "metrics_test_counters" +const countersTable2 = "metrics_test_counters2" +const batchTableA = "metrics_test_batch_a" +const batchTableB = "metrics_test_batch_b" + +// TestPerTableWriteMetricsCCM tests per-table write success metrics against real Cassandra clusters via CCM. +// Full permutation matrix: +// - Statement types: INSERT, UPDATE, DELETE, counter UPDATE +// - Execution modes: inline (Query), prepared (Prepare+Execute) +// - Batch modes: batch with inline children, batch with prepared children, batch with mixed children +// +// Each test verifies that the Prometheus metric is tracked independently for origin and target. +func TestPerTableWriteMetricsCCM(t *testing.T) { + proxyInstance, err := NewProxyInstanceForGlobalCcmClusters(t) + require.Nil(t, err) + defer proxyInstance.Shutdown() + + // Start a dedicated metrics HTTP server for this test + const metricsAddr = "localhost:14099" + mux := http.NewServeMux() + mux.Handle("/metrics", proxyInstance.GetMetricHandler().GetHttpHandler()) + metricsSrv := &http.Server{Addr: metricsAddr, Handler: mux} + go func() { + if err := metricsSrv.ListenAndServe(); err != http.ErrServerClosed { + log.Warnf("metrics server error: %v", err) + } + }() + defer metricsSrv.Close() + + originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters(t) + require.Nil(t, err) + + targetSession := targetCluster.GetSession() + + // Create test tables on both clusters + createTables := func(s *gocql.Session) { + s.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", setup.TestKeyspace, metricsTable)).Exec() + s.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", setup.TestKeyspace, countersTable)).Exec() + s.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", setup.TestKeyspace, countersTable2)).Exec() + s.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", setup.TestKeyspace, batchTableA)).Exec() + s.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", setup.TestKeyspace, batchTableB)).Exec() + + require.Nil(t, s.Query(fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s.%s (id uuid PRIMARY KEY, name text)", setup.TestKeyspace, metricsTable)).Exec()) + require.Nil(t, s.Query(fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s.%s (id uuid PRIMARY KEY, count counter)", setup.TestKeyspace, countersTable)).Exec()) + require.Nil(t, s.Query(fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s.%s (id uuid PRIMARY KEY, count counter)", setup.TestKeyspace, countersTable2)).Exec()) + require.Nil(t, s.Query(fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s.%s (id uuid PRIMARY KEY, val text)", setup.TestKeyspace, batchTableA)).Exec()) + require.Nil(t, s.Query(fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s.%s (id uuid PRIMARY KEY, val text)", setup.TestKeyspace, batchTableB)).Exec()) + } + createTables(originCluster.GetSession()) + createTables(targetSession) + + // Connect to proxy + proxy, err := utils.ConnectToCluster("127.0.0.1", "", "", 14002) + require.Nil(t, err) + defer proxy.Close() + + ks := setup.TestKeyspace + + // ================================================================ + // INLINE STATEMENTS + // ================================================================ + + t.Run("inline_insert", func(t *testing.T) { + err = proxy.Query(fmt.Sprintf( + "INSERT INTO %s.%s (id, name) VALUES (d1b05da0-8c20-11ea-9fc6-6d2c86545d91, 'alice')", ks, metricsTable)).Exec() + require.Nil(t, err) + assertMetricOnBothClusters(t, ks, metricsTable) + }) + + t.Run("inline_update", func(t *testing.T) { + err = proxy.Query(fmt.Sprintf( + "UPDATE %s.%s SET name = 'updated' WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91", ks, metricsTable)).Exec() + require.Nil(t, err) + assertMetricOnBothClusters(t, ks, metricsTable) + }) + + t.Run("inline_delete", func(t *testing.T) { + err = proxy.Query(fmt.Sprintf( + "DELETE FROM %s.%s WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91", ks, metricsTable)).Exec() + require.Nil(t, err) + assertMetricOnBothClusters(t, ks, metricsTable) + }) + + t.Run("inline_counter_update", func(t *testing.T) { + err = proxy.Query(fmt.Sprintf( + "UPDATE %s.%s SET count = count + 1 WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91", ks, countersTable)).Exec() + require.Nil(t, err) + assertMetricOnBothClusters(t, ks, countersTable) + }) + + // ================================================================ + // PREPARED STATEMENTS + // ================================================================ + + t.Run("prepared_insert", func(t *testing.T) { + q := proxy.Query(fmt.Sprintf("INSERT INTO %s.%s (id, name) VALUES (?, ?)", ks, metricsTable)) + q.Bind("eed574b0-8c20-11ea-9fc6-6d2c86545d91", "prepared_alice") + require.Nil(t, q.Exec()) + assertMetricOnBothClusters(t, ks, metricsTable) + }) + + t.Run("prepared_update", func(t *testing.T) { + q := proxy.Query(fmt.Sprintf("UPDATE %s.%s SET name = ? WHERE id = ?", ks, metricsTable)) + q.Bind("prepared_updated", "eed574b0-8c20-11ea-9fc6-6d2c86545d91") + require.Nil(t, q.Exec()) + assertMetricOnBothClusters(t, ks, metricsTable) + }) + + t.Run("prepared_delete", func(t *testing.T) { + q := proxy.Query(fmt.Sprintf("DELETE FROM %s.%s WHERE id = ?", ks, metricsTable)) + q.Bind("eed574b0-8c20-11ea-9fc6-6d2c86545d91") + require.Nil(t, q.Exec()) + assertMetricOnBothClusters(t, ks, metricsTable) + }) + + t.Run("prepared_counter_update", func(t *testing.T) { + q := proxy.Query(fmt.Sprintf("UPDATE %s.%s SET count = count + ? WHERE id = ?", ks, countersTable)) + q.Bind(int64(5), "eed574b0-8c20-11ea-9fc6-6d2c86545d91") + require.Nil(t, q.Exec()) + assertMetricOnBothClusters(t, ks, countersTable) + }) + + // ================================================================ + // BATCH WITH INLINE CHILDREN + // ================================================================ + + t.Run("batch_inline_inserts_multi_table", func(t *testing.T) { + batch := proxy.NewBatch(gocql.LoggedBatch) + batch.Query(fmt.Sprintf("INSERT INTO %s.%s (id, val) VALUES (cf0f4cf0-8c20-11ea-9fc6-6d2c86545d91, 'a')", ks, batchTableA)) + batch.Query(fmt.Sprintf("INSERT INTO %s.%s (id, val) VALUES (cf0f4cf0-8c20-11ea-9fc6-6d2c86545d92, 'b')", ks, batchTableB)) + require.Nil(t, proxy.ExecuteBatch(batch)) + assertMetricOnBothClusters(t, ks, batchTableA) + assertMetricOnBothClusters(t, ks, batchTableB) + }) + + t.Run("batch_inline_update_and_delete", func(t *testing.T) { + batch := proxy.NewBatch(gocql.LoggedBatch) + batch.Query(fmt.Sprintf("UPDATE %s.%s SET val = 'updated' WHERE id = cf0f4cf0-8c20-11ea-9fc6-6d2c86545d91", ks, batchTableA)) + batch.Query(fmt.Sprintf("DELETE FROM %s.%s WHERE id = cf0f4cf0-8c20-11ea-9fc6-6d2c86545d92", ks, batchTableB)) + require.Nil(t, proxy.ExecuteBatch(batch)) + assertMetricOnBothClusters(t, ks, batchTableA) + assertMetricOnBothClusters(t, ks, batchTableB) + }) + + // ================================================================ + // BATCH WITH PREPARED CHILDREN + // gocql automatically prepares statements when batch.Query() is + // called with bind parameters — the batch children become prepared + // statement IDs, not inline query strings. + // ================================================================ + + t.Run("batch_prepared_inserts_multi_table", func(t *testing.T) { + batch := proxy.NewBatch(gocql.LoggedBatch) + batch.Query(fmt.Sprintf("INSERT INTO %s.%s (id, val) VALUES (?, ?)", ks, batchTableA), "cf0f4cf0-8c20-11ea-9fc6-6d2c86545da1", "prep_a") + batch.Query(fmt.Sprintf("INSERT INTO %s.%s (id, val) VALUES (?, ?)", ks, batchTableB), "cf0f4cf0-8c20-11ea-9fc6-6d2c86545da2", "prep_b") + require.Nil(t, proxy.ExecuteBatch(batch)) + assertMetricOnBothClusters(t, ks, batchTableA) + assertMetricOnBothClusters(t, ks, batchTableB) + }) + + t.Run("batch_prepared_update_and_delete", func(t *testing.T) { + batch := proxy.NewBatch(gocql.LoggedBatch) + batch.Query(fmt.Sprintf("UPDATE %s.%s SET val = ? WHERE id = ?", ks, batchTableA), "batch_updated", "cf0f4cf0-8c20-11ea-9fc6-6d2c86545da1") + batch.Query(fmt.Sprintf("DELETE FROM %s.%s WHERE id = ?", ks, batchTableB), "cf0f4cf0-8c20-11ea-9fc6-6d2c86545da2") + require.Nil(t, proxy.ExecuteBatch(batch)) + assertMetricOnBothClusters(t, ks, batchTableA) + assertMetricOnBothClusters(t, ks, batchTableB) + }) + + // ================================================================ + // BATCH WITH MIXED INLINE AND PREPARED CHILDREN + // ================================================================ + + t.Run("batch_mixed_inline_and_prepared", func(t *testing.T) { + batch := proxy.NewBatch(gocql.LoggedBatch) + // Inline child (no bind params) + batch.Query(fmt.Sprintf("INSERT INTO %s.%s (id, val) VALUES (cf0f4cf0-8c20-11ea-9fc6-6d2c86545db1, 'inline')", ks, batchTableA)) + // Prepared child (with bind params — gocql will prepare this) + batch.Query(fmt.Sprintf("INSERT INTO %s.%s (id, val) VALUES (?, ?)", ks, batchTableB), "cf0f4cf0-8c20-11ea-9fc6-6d2c86545db2", "prepared") + require.Nil(t, proxy.ExecuteBatch(batch)) + assertMetricOnBothClusters(t, ks, batchTableA) + assertMetricOnBothClusters(t, ks, batchTableB) + }) + + // ================================================================ + // COUNTER BATCH (inline) + // ================================================================ + + t.Run("batch_counter_inline", func(t *testing.T) { + batch := proxy.NewBatch(gocql.CounterBatch) + batch.Query(fmt.Sprintf("UPDATE %s.%s SET count = count + 1 WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91", ks, countersTable)) + batch.Query(fmt.Sprintf("UPDATE %s.%s SET count = count + 1 WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91", ks, countersTable2)) + require.Nil(t, proxy.ExecuteBatch(batch)) + assertMetricOnBothClusters(t, ks, countersTable) + assertMetricOnBothClusters(t, ks, countersTable2) + }) + + // ================================================================ + // COUNTER BATCH (prepared — gocql prepares when bind params are used) + // ================================================================ + + t.Run("batch_counter_prepared", func(t *testing.T) { + batch := proxy.NewBatch(gocql.CounterBatch) + batch.Query(fmt.Sprintf("UPDATE %s.%s SET count = count + ? WHERE id = ?", ks, countersTable), int64(3), "eed574b0-8c20-11ea-9fc6-6d2c86545d91") + batch.Query(fmt.Sprintf("UPDATE %s.%s SET count = count + ? WHERE id = ?", ks, countersTable2), int64(3), "eed574b0-8c20-11ea-9fc6-6d2c86545d91") + require.Nil(t, proxy.ExecuteBatch(batch)) + assertMetricOnBothClusters(t, ks, countersTable) + assertMetricOnBothClusters(t, ks, countersTable2) + }) + + // ================================================================ + // DATA VERIFICATION ON BOTH CLUSTERS + // ================================================================ + + t.Run("verify_counter_on_target", func(t *testing.T) { + var count int64 + err := targetSession.Query(fmt.Sprintf( + "SELECT count FROM %s.%s WHERE id = d1b05da0-8c20-11ea-9fc6-6d2c86545d91", ks, countersTable)).Scan(&count) + require.Nil(t, err) + require.True(t, count >= 1, "counter should be at least 1, got %d", count) + }) + + t.Run("verify_batch_data_on_target", func(t *testing.T) { + var val string + err := targetSession.Query(fmt.Sprintf( + "SELECT val FROM %s.%s WHERE id = cf0f4cf0-8c20-11ea-9fc6-6d2c86545da1", ks, batchTableA)).Scan(&val) + require.Nil(t, err) + require.Equal(t, "batch_updated", val) + }) +} + +// assertMetricOnBothClusters verifies that the write success metric exists for both origin and target. +func assertMetricOnBothClusters(t *testing.T, keyspace string, table string) { + t.Helper() + lines := gatherCCMMetricLines(t) + requireMetricPresent(t, lines, "proxy_write_success_total", "origin", keyspace, table) + requireMetricPresent(t, lines, "proxy_write_success_total", "target", keyspace, table) +} + +// gatherCCMMetricLines scrapes the metrics endpoint used by CCM tests. +func gatherCCMMetricLines(t *testing.T) []string { + t.Helper() + statusCode, rspStr, err := utils.GetMetrics("localhost:14099") + require.Nil(t, err) + require.Equal(t, http.StatusOK, statusCode) + + var result []string + for _, line := range strings.Split(rspStr, "\n") { + if !strings.HasPrefix(line, "#") && strings.TrimSpace(line) != "" { + result = append(result, line) + } + } + return result +} + +// requireMetricPresent checks that a write_success metric exists for the given cluster/keyspace/table. +func requireMetricPresent(t *testing.T, lines []string, metricName string, cluster string, keyspace string, table string) { + t.Helper() + prefix := fmt.Sprintf(`zdm_%s{cluster="%s",keyspace="%s",table="%s"}`, metricName, cluster, keyspace, table) + for _, line := range lines { + if strings.HasPrefix(line, prefix) { + return + } + } + + var matching []string + for _, line := range lines { + if strings.Contains(line, "write_success") { + matching = append(matching, line) + } + } + t.Errorf("metric not found with prefix: %q\nAll write_success lines: %v", prefix, matching) +} diff --git a/integration-tests/target_consistency_override_ccm_test.go b/integration-tests/target_consistency_override_ccm_test.go new file mode 100644 index 00000000..0a900fec --- /dev/null +++ b/integration-tests/target_consistency_override_ccm_test.go @@ -0,0 +1,188 @@ +package integration_tests + +import ( + "fmt" + "strings" + "testing" + "time" + + gocql "github.com/apache/cassandra-gocql-driver/v2" + "github.com/stretchr/testify/require" + + "github.com/datastax/zdm-proxy/integration-tests/env" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/utils" +) + +// TestTargetConsistencyOverrideCCM verifies that the ZDM_TARGET_CONSISTENCY_LEVEL config +// overrides the consistency level on the target cluster while preserving the original +// client-requested CL on origin. Verified via Cassandra system_traces. +// +// Test matrix: +// - Inline INSERT at QUORUM → origin should see QUORUM, target should see LOCAL_ONE +// - Prepared INSERT at QUORUM → same verification via EXECUTE trace +// - Batch INSERT at QUORUM → same verification via BATCH trace +func TestTargetConsistencyOverrideCCM(t *testing.T) { + if env.CompareServerVersion("3.0.0") < 0 { + t.Skip("Skipping consistency override trace test: system_traces.sessions parameters map not available before Cassandra 3.0") + } + + originCluster, targetCluster, err := SetupOrGetGlobalCcmClusters(t) + require.Nil(t, err) + + originSession := originCluster.GetSession() + targetSession := targetCluster.GetSession() + + // Create a proxy with target consistency override set to LOCAL_ONE + conf := setup.NewTestConfig(originCluster.GetInitialContactPoint(), targetCluster.GetInitialContactPoint()) + conf.TargetConsistencyLevel = "LOCAL_ONE" + + proxyInstance, err := setup.NewProxyInstanceWithConfig(conf) + require.Nil(t, err) + defer proxyInstance.Shutdown() + + // Ensure system_traces has RF=1 on both single-node CCM clusters + // and create the test table + for _, s := range []*gocql.Session{originSession, targetSession} { + s.Query("ALTER KEYSPACE system_traces WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}").Exec() + s.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.cl_test (id uuid PRIMARY KEY, val text)", setup.TestKeyspace)).Exec() + } + + // Connect through the proxy + proxy, err := utils.ConnectToCluster("127.0.0.1", "", "", conf.ProxyListenPort) + require.Nil(t, err) + defer proxy.Close() + + // ================================================================ + // TEST 1: Inline INSERT + // Client sends QUORUM, origin should see QUORUM, target should see LOCAL_ONE + // ================================================================ + t.Run("inline_insert", func(t *testing.T) { + clearTraces(originSession, targetSession) + + q := proxy.Query(fmt.Sprintf( + "INSERT INTO %s.cl_test (id, val) VALUES (d1b05da0-8c20-11ea-9fc6-6d2c86545d91, 'inline_cl_test')", + setup.TestKeyspace)) + q.Consistency(gocql.Quorum) + q.Trace(noopTracer{}) + err = q.Exec() + require.Nil(t, err, "inline INSERT through proxy failed") + + originCL := findTraceCL(t, originSession, "inline_cl_test") + require.Equal(t, "QUORUM", originCL, "origin should receive client-requested QUORUM") + + targetCL := findTraceCL(t, targetSession, "inline_cl_test") + require.Equal(t, "LOCAL_ONE", targetCL, "target should receive overridden LOCAL_ONE") + }) + + // ================================================================ + // TEST 2: Prepared INSERT + // gocql auto-prepares when bind params are used. The EXECUTE trace + // contains the original query string (e.g. "INSERT INTO ks.cl_test (id, val) VALUES (?, ?)") + // so we search for the table name as the marker. + // ================================================================ + t.Run("prepared_insert", func(t *testing.T) { + clearTraces(originSession, targetSession) + + q := proxy.Query(fmt.Sprintf( + "INSERT INTO %s.cl_test (id, val) VALUES (?, ?)", setup.TestKeyspace)) + q.Bind("eed574b0-8c20-11ea-9fc6-6d2c86545d91", "prepared_cl_test") + q.Consistency(gocql.Quorum) + q.Trace(noopTracer{}) + err = q.Exec() + require.Nil(t, err, "prepared INSERT through proxy failed") + + // For prepared statements, the trace query field contains the CQL with ? markers, + // not the bound values. Search for the table name instead. + originCL := findTraceCL(t, originSession, "cl_test") + require.Equal(t, "QUORUM", originCL, "origin should receive client-requested QUORUM for prepared statement") + + targetCL := findTraceCL(t, targetSession, "cl_test") + require.Equal(t, "LOCAL_ONE", targetCL, "target should receive overridden LOCAL_ONE for prepared statement") + }) + + // ================================================================ + // TEST 3: Batch INSERT + // Batch traces don't include the query text in the parameters map, + // so we check the first trace found after clearing. + // ================================================================ + t.Run("batch_insert", func(t *testing.T) { + clearTraces(originSession, targetSession) + + batch := proxy.NewBatch(gocql.LoggedBatch) + batch.Query(fmt.Sprintf( + "INSERT INTO %s.cl_test (id, val) VALUES (cf0f4cf0-8c20-11ea-9fc6-6d2c86545d91, 'batch_cl_test')", + setup.TestKeyspace)) + batch.SetConsistency(gocql.Quorum) + batch.Trace(noopTracer{}) + err = proxy.ExecuteBatch(batch) + require.Nil(t, err, "batch INSERT through proxy failed") + + originCL := findAnyTraceCL(t, originSession) + require.Equal(t, "QUORUM", originCL, "origin should receive client-requested QUORUM for batch") + + targetCL := findAnyTraceCL(t, targetSession) + require.Equal(t, "LOCAL_ONE", targetCL, "target should receive overridden LOCAL_ONE for batch") + }) +} + +// clearTraces truncates system_traces.sessions on both clusters. +func clearTraces(origin *gocql.Session, target *gocql.Session) { + origin.Query("TRUNCATE system_traces.sessions").Consistency(gocql.One).Exec() + origin.Query("TRUNCATE system_traces.events").Consistency(gocql.One).Exec() + target.Query("TRUNCATE system_traces.sessions").Consistency(gocql.One).Exec() + target.Query("TRUNCATE system_traces.events").Consistency(gocql.One).Exec() +} + +// findTraceCL searches system_traces.sessions for a trace whose query parameter contains +// the given marker, and returns the consistency_level. Retries for up to 10 seconds. +func findTraceCL(t *testing.T, session *gocql.Session, marker string) string { + t.Helper() + for attempt := 0; attempt < 20; attempt++ { + q := session.Query("SELECT parameters FROM system_traces.sessions") + q.Consistency(gocql.One) + iter := q.Iter() + var params map[string]string + for iter.Scan(¶ms) { + if query, ok := params["query"]; ok && strings.Contains(query, marker) { + if cl, ok := params["consistency_level"]; ok { + iter.Close() + return cl + } + } + } + iter.Close() + time.Sleep(500 * time.Millisecond) + } + t.Fatalf("no trace found containing marker %q after 10s of retries", marker) + return "" +} + +// findAnyTraceCL returns the consistency_level from the first trace session found. +// Retries for up to 10 seconds. +func findAnyTraceCL(t *testing.T, session *gocql.Session) string { + t.Helper() + for attempt := 0; attempt < 20; attempt++ { + q := session.Query("SELECT parameters FROM system_traces.sessions") + q.Consistency(gocql.One) + iter := q.Iter() + var params map[string]string + for iter.Scan(¶ms) { + if cl, ok := params["consistency_level"]; ok { + iter.Close() + return cl + } + } + iter.Close() + time.Sleep(500 * time.Millisecond) + } + t.Fatalf("no trace sessions found after 10s of retries") + return "" +} + +// noopTracer enables the tracing flag on the CQL protocol frame without +// fetching trace results. This avoids trace-fetch queries going through +// the proxy and interfering with the test. +type noopTracer struct{} + +func (noopTracer) Trace(_ []byte) {} diff --git a/integration-tests/target_write_consistency_test.go b/integration-tests/target_write_consistency_test.go new file mode 100644 index 00000000..ecb6e530 --- /dev/null +++ b/integration-tests/target_write_consistency_test.go @@ -0,0 +1,429 @@ +package integration_tests + +import ( + "context" + "encoding/json" + "testing" + + "github.com/datastax/go-cassandra-native-protocol/client" + "github.com/datastax/go-cassandra-native-protocol/frame" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/datastax/zdm-proxy/integration-tests/env" + "github.com/datastax/zdm-proxy/integration-tests/setup" + "github.com/datastax/zdm-proxy/integration-tests/simulacron" + "github.com/stretchr/testify/require" +) + +// getWriteQueries returns QUERY-type log entries for a given cluster. +func getWriteQueries(t *testing.T, cluster *simulacron.Cluster) []*simulacron.RequestLogEntry { + logs, err := cluster.GetLogsByType(simulacron.QueryTypeQuery) + require.NoError(t, err) + var queries []*simulacron.RequestLogEntry + for _, dc := range logs.Datacenters { + for _, node := range dc.Nodes { + queries = append(queries, node.Queries...) + } + } + return queries +} + +// TestTargetConsistencyOverride_Disabled verifies that when the override config is NOT set, +// both origin and target receive the client-requested consistency level unchanged. +func TestTargetConsistencyOverride_Disabled(t *testing.T) { + // Default config — no override + testSetup, err := setup.NewSimulacronTestSetup(t) + require.NoError(t, err) + defer testSetup.Cleanup() + + queryPrime := + simulacron.WhenQuery( + "INSERT INTO myks.users (name) VALUES ('alice')", + simulacron.NewWhenQueryOptions()). + ThenSuccess() + + err = testSetup.Origin.Prime(queryPrime) + require.NoError(t, err) + err = testSetup.Target.Prime(queryPrime) + require.NoError(t, err) + + // Clear logs before test + err = testSetup.Origin.DeleteLogs() + require.NoError(t, err) + err = testSetup.Target.DeleteLogs() + require.NoError(t, err) + + // Send a write with LOCAL_QUORUM using low-level client + cqlClient := client.NewCqlClient("127.0.0.1:14002", nil) + cqlConn, err := cqlClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) + require.NoError(t, err) + defer cqlConn.Close() + + queryMsg := &message.Query{ + Query: "INSERT INTO myks.users (name) VALUES ('alice')", + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelLocalQuorum, + }, + } + + rsp, err := cqlConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) + require.NoError(t, err) + require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) + + // Verify origin received LOCAL_QUORUM + originQueries := getWriteQueries(t, testSetup.Origin) + require.GreaterOrEqual(t, len(originQueries), 1, "expected at least 1 query on origin") + lastOriginQuery := originQueries[len(originQueries)-1] + require.Equal(t, "LOCAL_QUORUM", lastOriginQuery.ConsistencyLevel, + "origin should receive client-requested LOCAL_QUORUM") + + // Verify target also received LOCAL_QUORUM (no override) + targetQueries := getWriteQueries(t, testSetup.Target) + require.GreaterOrEqual(t, len(targetQueries), 1, "expected at least 1 query on target") + lastTargetQuery := targetQueries[len(targetQueries)-1] + require.Equal(t, "LOCAL_QUORUM", lastTargetQuery.ConsistencyLevel, + "target should receive client-requested LOCAL_QUORUM when override is disabled") +} + +// TestTargetConsistencyOverride_Enabled verifies that when the override is set to LOCAL_ONE, +// origin receives the original CL but target receives LOCAL_ONE. +func TestTargetConsistencyOverride_Enabled(t *testing.T) { + c := setup.NewTestConfig("", "") + c.TargetConsistencyLevel = "LOCAL_ONE" + + testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) + require.NoError(t, err) + defer testSetup.Cleanup() + + queryPrime := + simulacron.WhenQuery( + "INSERT INTO myks.users (name) VALUES ('bob')", + simulacron.NewWhenQueryOptions()). + ThenSuccess() + + err = testSetup.Origin.Prime(queryPrime) + require.NoError(t, err) + err = testSetup.Target.Prime(queryPrime) + require.NoError(t, err) + + // Clear logs before test + err = testSetup.Origin.DeleteLogs() + require.NoError(t, err) + err = testSetup.Target.DeleteLogs() + require.NoError(t, err) + + // Send a write with LOCAL_QUORUM + cqlClient := client.NewCqlClient("127.0.0.1:14002", nil) + cqlConn, err := cqlClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) + require.NoError(t, err) + defer cqlConn.Close() + + queryMsg := &message.Query{ + Query: "INSERT INTO myks.users (name) VALUES ('bob')", + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelLocalQuorum, + }, + } + + rsp, err := cqlConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) + require.NoError(t, err) + require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) + + // Verify origin still receives LOCAL_QUORUM (unchanged) + originQueries := getWriteQueries(t, testSetup.Origin) + require.GreaterOrEqual(t, len(originQueries), 1, "expected at least 1 query on origin") + lastOriginQuery := originQueries[len(originQueries)-1] + require.Equal(t, "LOCAL_QUORUM", lastOriginQuery.ConsistencyLevel, + "origin should always receive client-requested LOCAL_QUORUM") + + // Verify target receives LOCAL_ONE (overridden) + targetQueries := getWriteQueries(t, testSetup.Target) + require.GreaterOrEqual(t, len(targetQueries), 1, "expected at least 1 query on target") + lastTargetQuery := targetQueries[len(targetQueries)-1] + require.Equal(t, "LOCAL_ONE", lastTargetQuery.ConsistencyLevel, + "target should receive overridden LOCAL_ONE") +} + +// TestTargetConsistencyOverride_ReadAlsoAffected verifies that read requests +// routed to the target cluster are also affected by the consistency override. +// With default config (PrimaryCluster=ORIGIN), reads go to origin only, so +// the target does not receive them. This test uses PrimaryCluster=TARGET to +// route reads to target and verify the override applies. +func TestTargetConsistencyOverride_ReadAlsoAffected(t *testing.T) { + c := setup.NewTestConfig("", "") + c.TargetConsistencyLevel = "LOCAL_ONE" + c.PrimaryCluster = "TARGET" + + testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) + require.NoError(t, err) + defer testSetup.Cleanup() + + expectedRows := simulacron.NewRowsResult( + map[string]simulacron.DataType{"name": simulacron.DataTypeText}). + WithRow(map[string]interface{}{"name": "alice"}) + + queryPrime := + simulacron.WhenQuery( + "SELECT name FROM myks.users", + simulacron.NewWhenQueryOptions()). + ThenRowsSuccess(expectedRows) + + err = testSetup.Target.Prime(queryPrime) + require.NoError(t, err) + + // Clear logs before test + err = testSetup.Target.DeleteLogs() + require.NoError(t, err) + + // Send a read with LOCAL_QUORUM + cqlClient := client.NewCqlClient("127.0.0.1:14002", nil) + cqlConn, err := cqlClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) + require.NoError(t, err) + defer cqlConn.Close() + + queryMsg := &message.Query{ + Query: "SELECT name FROM myks.users", + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelLocalQuorum, + }, + } + + rsp, err := cqlConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) + require.NoError(t, err) + require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) + + // Verify target received LOCAL_ONE (overridden), not LOCAL_QUORUM + targetQueries := getWriteQueries(t, testSetup.Target) + require.GreaterOrEqual(t, len(targetQueries), 1, "expected at least 1 query on target") + lastTargetQuery := targetQueries[len(targetQueries)-1] + require.Equal(t, "LOCAL_ONE", lastTargetQuery.ConsistencyLevel, + "read queries to target should also be affected by consistency override") +} + +// TestTargetConsistencyOverride_Enabled_ONE verifies override with a different CL value (ONE). +func TestTargetConsistencyOverride_Enabled_ONE(t *testing.T) { + c := setup.NewTestConfig("", "") + c.TargetConsistencyLevel = "ONE" + + testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) + require.NoError(t, err) + defer testSetup.Cleanup() + + queryPrime := + simulacron.WhenQuery( + "INSERT INTO myks.users (name) VALUES ('charlie')", + simulacron.NewWhenQueryOptions()). + ThenSuccess() + + err = testSetup.Origin.Prime(queryPrime) + require.NoError(t, err) + err = testSetup.Target.Prime(queryPrime) + require.NoError(t, err) + + err = testSetup.Origin.DeleteLogs() + require.NoError(t, err) + err = testSetup.Target.DeleteLogs() + require.NoError(t, err) + + cqlClient := client.NewCqlClient("127.0.0.1:14002", nil) + cqlConn, err := cqlClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) + require.NoError(t, err) + defer cqlConn.Close() + + queryMsg := &message.Query{ + Query: "INSERT INTO myks.users (name) VALUES ('charlie')", + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelAll, + }, + } + + rsp, err := cqlConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, queryMsg)) + require.NoError(t, err) + require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) + + // Origin gets ALL (client-requested) + originQueries := getWriteQueries(t, testSetup.Origin) + require.GreaterOrEqual(t, len(originQueries), 1) + require.Equal(t, "ALL", originQueries[len(originQueries)-1].ConsistencyLevel) + + // Target gets ONE (overridden) + targetQueries := getWriteQueries(t, testSetup.Target) + require.GreaterOrEqual(t, len(targetQueries), 1) + require.Equal(t, "ONE", targetQueries[len(targetQueries)-1].ConsistencyLevel) +} + +// TestTargetConsistencyOverride_PreparedStatement verifies that the override applies +// to EXECUTE messages (prepared statement execution), not just inline Query writes. +func TestTargetConsistencyOverride_PreparedStatement(t *testing.T) { + c := setup.NewTestConfig("", "") + c.TargetConsistencyLevel = "LOCAL_ONE" + + testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) + require.NoError(t, err) + defer testSetup.Cleanup() + + // Prime the query for both clusters (simulacron needs this for PREPARE + EXECUTE) + queryPrime := + simulacron.WhenQuery( + "INSERT INTO myks.users (name) VALUES (?)", + simulacron.NewWhenQueryOptions(). + WithPositionalParameter(simulacron.DataTypeText, "dave")). + ThenSuccess() + + err = testSetup.Origin.Prime(queryPrime) + require.NoError(t, err) + err = testSetup.Target.Prime(queryPrime) + require.NoError(t, err) + + // Connect with low-level client + cqlClient := client.NewCqlClient("127.0.0.1:14002", nil) + cqlConn, err := cqlClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) + require.NoError(t, err) + defer cqlConn.Close() + + // Step 1: PREPARE + prepareMsg := &message.Prepare{ + Query: "INSERT INTO myks.users (name) VALUES (?)", + } + prepareResp, err := cqlConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, prepareMsg)) + require.NoError(t, err) + + prepared, ok := prepareResp.Body.Message.(*message.PreparedResult) + require.True(t, ok, "expected PreparedResult but got %T", prepareResp.Body.Message) + + // Clear logs between PREPARE and EXECUTE + err = testSetup.Origin.DeleteLogs() + require.NoError(t, err) + err = testSetup.Target.DeleteLogs() + require.NoError(t, err) + + // Step 2: EXECUTE with LOCAL_QUORUM + executeMsg := &message.Execute{ + QueryId: prepared.PreparedQueryId, + ResultMetadataId: prepared.ResultMetadataId, + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelLocalQuorum, + PositionalValues: []*primitive.Value{primitive.NewValue([]byte("dave"))}, + }, + } + execResp, err := cqlConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, executeMsg)) + require.NoError(t, err) + require.Equal(t, primitive.OpCodeResult, execResp.Header.OpCode) + + // Check origin EXECUTE logs — should have LOCAL_QUORUM + originExecLogs, err := testSetup.Origin.GetLogsByType(simulacron.QueryTypeExecute) + require.NoError(t, err) + originExecQueries := originExecLogs.Datacenters[0].Nodes[0].Queries + require.GreaterOrEqual(t, len(originExecQueries), 1, "expected at least 1 EXECUTE on origin") + + lastOriginExec := originExecQueries[len(originExecQueries)-1] + var originExecMsg simulacron.ExecuteMessage + err = json.Unmarshal(lastOriginExec.Frame.Message, &originExecMsg) + require.NoError(t, err) + require.NotNil(t, originExecMsg.Options) + require.Equal(t, "LOCAL_QUORUM", originExecMsg.Options.Consistency, + "origin EXECUTE should retain client-requested LOCAL_QUORUM") + + // Check target EXECUTE logs — should have LOCAL_ONE (overridden) + targetExecLogs, err := testSetup.Target.GetLogsByType(simulacron.QueryTypeExecute) + require.NoError(t, err) + targetExecQueries := targetExecLogs.Datacenters[0].Nodes[0].Queries + require.GreaterOrEqual(t, len(targetExecQueries), 1, "expected at least 1 EXECUTE on target") + + lastTargetExec := targetExecQueries[len(targetExecQueries)-1] + var targetExecMsg simulacron.ExecuteMessage + err = json.Unmarshal(lastTargetExec.Frame.Message, &targetExecMsg) + require.NoError(t, err) + require.NotNil(t, targetExecMsg.Options) + require.Equal(t, "LOCAL_ONE", targetExecMsg.Options.Consistency, + "target EXECUTE should have overridden LOCAL_ONE") +} + +// TestTargetConsistencyOverride_Batch verifies that the override applies to BATCH messages. +func TestTargetConsistencyOverride_Batch(t *testing.T) { + c := setup.NewTestConfig("", "") + c.TargetConsistencyLevel = "LOCAL_ONE" + + testSetup, err := setup.NewSimulacronTestSetupWithConfig(t, c) + require.NoError(t, err) + defer testSetup.Cleanup() + + // Prime individual queries that will be part of the batch + queryPrime1 := + simulacron.WhenQuery( + "INSERT INTO myks.users (name) VALUES ('eve')", + simulacron.NewWhenQueryOptions()). + ThenSuccess() + queryPrime2 := + simulacron.WhenQuery( + "INSERT INTO myks.users (name) VALUES ('frank')", + simulacron.NewWhenQueryOptions()). + ThenSuccess() + + for _, prime := range []simulacron.Then{queryPrime1, queryPrime2} { + err = testSetup.Origin.Prime(prime) + require.NoError(t, err) + err = testSetup.Target.Prime(prime) + require.NoError(t, err) + } + + cqlClient := client.NewCqlClient("127.0.0.1:14002", nil) + cqlConn, err := cqlClient.ConnectAndInit(context.Background(), env.DefaultProtocolVersionSimulacron, 0) + require.NoError(t, err) + defer cqlConn.Close() + + // Clear logs + err = testSetup.Origin.DeleteLogs() + require.NoError(t, err) + err = testSetup.Target.DeleteLogs() + require.NoError(t, err) + + // Send a BATCH with LOCAL_QUORUM + batchMsg := &message.Batch{ + Type: primitive.BatchTypeLogged, + Children: []*message.BatchChild{ + { + Query: "INSERT INTO myks.users (name) VALUES ('eve')", + }, + { + Query: "INSERT INTO myks.users (name) VALUES ('frank')", + }, + }, + Consistency: primitive.ConsistencyLevelLocalQuorum, + } + + batchResp, err := cqlConn.SendAndReceive(frame.NewFrame(env.DefaultProtocolVersionSimulacron, 0, batchMsg)) + require.NoError(t, err) + require.Equal(t, primitive.OpCodeResult, batchResp.Header.OpCode) + + // Helper to extract batch messages from logs + getBatchMessages := func(cluster *simulacron.Cluster) []*simulacron.BatchMessage { + logs, err := cluster.GetLogsByType(simulacron.QueryTypeBatch) + require.NoError(t, err) + var batches []*simulacron.BatchMessage + for _, dc := range logs.Datacenters { + for _, node := range dc.Nodes { + for _, entry := range node.Queries { + var bm simulacron.BatchMessage + err := json.Unmarshal(entry.Frame.Message, &bm) + if err == nil { + batches = append(batches, &bm) + } + } + } + } + return batches + } + + // Check origin BATCH — should have LOCAL_QUORUM + originBatches := getBatchMessages(testSetup.Origin) + require.GreaterOrEqual(t, len(originBatches), 1, "expected at least 1 BATCH on origin") + require.Equal(t, "LOCAL_QUORUM", originBatches[len(originBatches)-1].Consistency, + "origin BATCH should retain client-requested LOCAL_QUORUM") + + // Check target BATCH — should have LOCAL_ONE (overridden) + targetBatches := getBatchMessages(testSetup.Target) + require.GreaterOrEqual(t, len(targetBatches), 1, "expected at least 1 BATCH on target") + require.Equal(t, "LOCAL_ONE", targetBatches[len(targetBatches)-1].Consistency, + "target BATCH should have overridden LOCAL_ONE") +} diff --git a/proxy/launch.go b/proxy/launch.go index 9882457b..39f9d31d 100644 --- a/proxy/launch.go +++ b/proxy/launch.go @@ -56,6 +56,13 @@ func launchProxy(profilingSupported bool) { } log.SetLevel(logLevel) + targetCL, _ := conf.ParseTargetConsistencyLevel() + if targetCL != nil { + log.Warnf("Target consistency level override is ENABLED: all requests to the target cluster will use %v instead of the client-requested consistency level", *targetCL) + } else { + log.Infof("Target consistency level override: disabled") + } + if profilingSupported { log.Debugf("Proxy built with profiling support") } else { diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index fe3cebdb..a589ebde 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -105,6 +105,13 @@ type Config struct { HeartbeatRetryBackoffFactor float64 `default:"2" split_words:"true" yaml:"heartbeat_retry_backoff_factor"` HeartbeatFailureThreshold int `default:"1" split_words:"true" yaml:"heartbeat_failure_threshold"` + // Target consistency level override. + // When set, overrides the consistency level for ALL requests (reads and writes) sent to the target cluster. + // The origin/source cluster always uses the client-requested consistency level. + // Valid values: ANY, ONE, TWO, THREE, QUORUM, ALL, LOCAL_QUORUM, EACH_QUORUM, LOCAL_ONE (case-insensitive). + // Empty or unset means disabled (default behavior, no override). + TargetConsistencyLevel string `default:"" split_words:"true" yaml:"target_consistency_level"` + ////////////////////////////////////////////////////////////////////// /// THE SETTINGS BELOW AREN'T SUPPORTED AND MAY CHANGE AT ANY TIME /// ////////////////////////////////////////////////////////////////////// @@ -338,6 +345,11 @@ func (c *Config) Validate() error { return err } + _, err = c.ParseTargetConsistencyLevel() + if err != nil { + return err + } + return nil } @@ -392,6 +404,45 @@ func (c *Config) ParseReadMode() (common.ReadMode, error) { } } +// consistencyLevelMap maps uppercase consistency level names to primitive.ConsistencyLevel values. +// Only non-serial consistency levels are valid for write CL override. +var consistencyLevelMap = map[string]primitive.ConsistencyLevel{ + "ANY": primitive.ConsistencyLevelAny, + "ONE": primitive.ConsistencyLevelOne, + "TWO": primitive.ConsistencyLevelTwo, + "THREE": primitive.ConsistencyLevelThree, + "QUORUM": primitive.ConsistencyLevelQuorum, + "ALL": primitive.ConsistencyLevelAll, + "LOCAL_QUORUM": primitive.ConsistencyLevelLocalQuorum, + "EACH_QUORUM": primitive.ConsistencyLevelEachQuorum, + "LOCAL_ONE": primitive.ConsistencyLevelLocalOne, +} + +// ParseTargetConsistencyLevel parses the target consistency level override. +// Returns nil if the feature is disabled (empty/unset config value). +// Returns a non-nil pointer to the parsed consistency level if valid. +// Returns an error if the value is set but invalid. +func (c *Config) ParseTargetConsistencyLevel() (*primitive.ConsistencyLevel, error) { + trimmed := strings.TrimSpace(c.TargetConsistencyLevel) + if trimmed == "" { + return nil, nil + } + + upper := strings.ToUpper(trimmed) + if cl, ok := consistencyLevelMap[upper]; ok { + return &cl, nil + } + + validValues := make([]string, 0, len(consistencyLevelMap)) + for k := range consistencyLevelMap { + validValues = append(validValues, k) + } + slices.Sort(validValues) + return nil, fmt.Errorf( + "invalid value for ZDM_TARGET_CONSISTENCY_LEVEL: %q; valid values are: %v", + trimmed, strings.Join(validValues, ", ")) +} + func (c *Config) ParseControlConnMaxProtocolVersion() (primitive.ProtocolVersion, error) { if strings.EqualFold(c.ControlConnMaxProtocolVersion, "DseV2") { return primitive.ProtocolVersionDse2, nil diff --git a/proxy/pkg/config/config_target_write_cl_test.go b/proxy/pkg/config/config_target_write_cl_test.go new file mode 100644 index 00000000..a32a6bbe --- /dev/null +++ b/proxy/pkg/config/config_target_write_cl_test.go @@ -0,0 +1,157 @@ +package config + +import ( + "testing" + + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/require" +) + +func TestParseTargetConsistencyLevel_Empty(t *testing.T) { + c := &Config{TargetConsistencyLevel: ""} + cl, err := c.ParseTargetConsistencyLevel() + require.NoError(t, err) + require.Nil(t, cl, "empty value should return nil (disabled)") +} + +func TestParseTargetConsistencyLevel_Whitespace(t *testing.T) { + c := &Config{TargetConsistencyLevel: " "} + cl, err := c.ParseTargetConsistencyLevel() + require.NoError(t, err) + require.Nil(t, cl, "whitespace-only value should return nil (disabled)") +} + +func TestParseTargetConsistencyLevel_ValidValues(t *testing.T) { + tests := []struct { + input string + expected primitive.ConsistencyLevel + }{ + {"ANY", primitive.ConsistencyLevelAny}, + {"ONE", primitive.ConsistencyLevelOne}, + {"TWO", primitive.ConsistencyLevelTwo}, + {"THREE", primitive.ConsistencyLevelThree}, + {"QUORUM", primitive.ConsistencyLevelQuorum}, + {"ALL", primitive.ConsistencyLevelAll}, + {"LOCAL_QUORUM", primitive.ConsistencyLevelLocalQuorum}, + {"EACH_QUORUM", primitive.ConsistencyLevelEachQuorum}, + {"LOCAL_ONE", primitive.ConsistencyLevelLocalOne}, + // case-insensitive + {"local_one", primitive.ConsistencyLevelLocalOne}, + {"Local_Quorum", primitive.ConsistencyLevelLocalQuorum}, + {"one", primitive.ConsistencyLevelOne}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + c := &Config{TargetConsistencyLevel: tt.input} + cl, err := c.ParseTargetConsistencyLevel() + require.NoError(t, err) + require.NotNil(t, cl) + require.Equal(t, tt.expected, *cl) + }) + } +} + +func TestParseTargetConsistencyLevel_InvalidValues(t *testing.T) { + tests := []string{ + "SERIAL", + "LOCAL_SERIAL", + "INVALID", + "local_serial", + "QUOROM", + "12345", + } + + for _, input := range tests { + t.Run(input, func(t *testing.T) { + c := &Config{TargetConsistencyLevel: input} + cl, err := c.ParseTargetConsistencyLevel() + require.Error(t, err) + require.Nil(t, cl) + require.Contains(t, err.Error(), "ZDM_TARGET_CONSISTENCY_LEVEL") + }) + } +} + +func TestParseTargetConsistencyLevel_WithWhitespacePadding(t *testing.T) { + c := &Config{TargetConsistencyLevel: " LOCAL_ONE "} + cl, err := c.ParseTargetConsistencyLevel() + require.NoError(t, err) + require.NotNil(t, cl) + require.Equal(t, primitive.ConsistencyLevelLocalOne, *cl) +} + +func TestValidate_RejectsInvalidTargetConsistencyLevel(t *testing.T) { + defer clearAllEnvVars() + + clearAllEnvVars() + setOriginCredentialsEnvVars() + setTargetCredentialsEnvVars() + setOriginContactPointsAndPortEnvVars() + setTargetContactPointsAndPortEnvVars() + + setEnvVar("ZDM_TARGET_CONSISTENCY_LEVEL", "SERIAL") + + _, err := New().LoadConfig("") + require.Error(t, err) + require.Contains(t, err.Error(), "ZDM_TARGET_CONSISTENCY_LEVEL") +} + +func TestValidate_AcceptsEmptyTargetConsistencyLevel(t *testing.T) { + defer clearAllEnvVars() + + clearAllEnvVars() + setOriginCredentialsEnvVars() + setTargetCredentialsEnvVars() + setOriginContactPointsAndPortEnvVars() + setTargetContactPointsAndPortEnvVars() + + // do NOT set ZDM_TARGET_CONSISTENCY_LEVEL + + conf, err := New().LoadConfig("") + require.NoError(t, err) + require.Empty(t, conf.TargetConsistencyLevel) +} + +func TestValidate_AcceptsValidTargetConsistencyLevel(t *testing.T) { + defer clearAllEnvVars() + + clearAllEnvVars() + setOriginCredentialsEnvVars() + setTargetCredentialsEnvVars() + setOriginContactPointsAndPortEnvVars() + setTargetContactPointsAndPortEnvVars() + + setEnvVar("ZDM_TARGET_CONSISTENCY_LEVEL", "LOCAL_ONE") + + conf, err := New().LoadConfig("") + require.NoError(t, err) + require.Equal(t, "LOCAL_ONE", conf.TargetConsistencyLevel) +} + +func TestTargetConsistencyLevel_YamlConfig(t *testing.T) { + yamlContent := ` +origin_contact_points: "origin.hostname.com" +origin_port: 9042 +origin_username: "user" +origin_password: "pass" +target_contact_points: "target.hostname.com" +target_port: 9042 +target_username: "user" +target_password: "pass" +target_consistency_level: "LOCAL_ONE" +` + + f, err := createConfigFile(yamlContent) + require.NoError(t, err) + defer removeConfigFile(f) + + conf, err := New().LoadConfig(f.Name()) + require.NoError(t, err) + require.Equal(t, "LOCAL_ONE", conf.TargetConsistencyLevel) + + cl, err := conf.ParseTargetConsistencyLevel() + require.NoError(t, err) + require.NotNil(t, cl) + require.Equal(t, primitive.ConsistencyLevelLocalOne, *cl) +} diff --git a/proxy/pkg/metrics/metric_handler.go b/proxy/pkg/metrics/metric_handler.go index a7616d55..9ce19311 100644 --- a/proxy/pkg/metrics/metric_handler.go +++ b/proxy/pkg/metrics/metric_handler.go @@ -31,6 +31,10 @@ type MetricHandler struct { originBuckets []float64 targetBuckets []float64 asyncBuckets []float64 + + // Per-table write success counters, keyed by "cluster:keyspace.table" + writeSuccessCounters map[string]Counter + writeSuccessRwLock *sync.RWMutex } func NewMetricHandler( @@ -57,6 +61,8 @@ func NewMetricHandler( originBuckets: originBuckets, targetBuckets: targetBuckets, asyncBuckets: asyncBuckets, + writeSuccessCounters: make(map[string]Counter), + writeSuccessRwLock: &sync.RWMutex{}, } } @@ -180,6 +186,49 @@ func (recv *MetricHandler) GetNodeMetrics( return &NodeMetrics{OriginMetrics: originMetrics, TargetMetrics: targetMetrics, AsyncMetrics: asyncMetrics}, nil } +const ( + writeSuccessName = "proxy_write_success_total" + writeSuccessDescription = "Running total of successful writes per cluster, keyspace and table" + writeSuccessCluster = "cluster" + writeSuccessKeyspace = "keyspace" + writeSuccessTable = "table" +) + +// GetOrCreateWriteSuccessCounter returns a Counter for tracking successful writes to a specific +// cluster/keyspace/table combination. Counters are cached and reused for the same combination. +func (recv *MetricHandler) GetOrCreateWriteSuccessCounter(cluster string, keyspace string, table string) (Counter, error) { + key := cluster + ":" + keyspace + "." + table + + recv.writeSuccessRwLock.RLock() + counter, ok := recv.writeSuccessCounters[key] + recv.writeSuccessRwLock.RUnlock() + if ok { + return counter, nil + } + + recv.writeSuccessRwLock.Lock() + counter, ok = recv.writeSuccessCounters[key] + if ok { + recv.writeSuccessRwLock.Unlock() + return counter, nil + } + + mn := NewMetricWithLabels(writeSuccessName, writeSuccessDescription, map[string]string{ + writeSuccessCluster: cluster, + writeSuccessKeyspace: keyspace, + writeSuccessTable: table, + }) + counter, err := recv.metricFactory.GetOrCreateCounter(mn) + if err != nil { + recv.writeSuccessRwLock.Unlock() + return nil, fmt.Errorf("failed to create write success counter for %s: %w", key, err) + } + + recv.writeSuccessCounters[key] = counter + recv.writeSuccessRwLock.Unlock() + return counter, nil +} + func (recv *MetricHandler) UnregisterAllMetrics() error { return recv.metricFactory.UnregisterAllMetrics() } diff --git a/proxy/pkg/metrics/write_success_metric_test.go b/proxy/pkg/metrics/write_success_metric_test.go new file mode 100644 index 00000000..c169935c --- /dev/null +++ b/proxy/pkg/metrics/write_success_metric_test.go @@ -0,0 +1,98 @@ +package metrics_test + +import ( + "testing" + + "github.com/datastax/zdm-proxy/proxy/pkg/metrics" + "github.com/datastax/zdm-proxy/proxy/pkg/metrics/prommetrics" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" +) + +func TestGetOrCreateWriteSuccessCounter(t *testing.T) { + registry := prometheus.NewRegistry() + factory := prommetrics.NewPrometheusMetricFactory(registry, "zdm") + handler := metrics.NewMetricHandler(factory, nil, nil, nil, nil, nil, nil, nil) + + // Create a counter for origin/ks1/users + counter1, err := handler.GetOrCreateWriteSuccessCounter("origin", "ks1", "users") + require.NoError(t, err) + require.NotNil(t, counter1) + + // Increment it + counter1.Add(1) + + // Getting the same combination should return the same counter (cached) + counter1Again, err := handler.GetOrCreateWriteSuccessCounter("origin", "ks1", "users") + require.NoError(t, err) + require.Equal(t, counter1, counter1Again) + + // Different table should return a different counter + counter2, err := handler.GetOrCreateWriteSuccessCounter("origin", "ks1", "events") + require.NoError(t, err) + require.NotNil(t, counter2) + + // Different cluster same table should return a different counter + counter3, err := handler.GetOrCreateWriteSuccessCounter("target", "ks1", "users") + require.NoError(t, err) + require.NotNil(t, counter3) + + // Verify counters are independent — increment counter2 and counter3 + counter2.Add(3) + counter3.Add(5) + + // Gather metrics from the registry and verify the values + metricFamilies, err := registry.Gather() + require.NoError(t, err) + + found := map[string]float64{} + for _, mf := range metricFamilies { + if mf.GetName() == "zdm_proxy_write_success_total" { + for _, m := range mf.GetMetric() { + labels := map[string]string{} + for _, l := range m.GetLabel() { + labels[l.GetName()] = l.GetValue() + } + key := labels["cluster"] + ":" + labels["keyspace"] + "." + labels["table"] + found[key] = m.GetCounter().GetValue() + } + } + } + + require.Equal(t, float64(1), found["origin:ks1.users"]) + require.Equal(t, float64(3), found["origin:ks1.events"]) + require.Equal(t, float64(5), found["target:ks1.users"]) +} + +func TestGetOrCreateWriteSuccessCounter_ConcurrentAccess(t *testing.T) { + registry := prometheus.NewRegistry() + factory := prommetrics.NewPrometheusMetricFactory(registry, "zdm") + handler := metrics.NewMetricHandler(factory, nil, nil, nil, nil, nil, nil, nil) + + // Simulate concurrent access from multiple goroutines + done := make(chan bool, 100) + for i := 0; i < 100; i++ { + go func() { + counter, err := handler.GetOrCreateWriteSuccessCounter("origin", "ks1", "users") + require.NoError(t, err) + counter.Add(1) + done <- true + }() + } + + for i := 0; i < 100; i++ { + <-done + } + + // Verify total count is 100 + metricFamilies, err := registry.Gather() + require.NoError(t, err) + + for _, mf := range metricFamilies { + if mf.GetName() == "zdm_proxy_write_success_total" { + for _, m := range mf.GetMetric() { + require.Equal(t, float64(100), m.GetCounter().GetValue()) + } + } + } +} diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index 26c4a654..37d93568 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -120,6 +120,10 @@ type ClientHandler struct { timeUuidGenerator TimeUuidGenerator rateLimiters *RateLimiters + // targetConsistencyLevel is the optional override for target-side consistency. + // nil means disabled (default). When non-nil, all requests to the target cluster use this CL. + targetConsistencyLevel *primitive.ConsistencyLevel + // not used atm but should be used when a protocol error occurs after #68 has been addressed clientHandlerShutdownRequestCancelFn context.CancelFunc @@ -278,6 +282,13 @@ func NewClientHandler( forwardAuthToTarget, targetCredsOnClientRequest := forwardAuthToTarget( originControlConn, targetControlConn, conf.ForwardClientCredentialsToOrigin) + // Parse target consistency level override (nil = disabled) + targetWriteCL, err := conf.ParseTargetConsistencyLevel() + if err != nil { + clientHandlerCancelFunc() + return nil, fmt.Errorf("failed to parse target consistency level: %w", err) + } + return &ClientHandler{ clientConnector: NewClientConnector( clientTcpConn, @@ -345,6 +356,7 @@ func NewClientHandler( clientHandlerShutdownRequestCancelFn: clientHandlerShutdownRequestCancelFn, clientHandlerShutdownRequestContext: clientHandlerShutdownRequestContext, compression: compression, + targetConsistencyLevel: targetWriteCL, }, nil } @@ -658,6 +670,10 @@ func (ch *ClientHandler) responseLoop() { if reqCtx.GetRequestInfo().ShouldBeTrackedInMetrics() { trackClusterErrorMetrics(response.responseFrame, ch.getCompression(), response.connectorType, ch.nodeMetrics) } + // Track per-table successful writes immediately when each cluster responds + if isResponseSuccessful(response.responseFrame) && response.connectorType != ClusterConnectorTypeAsync { + ch.trackPerTableWriteSuccess(reqCtx.GetRequestInfo(), response.connectorType) + } } if finished { @@ -1528,6 +1544,16 @@ func (ch *ClientHandler) executeRequest( return err } + // Override target consistency level for requests where origin and target share the same frame. + // EXECUTE and BATCH are handled in their respective handlers where the deep copy already exists. + targetReceivesRequest := fwdDecision == forwardToBoth || fwdDecision == forwardToTarget + if ch.targetConsistencyLevel != nil && targetReceivesRequest && originRequest == targetRequest { + targetRequest, err = ch.overrideTargetConsistency(frameContext) + if err != nil { + return fmt.Errorf("could not override target consistency: %w", err) + } + } + if fwdDecision == forwardToNone { if clientResponse == nil { return fmt.Errorf("forwardDecision is NONE but client response is nil") @@ -1654,6 +1680,35 @@ func (ch *ClientHandler) handleRequestSendFailure(err error, frameContext *frame } } +// overrideTargetConsistency creates a deep copy of the frame with the consistency level overridden +// for target-side requests. This is called for generic Query messages (non-prepared, non-batch) where +// the origin and target requests share the same raw frame pointer. +func (ch *ClientHandler) overrideTargetConsistency(frameContext *frameDecodeContext) (*frame.RawFrame, error) { + decodedFrame, err := frameContext.GetOrDecodeFrame() + if err != nil { + return nil, fmt.Errorf("could not decode frame for target write CL override: %w", err) + } + + targetFrame := decodedFrame.DeepCopy() + + switch typedMsg := targetFrame.Body.Message.(type) { + case *message.Query: + if typedMsg.Options == nil { + typedMsg.Options = &message.QueryOptions{} + } + typedMsg.Options.Consistency = *ch.targetConsistencyLevel + default: + // For messages without a consistency field (e.g., PREPARE, STARTUP), return unchanged. + return frameContext.GetRawFrame(), nil + } + + rawFrame, err := ch.getCodec().ConvertToRawFrame(targetFrame) + if err != nil { + return nil, fmt.Errorf("could not re-encode frame after target write CL override: %w", err) + } + return rawFrame, nil +} + func (ch *ClientHandler) handleInterceptedRequest( requestInfo RequestInfo, frameContext *frameDecodeContext, currentKeyspace string) (*frame.RawFrame, error) { @@ -1842,6 +1897,13 @@ func (ch *ClientHandler) handleExecuteRequest( log.Tracef("Replacing prepared ID %s with %s for target cluster.", hex.EncodeToString(originalQueryId), hex.EncodeToString(newTargetExecuteMsg.QueryId)) + if ch.targetConsistencyLevel != nil { + if newTargetExecuteMsg.Options == nil { + newTargetExecuteMsg.Options = &message.QueryOptions{} + } + newTargetExecuteMsg.Options.Consistency = *ch.targetConsistencyLevel + } + newTargetRequestRaw, err := ch.getCodec().ConvertToRawFrame(newTargetRequest) if err != nil { return nil, nil, nil, fmt.Errorf("could not convert target EXECUTE response to raw frame: %w", err) @@ -1901,6 +1963,10 @@ func (ch *ClientHandler) handleBatchRequest( hex.EncodeToString(originalQueryId), hex.EncodeToString(preparedData.GetTargetPreparedId())) } + if ch.targetConsistencyLevel != nil { + newTargetBatchMsg.Consistency = *ch.targetConsistencyLevel + } + if newOriginRequest != nil { originBatchRequest, err := ch.getCodec().ConvertToRawFrame(newOriginRequest) if err != nil { @@ -2201,6 +2267,35 @@ func decodeErrorResult(frame *frame.RawFrame, compression primitive.Compression) return errorResult, nil } +func (ch *ClientHandler) trackPerTableWriteSuccess(requestInfo RequestInfo, connectorType ClusterConnectorType) { + if requestInfo.GetForwardDecision() != forwardToBoth { + return // only track writes (forwardToBoth) + } + targets := requestInfo.GetWriteTargets() + if len(targets) == 0 { + return + } + + var cluster string + switch connectorType { + case ClusterConnectorTypeOrigin: + cluster = "origin" + case ClusterConnectorTypeTarget: + cluster = "target" + default: + return + } + + for _, wt := range targets { + counter, err := ch.metricHandler.GetOrCreateWriteSuccessCounter(cluster, wt.Keyspace, wt.Table) + if err != nil { + log.Warnf("Could not create write success metric for %s.%s on %s: %v", wt.Keyspace, wt.Table, cluster, err) + continue + } + counter.Add(1) + } +} + func isResponseSuccessful(response *frame.RawFrame) bool { return response.Header.OpCode != primitive.OpCodeError } diff --git a/proxy/pkg/zdmproxy/controlconn.go b/proxy/pkg/zdmproxy/controlconn.go index 37edb54b..2d0b2632 100644 --- a/proxy/pkg/zdmproxy/controlconn.go +++ b/proxy/pkg/zdmproxy/controlconn.go @@ -235,6 +235,53 @@ func (cc *ControlConn) IsAuthEnabled() (bool, error) { "the control connection has not been initialized") } +// CheckSuperUserAndWarn queries system_auth.roles to determine if the configured user is a superuser. +// If the user is a superuser, a warning is logged advising against this practice. +// Any errors (e.g. table doesn't exist, permission denied, Astra-specific behavior) are silently +// ignored — the check is best-effort only. +func (cc *ControlConn) CheckSuperUserAndWarn() { + conn, _ := cc.GetConnAndContactPoint() + if conn == nil { + return + } + + authEnabled, err := cc.IsAuthEnabled() + if err != nil || !authEnabled { + return + } + + clusterType := cc.connConfig.GetClusterType() + query := fmt.Sprintf("SELECT is_superuser FROM system_auth.roles WHERE role = '%s'", cc.username) + result, err := conn.Query(query, GetDefaultGenericTypeCodec(), cc.context) + if err != nil { + log.Debugf("[%v] Could not query system_auth.roles to check superuser status (this is expected on some platforms): %v", + clusterType, err) + return + } + + if result == nil || len(result.Rows) == 0 { + return + } + + val, exists := result.Rows[0].GetByColumn("is_superuser") + if !exists || val == nil { + return + } + + isSuperUser, ok := val.(bool) + if !ok { + return + } + + if isSuperUser { + log.Warnf("[%v] The configured user '%s' is a superuser. This is not recommended for application "+ + "workloads because superuser authentication requires QUORUM consistency internally in Cassandra, "+ + "which increases the risk of authentication failures during node instability. Consider using a "+ + "regular user with only the necessary permissions for the keyspaces being migrated.", + clusterType, cc.username) + } +} + func (cc *ControlConn) IncrementFailureCounter() { cc.counterLock.Lock() defer cc.counterLock.Unlock() diff --git a/proxy/pkg/zdmproxy/cqlparser.go b/proxy/pkg/zdmproxy/cqlparser.go index d1d5d43c..82648389 100644 --- a/proxy/pkg/zdmproxy/cqlparser.go +++ b/proxy/pkg/zdmproxy/cqlparser.go @@ -109,7 +109,8 @@ func buildRequestInfo( } else if len(stmtsReplacedTerms) == 1 { replacedTerms = stmtsReplacedTerms[0].replacedTerms } - return NewPrepareRequestInfo(baseRequestInfo, replacedTerms, stmtQueryData.queryData.hasPositionalBindMarkers(), prepareMsg.Query, prepareMsg.Keyspace), nil + tableName := stmtQueryData.queryData.getTableName() + return NewPrepareRequestInfo(baseRequestInfo, replacedTerms, stmtQueryData.queryData.hasPositionalBindMarkers(), prepareMsg.Query, prepareMsg.Keyspace, tableName), nil case primitive.OpCodeBatch: decodedFrame, err := frameContext.GetOrDecodeFrame() if err != nil { @@ -120,17 +121,35 @@ func buildRequestInfo( return nil, fmt.Errorf("could not convert message with batch op code to batch type, got %v instead", decodedFrame.Body.Message) } preparedDataByStmtIdxMap := make(map[int]PreparedData) + var writeTargets []WriteTarget + seen := map[string]bool{} // deduplicate targets within the batch for childIdx, child := range batchMsg.Children { if child.Id != nil { preparedData, err := getPreparedData(psCache, mh, child.Id, primitive.OpCodeBatch, decodedFrame) if err != nil { return nil, err - } else { - preparedDataByStmtIdxMap[childIdx] = preparedData + } + preparedDataByStmtIdxMap[childIdx] = preparedData + pri := preparedData.GetPrepareRequestInfo() + ks := pri.GetKeyspace() + tbl := pri.GetTableName() + key := ks + "." + tbl + if (ks != "" || tbl != "") && !seen[key] { + writeTargets = append(writeTargets, WriteTarget{Keyspace: ks, Table: tbl}) + seen[key] = true + } + } else if child.Query != "" { + qi := inspectCqlQuery(child.Query, currentKeyspaceName, timeUuidGenerator) + ks := qi.getApplicableKeyspace() + tbl := qi.getTableName() + key := ks + "." + tbl + if (ks != "" || tbl != "") && !seen[key] { + writeTargets = append(writeTargets, WriteTarget{Keyspace: ks, Table: tbl}) + seen[key] = true } } } - return NewBatchRequestInfo(preparedDataByStmtIdxMap), nil + return NewBatchRequestInfo(preparedDataByStmtIdxMap, writeTargets), nil case primitive.OpCodeExecute: decodedFrame, err := frameContext.GetOrDecodeFrame() if err != nil { @@ -233,7 +252,15 @@ func getRequestInfoFromQueryInfo( log.Tracef("Forward decision: %s", forwardDecision) - return NewGenericRequestInfo(forwardDecision, sendAlsoToAsync, trackMetrics) + info := NewGenericRequestInfo(forwardDecision, sendAlsoToAsync, trackMetrics) + if forwardDecision == forwardToBoth { + ks := queryInfo.getApplicableKeyspace() + tbl := queryInfo.getTableName() + if ks != "" || tbl != "" { + info.writeTargets = []WriteTarget{{Keyspace: ks, Table: tbl}} + } + } + return info } func isSystemQuery(info QueryInfo) bool { diff --git a/proxy/pkg/zdmproxy/cqlparser_test.go b/proxy/pkg/zdmproxy/cqlparser_test.go index 1637e7d7..da9bc4f4 100644 --- a/proxy/pkg/zdmproxy/cqlparser_test.go +++ b/proxy/pkg/zdmproxy/cqlparser_test.go @@ -25,37 +25,37 @@ func TestInspectFrame(t *testing.T) { originCacheEntry := &preparedDataImpl{ originPreparedId: []byte("ORIGIN"), targetPreparedId: []byte("ORIGIN_TARGET"), - prepareRequestInfo: NewPrepareRequestInfo(NewGenericRequestInfo(forwardToOrigin, false, false), nil, false, "", ""), + prepareRequestInfo: NewPrepareRequestInfo(NewGenericRequestInfo(forwardToOrigin, false, false), nil, false, "", "", ""), } targetCacheEntry := &preparedDataImpl{ originPreparedId: []byte("TARGET"), targetPreparedId: []byte("TARGET_TARGET"), - prepareRequestInfo: NewPrepareRequestInfo(NewGenericRequestInfo(forwardToTarget, false, false), nil, false, "", ""), + prepareRequestInfo: NewPrepareRequestInfo(NewGenericRequestInfo(forwardToTarget, false, false), nil, false, "", "", ""), } bothCacheEntry := &preparedDataImpl{ originPreparedId: []byte("BOTH"), targetPreparedId: []byte("BOTH_TARGET"), - prepareRequestInfo: NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, false), nil, false, "", ""), + prepareRequestInfo: NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, false), nil, false, "", "", ""), } peersKsCacheEntry := &preparedDataImpl{ originPreparedId: []byte("PEERS_KS"), targetPreparedId: []byte("PEERS_KS"), - prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), nil, false, "SELECT * FROM peers", "system"), + prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), nil, false, "SELECT * FROM peers", "system", "peers"), } peersCacheEntry := &preparedDataImpl{ originPreparedId: []byte("PEERS"), targetPreparedId: []byte("PEERS"), - prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), nil, false, "SELECT * FROM system.peers", ""), + prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), nil, false, "SELECT * FROM system.peers", "", "peers"), } localKsCacheEntry := &preparedDataImpl{ originPreparedId: []byte("LOCAL_KS"), targetPreparedId: []byte("LOCAL_KS"), - prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), nil, false, "SELECT * FROM local", "system"), + prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), nil, false, "SELECT * FROM local", "system", "local"), } localCacheEntry := &preparedDataImpl{ originPreparedId: []byte("LOCAL"), targetPreparedId: []byte("LOCAL"), - prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), nil, false, "SELECT * FROM system.local", ""), + prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), nil, false, "SELECT * FROM system.local", "", "local"), } psCache := NewPreparedStatementCache() psCache.cache["BOTH"] = bothCacheEntry @@ -90,28 +90,28 @@ func TestInspectFrame(t *testing.T) { {"OpCodeQuery SELECT system.peers_v2", args{mockQueryFrame(t, "SELECT * FROM system.peers_v2"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewInterceptedRequestInfo(peersV2, newStarSelectClause())}, {"OpCodeQuery SELECT system_auth.roles", args{mockQueryFrame(t, "SELECT * FROM system_auth.roles"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToOrigin, false, true)}, {"OpCodeQuery SELECT dse_insights.tokens", args{mockQueryFrame(t, "SELECT * FROM dse_insights.tokens"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToOrigin, false, true)}, - {"OpCodeQuery INSERT INTO asd (a, b) VALUES (1, 2)", args{mockQueryFrame(t, "INSERT INTO asd (a, b) VALUES (1, 2)"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToBoth, false, true)}, - {"OpCodeQuery UPDATE asd SET b = 2 WHERE a = 1", args{mockQueryFrame(t, "UPDATE asd SET b = 2 WHERE a = 1"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToBoth, false, true)}, + {"OpCodeQuery INSERT INTO asd (a, b) VALUES (1, 2)", args{mockQueryFrame(t, "INSERT INTO asd (a, b) VALUES (1, 2)"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, &GenericRequestInfo{baseRequestInfo: &baseRequestInfo{forwardDecision: forwardToBoth, shouldAlsoBeSentAsync: false, trackMetrics: true, writeTargets: []WriteTarget{{Keyspace: "", Table: "asd"}}}}}, + {"OpCodeQuery UPDATE asd SET b = 2 WHERE a = 1", args{mockQueryFrame(t, "UPDATE asd SET b = 2 WHERE a = 1"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, &GenericRequestInfo{baseRequestInfo: &baseRequestInfo{forwardDecision: forwardToBoth, shouldAlsoBeSentAsync: false, trackMetrics: true, writeTargets: []WriteTarget{{Keyspace: "", Table: "asd"}}}}}, {"OpCodeQuery UNKNOWN", args{mockQueryFrame(t, "UNKNOWN"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToBoth, false, true)}, {"OpCodeQuery CALL InsightsRpc.reportInsight(?)", args{mockQueryFrame(t, "CALL InsightsRpc.reportInsight(?)"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToNone, false, false)}, {"OpCodeQuery CALL insightsrpc.reportinsight('a', 1, -2.3, true, '2020-01-01')", args{mockQueryFrame(t, "CALL InsightsRpc.reportInsight('a', 1, -2.3, true, '2020-01-01')"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToNone, false, false)}, {"OpCodeQuery CALL DseGraphRpc.getSchemaBlob(?)", args{mockQueryFrame(t, "CALL DseGraphRpc.getSchemaBlob(?)"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToBoth, false, true)}, // PREPARE - {"OpCodePrepare SELECT", args{mockPrepareFrame(t, "SELECT blah FROM ks1.t1"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToOrigin, true, true), []*term{}, false, "SELECT blah FROM ks1.t1", "")}, - {"OpCodePrepare SELECT system.local forwardSystemQueriesToOrigin", args{mockPrepareFrame(t, "SELECT * FROM system.local"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.local", "")}, - {"OpCodePrepare SELECT system.peers forwardSystemQueriesToOrigin", args{mockPrepareFrame(t, "SELECT * FROM system.peers"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.peers", "")}, - {"OpCodePrepare SELECT system.local", args{mockPrepareFrame(t, "SELECT * FROM system.local"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.local", "")}, - {"OpCodePrepare SELECT local", args{mockPrepareFrameWithKeyspace(t, "SELECT * FROM local", "system"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), []*term{}, false, "SELECT * FROM local", "system")}, - {"OpCodePrepare SELECT system.peers", args{mockPrepareFrame(t, "SELECT * FROM system.peers"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.peers", "")}, - {"OpCodePrepare SELECT peers", args{mockPrepareFrameWithKeyspace(t, "SELECT * FROM peers", "system"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), []*term{}, false, "SELECT * FROM peers", "system")}, - {"OpCodePrepare SELECT system.peers_v2", args{mockPrepareFrame(t, "SELECT * FROM system.peers_v2"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV2, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.peers_v2", "")}, - {"OpCodePrepare SELECT system.peers_v2 forwardSystemQueriesToOrigin", args{mockPrepareFrame(t, "SELECT * FROM system.peers_v2"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV2, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.peers_v2", "")}, - {"OpCodePrepare SELECT system_auth.roles", args{mockPrepareFrame(t, "SELECT * FROM system_auth.roles"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToTarget, false, true), []*term{}, false, "SELECT * FROM system_auth.roles", "")}, - {"OpCodePrepare SELECT dse_insights.tokens", args{mockPrepareFrame(t, "SELECT * FROM dse_insights.tokens"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToTarget, false, true), []*term{}, false, "SELECT * FROM dse_insights.tokens", "")}, - {"OpCodePrepare INSERT INTO asd (a, b) VALUES (1, 2)", args{mockPrepareFrame(t, "INSERT INTO asd (a, b) VALUES (1, 2)"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), []*term{}, false, "INSERT INTO asd (a, b) VALUES (1, 2)", "")}, - {"OpCodePrepare UPDATE asd SET b = 2 WHERE a = 1", args{mockPrepareFrame(t, "UPDATE asd SET b = 2 WHERE a = 1"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), []*term{}, false, "UPDATE asd SET b = 2 WHERE a = 1", "")}, - {"OpCodePrepare UNKNOWN", args{mockPrepareFrame(t, "UNKNOWN"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), []*term{}, false, "UNKNOWN", "")}, + {"OpCodePrepare SELECT", args{mockPrepareFrame(t, "SELECT blah FROM ks1.t1"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToOrigin, true, true), []*term{}, false, "SELECT blah FROM ks1.t1", "", "t1")}, + {"OpCodePrepare SELECT system.local forwardSystemQueriesToOrigin", args{mockPrepareFrame(t, "SELECT * FROM system.local"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.local", "", "local")}, + {"OpCodePrepare SELECT system.peers forwardSystemQueriesToOrigin", args{mockPrepareFrame(t, "SELECT * FROM system.peers"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.peers", "", "peers")}, + {"OpCodePrepare SELECT system.local", args{mockPrepareFrame(t, "SELECT * FROM system.local"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.local", "", "local")}, + {"OpCodePrepare SELECT local", args{mockPrepareFrameWithKeyspace(t, "SELECT * FROM local", "system"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), []*term{}, false, "SELECT * FROM local", "system", "local")}, + {"OpCodePrepare SELECT system.peers", args{mockPrepareFrame(t, "SELECT * FROM system.peers"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.peers", "", "peers")}, + {"OpCodePrepare SELECT peers", args{mockPrepareFrameWithKeyspace(t, "SELECT * FROM peers", "system"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV1, newStarSelectClause()), []*term{}, false, "SELECT * FROM peers", "system", "peers")}, + {"OpCodePrepare SELECT system.peers_v2", args{mockPrepareFrame(t, "SELECT * FROM system.peers_v2"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV2, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.peers_v2", "", "peers_v2")}, + {"OpCodePrepare SELECT system.peers_v2 forwardSystemQueriesToOrigin", args{mockPrepareFrame(t, "SELECT * FROM system.peers_v2"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewInterceptedRequestInfo(peersV2, newStarSelectClause()), []*term{}, false, "SELECT * FROM system.peers_v2", "", "peers_v2")}, + {"OpCodePrepare SELECT system_auth.roles", args{mockPrepareFrame(t, "SELECT * FROM system_auth.roles"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToTarget, false, true), []*term{}, false, "SELECT * FROM system_auth.roles", "", "roles")}, + {"OpCodePrepare SELECT dse_insights.tokens", args{mockPrepareFrame(t, "SELECT * FROM dse_insights.tokens"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToTarget, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToTarget, false, true), []*term{}, false, "SELECT * FROM dse_insights.tokens", "", "tokens")}, + {"OpCodePrepare INSERT INTO asd (a, b) VALUES (1, 2)", args{mockPrepareFrame(t, "INSERT INTO asd (a, b) VALUES (1, 2)"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(&GenericRequestInfo{baseRequestInfo: &baseRequestInfo{forwardDecision: forwardToBoth, shouldAlsoBeSentAsync: false, trackMetrics: true, writeTargets: []WriteTarget{{Keyspace: "", Table: "asd"}}}}, []*term{}, false, "INSERT INTO asd (a, b) VALUES (1, 2)", "", "asd")}, + {"OpCodePrepare UPDATE asd SET b = 2 WHERE a = 1", args{mockPrepareFrame(t, "UPDATE asd SET b = 2 WHERE a = 1"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(&GenericRequestInfo{baseRequestInfo: &baseRequestInfo{forwardDecision: forwardToBoth, shouldAlsoBeSentAsync: false, trackMetrics: true, writeTargets: []WriteTarget{{Keyspace: "", Table: "asd"}}}}, []*term{}, false, "UPDATE asd SET b = 2 WHERE a = 1", "", "asd")}, + {"OpCodePrepare UNKNOWN", args{mockPrepareFrame(t, "UNKNOWN"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), []*term{}, false, "UNKNOWN", "", "")}, // EXECUTE {"OpCodeExecute origin", args{mockExecuteFrame(t, "ORIGIN"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewExecuteRequestInfo(originCacheEntry)}, @@ -125,8 +125,8 @@ func TestInspectFrame(t *testing.T) { // REGISTER {"OpCodeRegister", args{mockFrame(t, &message.Register{EventTypes: []primitive.EventType{primitive.EventTypeSchemaChange}}, primitive.ProtocolVersion4), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToBoth, false, false)}, // BATCH - {"OpCodeBatch simple", args{mockBatch(t, "simple query"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewBatchRequestInfo(map[int]PreparedData{})}, - {"OpCodeBatch prepared", args{mockBatch(t, []byte("BOTH")), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewBatchRequestInfo(map[int]PreparedData{0: bothCacheEntry})}, + {"OpCodeBatch simple", args{mockBatch(t, "simple query"), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewBatchRequestInfo(map[int]PreparedData{}, nil)}, + {"OpCodeBatch prepared", args{mockBatch(t, []byte("BOTH")), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewBatchRequestInfo(map[int]PreparedData{0: bothCacheEntry}, nil)}, // AUTH_RESPONSE {"OpCodeAuthResponse ForwardAuthToTarget", args{mockAuthResponse(t), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToTarget}, NewGenericRequestInfo(forwardToTarget, false, false)}, {"OpCodeAuthResponse ForwardAuthToOrigin", args{mockAuthResponse(t), []*term{}, primaryClusterOrigin, forwardSystemQueriesToOrigin, forwardAuthToOrigin}, NewGenericRequestInfo(forwardToOrigin, false, false)}, diff --git a/proxy/pkg/zdmproxy/parametermodifier_test.go b/proxy/pkg/zdmproxy/parametermodifier_test.go index 0cb85ade..49d08905 100644 --- a/proxy/pkg/zdmproxy/parametermodifier_test.go +++ b/proxy/pkg/zdmproxy/parametermodifier_test.go @@ -20,7 +20,7 @@ func TestAddValuesToExecuteFrame_NoReplacedTerms(t *testing.T) { ResultMetadataId: nil, Options: &message.QueryOptions{}, }) - prepareRequestInfo := NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), []*term{}, false, "", "") + prepareRequestInfo := NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), []*term{}, false, "", "", "") variablesMetadata := &message.VariablesMetadata{ PkIndices: nil, Columns: nil, @@ -42,7 +42,7 @@ func TestAddValuesToExecuteFrame_InvalidMessageType(t *testing.T) { Query: "SELECT * FROM asd WHERE a = :param1", Options: &message.QueryOptions{}, }) - prepareRequestInfo := NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), []*term{}, false, "", "") + prepareRequestInfo := NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), []*term{}, false, "", "", "") variablesMetadata := &message.VariablesMetadata{ PkIndices: nil, Columns: nil, @@ -205,7 +205,7 @@ func TestAddValuesToExecuteFrame_PositionalValues(t *testing.T) { Options: clonedQueryOpts, }) containsPositionalMarkers := ((len(requestPosVals) + len(replacedTerms)) > 0) && !test.prepareContainsNamedValues - prepareRequestInfo := NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), replacedTerms, containsPositionalMarkers, "", "") + prepareRequestInfo := NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), replacedTerms, containsPositionalMarkers, "", "", "") replacementTimeUuids := parameterModifier.generateTimeUuids(prepareRequestInfo) executeMsg, err := parameterModifier.AddValuesToExecuteFrame(f, prepareRequestInfo, vm, replacementTimeUuids) @@ -350,7 +350,7 @@ func TestAddValuesToExecuteFrame_NamedValues(t *testing.T) { ResultMetadataId: nil, Options: clonedQueryOpts, }) - prepareRequestInfo := NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), replacedTerms, false, "", "") + prepareRequestInfo := NewPrepareRequestInfo(NewGenericRequestInfo(forwardToBoth, false, true), replacedTerms, false, "", "", "") replacementTimeUuids := parameterModifier.generateTimeUuids(prepareRequestInfo) executeMsg, err := parameterModifier.AddValuesToExecuteFrame(f, prepareRequestInfo, vm, replacementTimeUuids) diff --git a/proxy/pkg/zdmproxy/proxy.go b/proxy/pkg/zdmproxy/proxy.go index ffd4f1c2..574982fb 100644 --- a/proxy/pkg/zdmproxy/proxy.go +++ b/proxy/pkg/zdmproxy/proxy.go @@ -264,6 +264,7 @@ func (p *ZdmProxy) initializeControlConnections(ctx context.Context) error { if err := originControlConn.Start(p.controlConnShutdownWg, ctx); err != nil { return fmt.Errorf("failed to initialize origin control connection: %w", err) } + originControlConn.CheckSuperUserAndWarn() p.lock.Lock() p.originControlConn = originControlConn @@ -276,6 +277,7 @@ func (p *ZdmProxy) initializeControlConnections(ctx context.Context) error { if err := targetControlConn.Start(p.controlConnShutdownWg, ctx); err != nil { return fmt.Errorf("failed to initialize target control connection: %w", err) } + targetControlConn.CheckSuperUserAndWarn() p.lock.Lock() p.targetControlConn = targetControlConn diff --git a/proxy/pkg/zdmproxy/requestinfo.go b/proxy/pkg/zdmproxy/requestinfo.go index a079dbc0..8c360339 100644 --- a/proxy/pkg/zdmproxy/requestinfo.go +++ b/proxy/pkg/zdmproxy/requestinfo.go @@ -2,16 +2,27 @@ package zdmproxy import "fmt" +// WriteTarget identifies a keyspace and table that a write operation targets. +// Used for per-table write metrics tracking. +type WriteTarget struct { + Keyspace string + Table string +} + type RequestInfo interface { GetForwardDecision() forwardDecision ShouldAlsoBeSentAsync() bool ShouldBeTrackedInMetrics() bool + // GetWriteTargets returns the keyspace/table pairs targeted by this write request. + // Returns nil for non-write requests (reads, PREPARE, intercepted queries). + GetWriteTargets() []WriteTarget } type baseRequestInfo struct { forwardDecision forwardDecision shouldAlsoBeSentAsync bool trackMetrics bool + writeTargets []WriteTarget } func newBaseRequestInfo(decision forwardDecision, shouldBeSentAsync bool, trackMetrics bool) *baseRequestInfo { @@ -30,6 +41,10 @@ func (recv *baseRequestInfo) ShouldBeTrackedInMetrics() bool { return recv.trackMetrics } +func (recv *baseRequestInfo) GetWriteTargets() []WriteTarget { + return recv.writeTargets +} + type GenericRequestInfo struct { *baseRequestInfo } @@ -49,6 +64,7 @@ type PrepareRequestInfo struct { containsPositionalMarkers bool query string keyspace string + tableName string } func NewPrepareRequestInfo( @@ -56,13 +72,15 @@ func NewPrepareRequestInfo( replacedTerms []*term, containsPositionalMarkers bool, query string, - keyspace string) *PrepareRequestInfo { + keyspace string, + tableName string) *PrepareRequestInfo { return &PrepareRequestInfo{ baseRequestInfo: baseRequestInfo, replacedTerms: replacedTerms, containsPositionalMarkers: containsPositionalMarkers, query: query, - keyspace: keyspace} + keyspace: keyspace, + tableName: tableName} } func (recv *PrepareRequestInfo) String() string { @@ -78,6 +96,14 @@ func (recv *PrepareRequestInfo) ShouldBeTrackedInMetrics() bool { return false } +func (recv *PrepareRequestInfo) GetWriteTargets() []WriteTarget { + return nil // PREPARE doesn't execute a write +} + +func (recv *PrepareRequestInfo) GetTableName() string { + return recv.tableName +} + func (recv *PrepareRequestInfo) GetQuery() string { return recv.query } @@ -133,6 +159,16 @@ func (recv *ExecuteRequestInfo) ShouldBeTrackedInMetrics() bool { return recv.preparedData.GetPrepareRequestInfo().GetBaseRequestInfo().ShouldBeTrackedInMetrics() } +func (recv *ExecuteRequestInfo) GetWriteTargets() []WriteTarget { + pri := recv.preparedData.GetPrepareRequestInfo() + // WriteTargets were set on the base RequestInfo during PREPARE parsing + baseTargets := pri.GetBaseRequestInfo().GetWriteTargets() + if baseTargets != nil { + return baseTargets + } + return nil +} + // InterceptedRequestInfo on its own means that this intercepted request is a QUERY request. // This can also be the base request field of a PrepareRequestInfo object in which case the intercepted request will be // a PREPARE (or EXECUTE if it's a ExecuteRequestInfo). @@ -163,12 +199,17 @@ func (recv *InterceptedRequestInfo) GetParsedSelectClause() *selectClause { return recv.parsedSelectClause } +func (recv *InterceptedRequestInfo) GetWriteTargets() []WriteTarget { + return nil +} + type BatchRequestInfo struct { preparedDataByStmtIdx map[int]PreparedData + writeTargets []WriteTarget } -func NewBatchRequestInfo(preparedDataByStmtIdx map[int]PreparedData) *BatchRequestInfo { - return &BatchRequestInfo{preparedDataByStmtIdx: preparedDataByStmtIdx} +func NewBatchRequestInfo(preparedDataByStmtIdx map[int]PreparedData, writeTargets []WriteTarget) *BatchRequestInfo { + return &BatchRequestInfo{preparedDataByStmtIdx: preparedDataByStmtIdx, writeTargets: writeTargets} } func (recv *BatchRequestInfo) String() string { @@ -187,6 +228,10 @@ func (recv *BatchRequestInfo) ShouldBeTrackedInMetrics() bool { return true } +func (recv *BatchRequestInfo) GetWriteTargets() []WriteTarget { + return recv.writeTargets +} + func (recv *BatchRequestInfo) GetPreparedDataByStmtIdx() map[int]PreparedData { return recv.preparedDataByStmtIdx }