From 50c655170715c4bfcbc7bd5a42e505b8be5e8add Mon Sep 17 00:00:00 2001 From: Jan Scheres Date: Sun, 8 Feb 2026 18:23:25 +0000 Subject: [PATCH] go and cpp support decoding multiple messages in single UDP packet --- cpp/cdc/CDC.cpp | 68 ++++++------- cpp/core/Protocol.hpp | 20 ++-- cpp/shard/Shard.cpp | 90 ++++++++--------- go/client/client.go | 218 ++++++++++++++++++++++++------------------ 4 files changed, 217 insertions(+), 179 deletions(-) diff --git a/cpp/cdc/CDC.cpp b/cpp/cdc/CDC.cpp index 241a1b5c..bc28c879 100644 --- a/cpp/cdc/CDC.cpp +++ b/cpp/cdc/CDC.cpp @@ -668,46 +668,48 @@ struct CDCServer : Loop { void _processCDCMessages() { int startUpdateSize = _updateSize(); for (auto& msg: _channel.protocolMessages(CDC_REQ_PROTOCOL_VERSION)) { - // First, try to parse the header - CDCReqMsg cdcMsg; - try { - cdcMsg.unpack(msg.buf); - } catch (const BincodeException& err) { - LOG_ERROR(_env, "could not parse: %s", err.what()); - RAISE_ALERT(_env, "could not parse request from %s, dropping it.", msg.clientAddr); - continue; - } + while (msg.buf.remaining()) { + // First, try to parse the header + CDCReqMsg cdcMsg; + try { + cdcMsg.unpack(msg.buf); + } catch (const BincodeException& err) { + LOG_ERROR(_env, "could not parse: %s", err.what()); + RAISE_ALERT(_env, "could not parse request from %s, dropping it.", msg.clientAddr); + break; + } - LOG_DEBUG(_env, "received request id %s, kind %s", cdcMsg.id, cdcMsg.body.kind()); - auto receivedAt = ternNow(); + LOG_DEBUG(_env, "received request id %s, kind %s", cdcMsg.id, cdcMsg.body.kind()); + auto receivedAt = ternNow(); - if (unlikely(cdcMsg.body.kind() == CDCMessageKind::CDC_SNAPSHOT)) { - _processCDCSnapshotMessage(cdcMsg, msg); - continue; - } + if (unlikely(cdcMsg.body.kind() == CDCMessageKind::CDC_SNAPSHOT)) { + _processCDCSnapshotMessage(cdcMsg, msg); + continue; + } - // If we're already processing this request, drop it to try to not clog the queue - if (_inFlightCDCReqs.contains(InFlightCDCRequestKey(cdcMsg.id, msg.clientAddr))) { - LOG_DEBUG(_env, "dropping req id %s from %s since it's already being processed", cdcMsg.id, msg.clientAddr); - continue; - } + // If we're already processing this request, drop it to try to not clog the queue + if (_inFlightCDCReqs.contains(InFlightCDCRequestKey(cdcMsg.id, msg.clientAddr))) { + LOG_DEBUG(_env, "dropping req id %s from %s since it's already being processed", cdcMsg.id, msg.clientAddr); + continue; + } - if (unlikely(_shared.isLeader.load(std::memory_order_relaxed) == false)) { - LOG_DEBUG(_env, "dropping request since we're not the leader %s", cdcMsg); - continue; - } + if (unlikely(_shared.isLeader.load(std::memory_order_relaxed) == false)) { + LOG_DEBUG(_env, "dropping request since we're not the leader %s", cdcMsg); + continue; + } - auto& cdcReq = _cdcReqs.emplace_back(std::move(cdcMsg.body)); + auto& cdcReq = _cdcReqs.emplace_back(std::move(cdcMsg.body)); - _inFlightCDCReqs.insert(InFlightCDCRequestKey(cdcMsg.id, msg.clientAddr)); + _inFlightCDCReqs.insert(InFlightCDCRequestKey(cdcMsg.id, msg.clientAddr)); - LOG_DEBUG(_env, "CDC request %s successfully parsed, will process soon", cdcReq.kind()); - _cdcReqsInfo.emplace_back(CDCReqInfo{ - .reqId = cdcMsg.id, - .clientAddr = msg.clientAddr, - .receivedAt = receivedAt, - .sockIx = msg.socketIx, - }); + LOG_DEBUG(_env, "CDC request %s successfully parsed, will process soon", cdcReq.kind()); + _cdcReqsInfo.emplace_back(CDCReqInfo{ + .reqId = cdcMsg.id, + .clientAddr = msg.clientAddr, + .receivedAt = receivedAt, + .sockIx = msg.socketIx, + }); + } } } diff --git a/cpp/core/Protocol.hpp b/cpp/core/Protocol.hpp index 95696ac9..12da4579 100644 --- a/cpp/core/Protocol.hpp +++ b/cpp/core/Protocol.hpp @@ -93,13 +93,13 @@ inline std::ostream& operator<<(std::ostream& out, const ShardCheckPointedResp& return out << "checkPointIdx: " << resp.checkPointIdx << " resp: " << resp.resp; } -using ShardReqMsg = ProtocolMessage; -using ShardRespMsg = ProtocolMessage; -using CdcToShardReqMsg = SignedProtocolMessage; -using CdcToShardRespMsg = SignedProtocolMessage; -using CDCReqMsg = ProtocolMessage; -using CDCRespMsg = ProtocolMessage; -using LogReqMsg = SignedProtocolMessage; -using LogRespMsg = SignedProtocolMessage; -using ProxyShardReqMsg = SignedProtocolMessage; -using ProxyShardRespMsg = SignedProtocolMessage; +using ShardReqMsg = ProtocolMessage; +using ShardRespMsg = ProtocolMessage; +using CdcToShardReqMsg = SignedProtocolMessage; +using CdcToShardRespMsg = SignedProtocolMessage; +using CDCReqMsg = ProtocolMessage; +using CDCRespMsg = ProtocolMessage; +using LogReqMsg = SignedProtocolMessage; +using LogRespMsg = SignedProtocolMessage; +using ProxyShardReqMsg = SignedProtocolMessage; +using ProxyShardRespMsg = SignedProtocolMessage; diff --git a/cpp/shard/Shard.cpp b/cpp/shard/Shard.cpp index 1a5e5852..b4df0072 100644 --- a/cpp/shard/Shard.cpp +++ b/cpp/shard/Shard.cpp @@ -472,61 +472,63 @@ struct ShardServer : Loop { return; } - ShardReqMsg req; - try { - switch (protocol) { - case CDC_TO_SHARD_REQ_PROTOCOL_VERSION: - { - CdcToShardReqMsg signedReq; - signedReq.unpack(msg.buf, _expandedCDCKey); - req.id = signedReq.id; - req.body = std::move(signedReq.body); - } - break; - case SHARD_REQ_PROTOCOL_VERSION: - req.unpack(msg.buf); - if (isPrivilegedRequestKind((uint8_t)req.body.kind())) { - LOG_ERROR(_env, "Received unauthenticated request %s from %s", req.body.kind(), msg.clientAddr); - return; - } - break; - case PROXY_SHARD_REQ_PROTOCOL_VERSION: - { - ProxyShardReqMsg signedReq; - signedReq.unpack(msg.buf, _expandedShardKey); - req.id = signedReq.id; - req.body = std::move(signedReq.body); + while (msg.buf.remaining()) { + ShardReqMsg req; + try { + switch (protocol) { + case CDC_TO_SHARD_REQ_PROTOCOL_VERSION: + { + CdcToShardReqMsg signedReq; + signedReq.unpack(msg.buf, _expandedCDCKey); + req.id = signedReq.id; + req.body = std::move(signedReq.body); + } + break; + case SHARD_REQ_PROTOCOL_VERSION: + req.unpack(msg.buf); + if (isPrivilegedRequestKind((uint8_t)req.body.kind())) { + LOG_ERROR(_env, "Received unauthenticated request %s from %s", req.body.kind(), msg.clientAddr); + return; + } + break; + case PROXY_SHARD_REQ_PROTOCOL_VERSION: + { + ProxyShardReqMsg signedReq; + signedReq.unpack(msg.buf, _expandedShardKey); + req.id = signedReq.id; + req.body = std::move(signedReq.body); + } + break; + default: + ALWAYS_ASSERT(false, "Unknown protocol version"); } - break; - default: - ALWAYS_ASSERT(false, "Unknown protocol version"); + } catch (const BincodeException& err) { + LOG_ERROR(_env, "Could not parse: %s", err.what()); + RAISE_ALERT(_env, "could not parse request from %s, dropping it.", msg.clientAddr); + return; } - } catch (const BincodeException& err) { - LOG_ERROR(_env, "Could not parse: %s", err.what()); - RAISE_ALERT(_env, "could not parse request from %s, dropping it.", msg.clientAddr); - return; - } - auto t0 = ternNow(); + auto t0 = ternNow(); - LOG_DEBUG(_env, "received request id %s, kind %s, from %s", req.id, req.body.kind(), msg.clientAddr); + LOG_DEBUG(_env, "received request id %s, kind %s, from %s", req.id, req.body.kind(), msg.clientAddr); - if (bigRequest(req.body.kind())) { + if (bigRequest(req.body.kind())) { if (unlikely(_env._shouldLog(LogLevel::LOG_TRACE))) { LOG_TRACE(_env, "parsed request: %s", req); } else { LOG_DEBUG(_env, "parsed request: "); } - } else { - LOG_DEBUG(_env, "parsed request: %s", req); - } + } else { + LOG_DEBUG(_env, "parsed request: %s", req); + } - auto& entry = _requestNeedsConsistency(req.body.kind(), protocol) ? _writeReqs.emplace_back() : _readRequests.emplace_back(); - entry.sockIx = msg.socketIx; - entry.clientAddr = msg.clientAddr; - entry.receivedAt = t0; - entry.protocol = protocol; - entry.msg = std::move(req); + auto& entry = _requestNeedsConsistency(req.body.kind(), protocol) ? _writeReqs.emplace_back() : _readRequests.emplace_back(); + entry.sockIx = msg.socketIx; + entry.clientAddr = msg.clientAddr; + entry.receivedAt = t0; + entry.protocol = protocol; + entry.msg = std::move(req); + } } // All write requests fall into this category. Some read requests issues by CDC also need cross regional consistency diff --git a/go/client/client.go b/go/client/client.go index 8ab87977..18af8334 100644 --- a/go/client/client.go +++ b/go/client/client.go @@ -226,6 +226,7 @@ type rawMetadataResponse struct { kind uint8 respLen int buf *[]byte // the buf contains the header + offset int } type clientMetadata struct { @@ -368,115 +369,131 @@ func (cm *clientMetadata) processRequests(log *log.Logger) { log.Debug("got close request in request processor, winding down") } -func (cm *clientMetadata) parseResponse(log *log.Logger, req *metadataProcessorRequest, rawResp *rawMetadataResponse, dischargeBuf bool) { +func (cm *clientMetadata) handlePacket(log *log.Logger, rawResp *rawMetadataResponse) { + releaseBuf := true defer func() { - if !dischargeBuf { - return - } - select { - case cm.responsesBufs <- rawResp.buf: - default: - panic(fmt.Errorf("impossible: could not put back response buffer which we got from socket drainer")) + if releaseBuf && rawResp.buf != nil { + select { + case cm.responsesBufs <- rawResp.buf: + default: + panic(fmt.Errorf("impossible: could not put back response buffer which we got from socket drainer")) + } } }() - // check protocol - if req.shard < 0 { // CDC - if rawResp.protocol != msgs.CDC_RESP_PROTOCOL_VERSION { - log.RaiseAlert("got bad cdc protocol %v for request id %v, ignoring", rawResp.protocol, req.requestId) - return - } - } else { - if rawResp.protocol != msgs.SHARD_RESP_PROTOCOL_VERSION { - log.RaiseAlert("got bad shard protocol %v for request id %v, shard %v, ignoring", rawResp.protocol, req.shard, req.requestId) - return - } - } - // remove everywhere - delete(cm.earlyRequests, req.requestId) - if _, found := cm.requestsById[req.requestId]; found { - delete(cm.requestsById, req.requestId) - heap.Remove(&cm.requestsByTimeout, req.index) + + buf := *rawResp.buf + + if rawResp.offset == 0 { + rawResp.offset = 4 } - if rawResp.kind == msgs.ERROR { - var err error - if rawResp.respLen != 4+8+1+2 { - log.RaiseAlert("bad error response length %v, expected %v", rawResp.respLen, 4+8+1+2) - err = msgs.MALFORMED_RESPONSE - } else { - err = msgs.TernError(binary.LittleEndian.Uint16((*rawResp.buf)[4+8+1:])) - } - req.respCh <- &metadataProcessorResponse{ - requestId: req.requestId, - err: err, - extra: req.extra, - resp: nil, - } - } else { - // check kind - if req.shard < 0 { // CDC - expectedKind := req.req.(msgs.CDCRequest).CDCRequestKind() - if uint8(expectedKind) != rawResp.kind { - log.RaiseAlert("got bad cdc kind %v for request id %v, expected %v", msgs.CDCMessageKind(rawResp.kind), req.requestId, expectedKind) + + reader := bytes.NewReader(buf[rawResp.offset:rawResp.respLen]) + + for reader.Len() > 0 { + if reader.Len() < 9 { + log.RaiseAlert("got runt metadata message, expected at least %v bytes, got %v", 8+1, reader.Len()) + break + } + + var reqId uint64 + binary.Read(reader, binary.LittleEndian, &reqId) + var kind uint8 + binary.Read(reader, binary.LittleEndian, &kind) + + req, ok := cm.requestsById[reqId] + if !ok { + rawResp.offset = rawResp.respLen - reader.Len() - 9 + + cm.earlyRequests[reqId] = *rawResp + + releaseBuf = false + return + } + + if req.shard < 0 { // CDC + if rawResp.protocol != msgs.CDC_RESP_PROTOCOL_VERSION { + log.RaiseAlert("got bad cdc protocol %v for request id %v, ignoring", rawResp.protocol, req.requestId) + break + } + } else { + if rawResp.protocol != msgs.SHARD_RESP_PROTOCOL_VERSION { + log.RaiseAlert("got bad shard protocol %v for request id %v, shard %v, ignoring", rawResp.protocol, req.shard, req.requestId) + break + } + } + + delete(cm.requestsById, req.requestId) + heap.Remove(&cm.requestsByTimeout, req.index) + + if kind == msgs.ERROR { + var errCode uint16 + var err error + if binary.Read(reader, binary.LittleEndian, &errCode) != nil { + log.RaiseAlert("bad error response length for request id %v", reqId) + err = msgs.MALFORMED_RESPONSE + } else { + err = msgs.TernError(errCode) + } req.respCh <- &metadataProcessorResponse{ - requestId: req.requestId, - err: msgs.MALFORMED_RESPONSE, + requestId: reqId, + err: err, extra: req.extra, resp: nil, } - return - } - } else { - expectedKind := req.req.(msgs.ShardRequest).ShardRequestKind() - if uint8(expectedKind) != rawResp.kind { - log.RaiseAlert("got bad shard kind %v for request id %v, shard %v, expected %v", msgs.ShardMessageKind(rawResp.kind), req.requestId, req.shard, expectedKind) + } else { + // check kind + if req.shard < 0 { // CDC + expectedKind := req.req.(msgs.CDCRequest).CDCRequestKind() + if uint8(expectedKind) != kind { + log.RaiseAlert("got bad cdc kind %v for request id %v, expected %v", msgs.CDCMessageKind(rawResp.kind), req.requestId, expectedKind) + req.respCh <- &metadataProcessorResponse{ + requestId: reqId, + err: msgs.MALFORMED_RESPONSE, + extra: req.extra, + resp: nil, + } + break + } + } else { + expectedKind := req.req.(msgs.ShardRequest).ShardRequestKind() + if uint8(expectedKind) != kind { + log.RaiseAlert("got bad shard kind %v for request id %v, shard %v, expected %v", msgs.ShardMessageKind(rawResp.kind), req.requestId, req.shard, expectedKind) + req.respCh <- &metadataProcessorResponse{ + requestId: reqId, + err: msgs.MALFORMED_RESPONSE, + extra: req.extra, + resp: nil, + } + break + } + } + + // unpack + if err := req.resp.Unpack(reader); err != nil { + log.RaiseAlert("could not unpack resp %T for request id %v, shard %v: %v", req.resp, req.requestId, req.shard, err) + req.respCh <- &metadataProcessorResponse{ + requestId: req.requestId, + err: err, + extra: req.extra, + resp: nil, + } + return + } + log.Debug("received resp %T %v req id %v from shard %v", req.resp, req.resp, req.requestId, req.shard) + // done req.respCh <- &metadataProcessorResponse{ requestId: req.requestId, - err: msgs.MALFORMED_RESPONSE, + err: nil, extra: req.extra, - resp: nil, + resp: req.resp, } - return - } - } - // unpack - if err := bincode.Unpack((*rawResp.buf)[4+8+1:rawResp.respLen], req.resp); err != nil { - log.RaiseAlert("could not unpack resp %T for request id %v, shard %v: %v", req.resp, req.requestId, req.shard, err) - req.respCh <- &metadataProcessorResponse{ - requestId: req.requestId, - err: err, - extra: req.extra, - resp: nil, } - return - } - log.Debug("received resp %T %v req id %v from shard %v", req.resp, req.resp, req.requestId, req.shard) - // done - req.respCh <- &metadataProcessorResponse{ - requestId: req.requestId, - err: nil, - extra: req.extra, - resp: req.resp, - } } } func (cm *clientMetadata) processRawResponse(log *log.Logger, rawResp *rawMetadataResponse) { if rawResp.buf != nil { - if req, found := cm.requestsById[rawResp.requestId]; found { - // Common case, the request is already there - cm.parseResponse(log, req, rawResp, true) - } else { - // Uncommon case, the request is missing. In this (rare) case - // we still discharge the buffer immediately so that it's already - // available for use - buf := make([]byte, clientMtu) - select { - case cm.responsesBufs <- &buf: - default: - panic(fmt.Errorf("impossible: could not return buffer")) - } - cm.earlyRequests[rawResp.requestId] = *rawResp - } + cm.handlePacket(log, rawResp) } now := time.Now() // expire requests @@ -507,6 +524,13 @@ func (cm *clientMetadata) processRawResponse(log *log.Logger, rawResp *rawMetada for reqId, rawReq := range cm.earlyRequests { if now.Sub(rawReq.receivedAt) > 10*time.Minute { delete(cm.earlyRequests, reqId) + + // release buffer + select { + case cm.responsesBufs <- rawReq.buf: + default: + panic(fmt.Errorf("impossible: could not put back response buffer which we got from socket drainer")) + } } } } @@ -531,7 +555,16 @@ func (cm *clientMetadata) processResponses(log *log.Logger) { case req := <-cm.inFlight: if rawResp, found := cm.earlyRequests[req.requestId]; found { // uncommon case: we have a response for this already. - cm.parseResponse(log, req, &rawResp, false) + // remove from earlyRequests, handlePacket will take ownership + delete(cm.earlyRequests, req.requestId) + + if _, found := cm.requestsById[req.requestId]; found { + heap.Remove(&cm.requestsByTimeout, req.index) + } + cm.requestsById[req.requestId] = req + heap.Push(&cm.requestsByTimeout, req) + + cm.handlePacket(log, &rawResp) } else { // common case: we don't have the response yet, put it in the data structures and wait. // if the request was there before, we remove it from the heap so that we don't have @@ -584,6 +617,7 @@ func (cm *clientMetadata) drainSocket(log *log.Logger) { protocol: binary.LittleEndian.Uint32(*buf), requestId: binary.LittleEndian.Uint64((*buf)[4:]), kind: (*buf)[4+8], + offset: 4, } cm.rawResponses <- rawResp }