From bcbe26b75617c671a9eb115582d9b9565f5bf3e9 Mon Sep 17 00:00:00 2001 From: Johan Lindh Date: Thu, 30 Apr 2026 14:06:46 +0200 Subject: [PATCH] feat(mqtt): split signing and publish workers --- pkg/runner/mqtt.go | 130 ++++++++++---- pkg/runner/mqtt_test.go | 363 ++++++++++++++++++++++++++++++++++++++++ pkg/runner/runner.go | 18 +- 3 files changed, 471 insertions(+), 40 deletions(-) create mode 100644 pkg/runner/mqtt_test.go diff --git a/pkg/runner/mqtt.go b/pkg/runner/mqtt.go index a815e92..edee9c4 100644 --- a/pkg/runner/mqtt.go +++ b/pkg/runner/mqtt.go @@ -1,11 +1,13 @@ package runner import ( + "context" "crypto/tls" "crypto/x509" "fmt" "log/slog" "net/url" + "sync" "github.com/eclipse/paho.golang/autopaho" "github.com/eclipse/paho.golang/autopaho/queue/file" @@ -47,6 +49,12 @@ func (pel pahoErrorLogger) Printf(format string, v ...interface{}) { pel.logger.Error(fmt.Sprintf(format, v...)) } +type mqttConnectionManager interface { + AwaitConnection(context.Context) error + PublishViaQueue(context.Context, *autopaho.QueuePublish) error + Publish(context.Context, *paho.Publish) (*paho.PublishResponse, error) +} + func (edm *dnstapMinimiser) newAutoPahoClientConfig(caCertPool *x509.CertPool, server string, clientID string, mqttKeepAlive uint16, localFileQueue *file.Queue) (autopaho.ClientConfig, error) { u, err := url.Parse(server) if err != nil { @@ -88,19 +96,75 @@ func (edm *dnstapMinimiser) newAutoPahoClientConfig(caCertPool *x509.CertPool, s return cliCfg, nil } -func (edm *dnstapMinimiser) runAutoPaho(cm *autopaho.ConnectionManager, mqttJWK jwk.Key, usingFileQueue bool) { - defer edm.autopahoWg.Done() - +// startMQTTPipeline launches N JWS sign workers and 1 paho publisher. The +// previous design ran sign + publish in a single goroutine, which made +// jws.Sign a serialization bottleneck. Splitting them lets sign work +// parallelize across cores while the paho ConnectionManager's +// single-connection requirement is preserved by the lone publisher. +func (edm *dnstapMinimiser) startMQTTPipeline(cm mqttConnectionManager, mqttJWK jwk.Key, usingFileQueue bool, signWorkers int) { + if signWorkers <= 0 { + signWorkers = 1 + } topic := "events/up/" + mqttJWK.KeyID() + "/new_qname" - edm.log.Info("starting signing MQTT publisher", "jwk_id", mqttJWK.KeyID(), "jwk_alg", mqttJWK.Algorithm(), "topic", topic) + edm.log.Info("starting signing MQTT publisher", + "jwk_id", mqttJWK.KeyID(), + "jwk_alg", mqttJWK.Algorithm(), + "topic", topic, + "sign_workers", signWorkers, + ) + + // Sign workers: each independently reads unsigned bytes, JWS-signs, + // pushes the signed bytes onto the publisher's queue. When mqttPubCh + // is closed, each worker exits; when all are done, the last one + // closes mqttSignedCh so the publisher knows to drain and exit. + var signWg sync.WaitGroup + signWg.Add(signWorkers) + for i := 0; i < signWorkers; i++ { + go edm.mqttSignWorker(&signWg, mqttJWK) + } + + edm.autopahoWg.Add(1) + go func() { + defer edm.autopahoWg.Done() + signWg.Wait() + close(edm.mqttSignedCh) + }() + + edm.autopahoWg.Add(1) + go edm.mqttPublishWorker(cm, topic, usingFileQueue) +} + +// mqttSignWorker drains mqttPubCh, JWS-signs each message, and forwards to +// mqttSignedCh. Exits when mqttPubCh is closed. +func (edm *dnstapMinimiser) mqttSignWorker(wg *sync.WaitGroup, mqttJWK jwk.Key) { + defer wg.Done() + for unsignedMsg := range edm.mqttPubCh { + signedMsg, err := jws.Sign(unsignedMsg, jws.WithJSON(), jws.WithKey(mqttJWK.Algorithm(), mqttJWK)) + if err != nil { + edm.log.Error("mqttSignWorker: failed to create JWS message", "error", err) + continue + } + select { + case edm.mqttSignedCh <- signedMsg: + case <-edm.autopahoCtx.Done(): + return + } + } +} + +// mqttPublishWorker is the single goroutine that talks to paho. Single-writer +// matches paho's ConnectionManager expectations; signing remains parallel +// upstream while broker back-pressure is contained to this publisher. +func (edm *dnstapMinimiser) mqttPublishWorker(cm mqttConnectionManager, topic string, usingFileQueue bool) { + defer edm.autopahoWg.Done() + + var signedMsg []byte for { // We only need to wait for a server connection if we have no // local queue. Otherwise we can just start appending messages // to disk. if !usingFileQueue { - // AwaitConnection will return immediately if connection is up; adding this call stops publication whilst - // connection is unavailable. err := cm.AwaitConnection(edm.autopahoCtx) if err != nil { // Should only happen when context is cancelled edm.log.Error("publisher done", "AwaitConnection", err) @@ -108,22 +172,20 @@ func (edm *dnstapMinimiser) runAutoPaho(cm *autopaho.ConnectionManager, mqttJWK } } - // Wait for a message to publish - unsignedMsg := <-edm.mqttPubCh - if unsignedMsg == nil { - // The channel has been closed - edm.log.Info("runAutoPaho: message queue closed, exiting") + var ok bool + select { + case signedMsg, ok = <-edm.mqttSignedCh: + if !ok { + edm.log.Info("mqttPublishWorker: signed queue closed, exiting") + return + } + case <-edm.autopahoCtx.Done(): + edm.log.Info("mqttPublishWorker: context cancelled, exiting") return } - signedMsg, err := jws.Sign(unsignedMsg, jws.WithJSON(), jws.WithKey(mqttJWK.Algorithm(), mqttJWK)) - if err != nil { - edm.log.Error("runAutoPaho: failed to create JWS message", "error", err) - continue - } - if usingFileQueue { - err = cm.PublishViaQueue(edm.autopahoCtx, &autopaho.QueuePublish{ + err := cm.PublishViaQueue(edm.autopahoCtx, &autopaho.QueuePublish{ Publish: &paho.Publish{ QoS: 0, Topic: topic, @@ -134,23 +196,21 @@ func (edm *dnstapMinimiser) runAutoPaho(cm *autopaho.ConnectionManager, mqttJWK edm.log.Error("error writing message to queue", "error", err) } } else { - // Publish will block so we run it in a goroutine - go func(msg []byte) { - pr, err := cm.Publish(edm.autopahoCtx, &paho.Publish{ - QoS: 0, - Topic: topic, - Payload: msg, - }) - if err != nil { - edm.log.Error("error publishing", "error", err) - } else if pr != nil && pr.ReasonCode != 0 && pr.ReasonCode != 16 { // 16 = Server received message but there are no subscribers - // pr is only non-nil for QoS 1 and up - edm.log.Info("reason code received", "reason_code", pr.ReasonCode) - } - if edm.debug { - edm.log.Info("sent message", "content", string(msg)) - } - }(signedMsg) + pr, err := cm.Publish(edm.autopahoCtx, &paho.Publish{ + QoS: 0, + Topic: topic, + Payload: signedMsg, + }) + if err != nil { + edm.log.Error("error publishing", "error", err) + } else if pr != nil && pr.ReasonCode != 0 && pr.ReasonCode != 16 { + // pr is only non-nil for QoS 1 and up; + // 16 = "no subscribers" which is fine. + edm.log.Info("reason code received", "reason_code", pr.ReasonCode) + } + if edm.debug { + edm.log.Info("sent message", "content", string(signedMsg)) + } } select { diff --git a/pkg/runner/mqtt_test.go b/pkg/runner/mqtt_test.go new file mode 100644 index 0000000..25a2a3d --- /dev/null +++ b/pkg/runner/mqtt_test.go @@ -0,0 +1,363 @@ +package runner + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "io" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/eclipse/paho.golang/autopaho" + "github.com/eclipse/paho.golang/paho" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" +) + +// newTestMQTTJWK builds an EdDSA jwk.Key suitable for the JWS pipeline. +// The key's algorithm/key-id are populated the same way the production +// loader (edDsaJWKFromFile) does. Returned alongside the corresponding +// public key so tests can verify signed messages. +func newTestMQTTJWK(t *testing.T) (priv jwk.Key, pub ed25519.PublicKey) { + t.Helper() + + pub, sk, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("ed25519.GenerateKey: %s", err) + } + + priv, err = jwk.FromRaw(sk) + if err != nil { + t.Fatalf("jwk.FromRaw: %s", err) + } + if err := priv.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil { + t.Fatalf("set Algorithm: %s", err) + } + if err := priv.Set(jwk.KeyIDKey, "test-mqtt-key"); err != nil { + t.Fatalf("set KeyID: %s", err) + } + return priv, pub +} + +func cleanupMQTTTestMinimiser(edm *dnstapMinimiser) { + if edm.stop != nil { + edm.stop() + } + if edm.fsWatcher != nil { + _ = edm.fsWatcher.Close() + edm.fsWatcher = nil + } +} + +// TestMqttSignWorkerSignsAndForwards covers the happy path: the worker +// reads an unsigned payload from mqttPubCh, JWS-signs it with the supplied +// JWK, and forwards the signed envelope on mqttSignedCh. We then verify +// the JWS using the matching public key to make sure the worker is +// actually signing the payload it received (not e.g. a constant or empty +// buffer). +// +// This pins the contract introduced in signing is +// parallelizable because each worker is independent of the others and of +// paho - a future refactor that, say, mutated the payload buffer in place +// across workers would corrupt the signature, and this test would catch +// it. +func TestMqttSignWorkerSignsAndForwards(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + edm, err := newDnstapMinimiser(logger, defaultTC) + if err != nil { + t.Fatalf("newDnstapMinimiser: %s", err) + } + t.Cleanup(func() { cleanupMQTTTestMinimiser(edm) }) + + // The worker ranges over mqttPubCh and selects on autopahoCtx.Done() + // for cancellation. Bind a fresh context so this test does not affect + // other tests and so we can guarantee cancellation on cleanup. + ctx, cancel := context.WithCancel(t.Context()) + edm.autopahoCtx = ctx + t.Cleanup(cancel) + + priv, pub := newTestMQTTJWK(t) + + var wg sync.WaitGroup + wg.Add(1) + go edm.mqttSignWorker(&wg, priv) + + payload := []byte(`{"qname":"example.com.","time":"2026-01-02T03:04:05Z"}`) + edm.mqttPubCh <- payload + + select { + case signed := <-edm.mqttSignedCh: + // Verify the signature using the matching public key. Verify + // returns the original payload bytes on success. + got, err := jws.Verify(signed, jws.WithKey(jwa.EdDSA, pub)) + if err != nil { + t.Fatalf("jws.Verify: %s", err) + } + if string(got) != string(payload) { + t.Fatalf("signed payload mismatch\n have: %s\n want: %s", got, payload) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for signed message on mqttSignedCh") + } + + // Closing the input channel must let the worker exit cleanly. If a + // future refactor accidentally introduced an unbounded inner loop the + // wg.Wait() below would hang and t.Cleanup-driven cancel() would not + // rescue us - so we wait with a timeout and fail loudly instead. + close(edm.mqttPubCh) + waitOrFail(t, &wg, 2*time.Second, "mqttSignWorker did not exit after mqttPubCh close") +} + +// TestMqttSignWorkerExitsOnContextCancelWhenSignedFull demonstrates the +// back-pressure escape hatch: if the publisher stalls and mqttSignedCh +// fills, the sign worker must not deadlock - it must return when +// autopahoCtx is cancelled. Without this, cancelling the run context +// would leave goroutines blocked on the channel send. +// +// Setup: replace mqttSignedCh with a *full* unbuffered-equivalent channel +// (capacity 1, pre-loaded) so the worker's send blocks. Then cancel the +// context and observe that the worker exits. +func TestMqttSignWorkerExitsOnContextCancelWhenSignedFull(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + edm, err := newDnstapMinimiser(logger, defaultTC) + if err != nil { + t.Fatalf("newDnstapMinimiser: %s", err) + } + t.Cleanup(func() { cleanupMQTTTestMinimiser(edm) }) + + ctx, cancel := context.WithCancel(t.Context()) + edm.autopahoCtx = ctx + + // Replace the default 1024-deep channel with a tiny pre-filled one so + // we can deterministically force the worker's send to block. + edm.mqttSignedCh = make(chan []byte, 1) + edm.mqttSignedCh <- []byte("placeholder") + + priv, _ := newTestMQTTJWK(t) + + var wg sync.WaitGroup + wg.Add(1) + go edm.mqttSignWorker(&wg, priv) + + // Hand the worker exactly one message; it will sign it and then block + // trying to enqueue on the (already full) signed channel. + edm.mqttPubCh <- []byte("payload") + + // Give the worker a moment to actually reach the blocked select. + // 50ms is generous on any real machine; we don't poll with a + // shorter, busier loop because we want to keep the test simple and + // the operation we're racing against is a cheap goroutine reaching + // a select. + time.Sleep(50 * time.Millisecond) + + cancel() + waitOrFail(t, &wg, 2*time.Second, "mqttSignWorker did not exit after context cancel") +} + +// TestMqttSignWorkerSkipsBadKey verifies the worker tolerates jws.Sign +// failures: a misconfigured/unsigned-eligible jwk causes the sign call to +// error, which the worker logs and then continues to the next message. +// We cannot easily construct a "broken" jwk.Key from an unsigned input, +// so we instead set a clearly mismatched algorithm on a valid Ed25519 key +// - jws.Sign with WithKey(, ) rejects the +// combination - and confirm: (a) the worker does not exit, (b) the next +// well-formed message after fixing the key still gets signed. +// +// Why this matters: the worker is a long-lived goroutine. If a transient +// signing error caused it to exit, every subsequent message would pile +// up unsigned in mqttPubCh (or be dropped on close) until the process +// restarted. The "continue past sign errors" behaviour was a deliberate +// design choice; this test pins it. +func TestMqttSignWorkerSkipsBadKey(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + edm, err := newDnstapMinimiser(logger, defaultTC) + if err != nil { + t.Fatalf("newDnstapMinimiser: %s", err) + } + t.Cleanup(func() { cleanupMQTTTestMinimiser(edm) }) + + ctx, cancel := context.WithCancel(t.Context()) + edm.autopahoCtx = ctx + t.Cleanup(cancel) + + priv, pub := newTestMQTTJWK(t) + + // Force a signing error by claiming the Ed25519 key uses RS256 - the + // jws library will refuse to sign. + if err := priv.Set(jwk.AlgorithmKey, jwa.RS256); err != nil { + t.Fatalf("set Algorithm: %s", err) + } + + var wg sync.WaitGroup + wg.Add(1) + go edm.mqttSignWorker(&wg, priv) + + // Push one "bad" message; the worker will fail to sign and continue. + edm.mqttPubCh <- []byte("bad-payload") + + // Drain attempt: there must be no signed message. We rely on a short + // wait because nothing else is feeding mqttSignedCh. + select { + case got := <-edm.mqttSignedCh: + t.Fatalf("mqttSignedCh unexpectedly received: %q", got) + case <-time.After(100 * time.Millisecond): + // expected: no signed output + } + + // Now flip the algorithm back to a valid one and push a real message. + // A correctly configured worker continues past the prior error and + // signs this one. + if err := priv.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil { + t.Fatalf("restore Algorithm: %s", err) + } + good := []byte(`{"ok":true}`) + edm.mqttPubCh <- good + + select { + case signed := <-edm.mqttSignedCh: + got, err := jws.Verify(signed, jws.WithKey(jwa.EdDSA, pub)) + if err != nil { + t.Fatalf("jws.Verify: %s", err) + } + if string(got) != string(good) { + t.Fatalf("payload mismatch have: %s want: %s", got, good) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for signed message after recovery") + } + + close(edm.mqttPubCh) + waitOrFail(t, &wg, 2*time.Second, "mqttSignWorker did not exit cleanly") +} + +type blockingMQTTConnectionManager struct { + publishStarted chan []byte + release chan struct{} + active atomic.Int32 + concurrent atomic.Bool +} + +func (cm *blockingMQTTConnectionManager) AwaitConnection(context.Context) error { + return nil +} + +func (cm *blockingMQTTConnectionManager) PublishViaQueue(context.Context, *autopaho.QueuePublish) error { + return nil +} + +func (cm *blockingMQTTConnectionManager) Publish(ctx context.Context, publish *paho.Publish) (*paho.PublishResponse, error) { + if cm.active.Add(1) > 1 { + cm.concurrent.Store(true) + } + defer cm.active.Add(-1) + + select { + case cm.publishStarted <- publish.Payload: + default: + } + + select { + case <-cm.release: + case <-ctx.Done(): + } + return nil, nil +} + +func TestMqttPublishWorkerPublishesSerially(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + edm, err := newDnstapMinimiser(logger, defaultTC) + if err != nil { + t.Fatalf("newDnstapMinimiser: %s", err) + } + t.Cleanup(func() { cleanupMQTTTestMinimiser(edm) }) + + ctx, cancel := context.WithCancel(t.Context()) + edm.autopahoCtx = ctx + t.Cleanup(cancel) + + edm.mqttSignedCh = make(chan []byte, 2) + cm := &blockingMQTTConnectionManager{ + publishStarted: make(chan []byte, 2), + release: make(chan struct{}), + } + + edm.autopahoWg.Add(1) + go edm.mqttPublishWorker(cm, "events/up/test/new_qname", false) + + edm.mqttSignedCh <- []byte("first") + select { + case got := <-cm.publishStarted: + if string(got) != "first" { + t.Fatalf("first publish payload have: %s, want: first", got) + } + case <-time.After(2 * time.Second): + t.Fatal("first publish did not start") + } + + edm.mqttSignedCh <- []byte("second") + select { + case got := <-cm.publishStarted: + t.Fatalf("second publish started before first publish completed: %s", got) + case <-time.After(100 * time.Millisecond): + } + if cm.concurrent.Load() { + t.Fatal("mqttPublishWorker called Publish concurrently") + } + + close(cm.release) + close(edm.mqttSignedCh) + waitOrFail(t, &edm.autopahoWg, 2*time.Second, "mqttPublishWorker did not drain and exit") +} + +// TestMqttPublishWorkerExitsOnContextCancel verifies that mqttPublishWorker +// exits when autopahoCtx is cancelled even if mqttSignedCh is empty. Without +// the fix, the goroutine blocks on the channel receive and can only exit when +// the channel is closed (the !ok path) or a message arrives, not on context +// cancellation. +func TestMqttPublishWorkerExitsOnContextCancel(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + edm, err := newDnstapMinimiser(logger, defaultTC) + if err != nil { + t.Fatalf("newDnstapMinimiser: %s", err) + } + t.Cleanup(func() { cleanupMQTTTestMinimiser(edm) }) + + ctx, cancel := context.WithCancel(t.Context()) + edm.autopahoCtx = ctx + + edm.mqttSignedCh = make(chan []byte) + + cm := &blockingMQTTConnectionManager{ + publishStarted: make(chan []byte, 1), + release: make(chan struct{}), + } + + edm.autopahoWg.Add(1) + go edm.mqttPublishWorker(cm, "events/up/test/new_qname", false) + + time.Sleep(50 * time.Millisecond) + + cancel() + waitOrFail(t, &edm.autopahoWg, 2*time.Second, "mqttPublishWorker did not exit after context cancel") +} + +// waitOrFail waits for wg with a deadline, calling t.Fatalf with the +// supplied message on timeout. Centralized so the timeout discipline is +// uniform across the MQTT worker tests. +func waitOrFail(t *testing.T, wg *sync.WaitGroup, d time.Duration, msg string) { + t.Helper() + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(d): + t.Fatal(msg) + } +} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 17cb068..459c49a 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -90,6 +90,7 @@ type config struct { MQTTServer string `mapstructure:"mqtt-server" validate:"required_without=DisableMQTT"` MQTTCAFile string `mapstructure:"mqtt-ca-file"` MQTTKeepalive uint16 `mapstructure:"mqtt-keepalive" validate:"required_without=DisableMQTT"` + MQTTSignWorkers int `mapstructure:"mqtt-sign-workers"` QnameSeenEntries int `mapstructure:"qname-seen-entries"` CryptopanAddressEntries int `mapstructure:"cryptopan-address-entries" reload:"true"` NewQnameBuffer int `mapstructure:"newqname-buffer"` @@ -710,9 +711,12 @@ func (edm *dnstapMinimiser) setupMQTT() { os.Exit(1) } - // Connect to the broker - this will return immediately after initiating the connection process - edm.autopahoWg.Add(1) - go edm.runAutoPaho(autopahoCm, mqttJWK, mqttFileQueue != nil) + // Connect to the broker - this will return immediately after initiating the connection process. + signWorkers := conf.MQTTSignWorkers + if signWorkers <= 0 { + signWorkers = runtime.GOMAXPROCS(0) + } + edm.startMQTTPipeline(autopahoCm, mqttJWK, mqttFileQueue != nil, signWorkers) } func (edm *dnstapMinimiser) loadHTTPClientCert() error { @@ -1455,6 +1459,7 @@ type dnstapMinimiser struct { aggregSenderMutex sync.RWMutex aggregSender aggregateSender mqttPubCh chan []byte + mqttSignedCh chan []byte autopahoCtx context.Context autopahoCancel context.CancelFunc autopahoWg sync.WaitGroup @@ -1604,8 +1609,11 @@ func newDnstapMinimiser(logger *slog.Logger, edmConf edmConfiger) (*dnstapMinimi edm.httpClientCertStore = newCertStore() edm.mqttClientCertStore = newCertStore() - // Setup channel for reading messages to publish - edm.mqttPubCh = make(chan []byte, 100) + // Setup channels for the MQTT publish pipeline. mqttPubCh holds + // unsigned events from minimisers; mqttSignedCh holds signed + // envelopes ready for paho to publish. + edm.mqttPubCh = make(chan []byte, 1024) + edm.mqttSignedCh = make(chan []byte, 1024) // Setup channels for feeding writers and data senders that should do // their work outside the main minimiser loop. They are buffered to