diff --git a/.gitignore b/.gitignore index 16257c2..a3f4a9b 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,10 @@ poc/ # Dev .vscode +openspec/ +.cache +.codex +openspec/ +docs/roadmap.md +docs/engineering-guidelines.md + diff --git a/config/config.yaml b/config/config.yaml index 222deb7..69c64ee 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -22,5 +22,9 @@ producers: auth: auth channel: test -conn_timeout: 45 -max_tcp_payload: 4096 +conn_timeout: 45 # idle I/O timeout in seconds for established connections. +max_tcp_payload: 4096 # bytes +dial_timeout: 5 # timeout in seconds for proxy target connection. + +capture_traffic: + enabled: false \ No newline at end of file diff --git a/config/rules.yaml b/config/rules.yaml index 6b5bb4c..d991241 100644 --- a/config/rules.yaml +++ b/config/rules.yaml @@ -35,6 +35,12 @@ rules: - match: tcp dst port 27017 type: conn_handler target: mongodb + - match: tcp dst port 9889 + type: proxy_tcp + target: 127.0.0.1:9889 + - match: tcp dst port 3306 + type: proxy_tcp + target: 127.0.0.1:3306 - match: tcp type: conn_handler target: tcp diff --git a/docs/configuration.md b/docs/configuration.md index 7d036a1..c5d711c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -13,8 +13,10 @@ This file holds the core settings for Glutton. Key configuration options include - **udp:** The UDP port for intercepted packets (default: `5001`). - **ssh:** Typically excluded from redirection to avoid interfering with SSH (default: `22`). - **interface:** The network interface Glutton listens on (default: `eth0`). -- **max_tcp_payload:** Maximum TCP payload size in bytes (default: `4096`). -- **conn_timeout:** The connection timeout duration in seconds (default: `45`). +- **conn_timeout:** Idle I/O timeout, in seconds, for established connections (default: `45`). +- **max_tcp_payload:** Maximum TCP payload size in bytes (default: `4096`). Proxy TCP uses this as the per-direction captured payload cap. +- **dial_timeout:** Timeout, in seconds, for opening outbound proxy TCP target connections (default: `5`). +- **capture_traffic.enabled:** Enables raw payload capture in logs and produced decoded events. When disabled, proxy TCP still forwards traffic and logs metadata, but raw payload bytes are omitted from decoded events. - **confpath:** The directory path where the configuration file resides. - **producers:** - **enabled**: Boolean flag to enable or disable logging/producer functionality. @@ -55,6 +57,10 @@ producers: conn_timeout: 45 max_tcp_payload: 4096 +dial_timeout: 5 + +capture_traffic: + enabled: false ``` ### config/rules.yaml @@ -63,8 +69,8 @@ This file defines the rules that Glutton uses to determine which protocol handle Key elements include: -- **type**: `conn_handler` to pass off to the appropriate protocol handler or `drop` to ignore packets. -- **target**: Indicates the protocol handler (e.g., "http", "ftp") to be used. +- **type**: `conn_handler` to pass off to the appropriate protocol handler, `proxy_tcp` to forward the TCP connection to an upstream target, or `drop` to ignore packets. +- **target**: For `conn_handler`, indicates the protocol handler (e.g., `http`, `ftp`) to use. For `proxy_tcp`, this must be the upstream target in `host:port` form. - **match**: Define criteria such as source IP ranges or destination ports to match incoming traffic, according to [BPF syntax](https://biot.com/capstats/bpf.html). Example rule: @@ -80,8 +86,14 @@ rules: - match: tcp dst port 6969 type: drop # drops any matching packets target: bittorrent + - name: Proxy TCP example + match: tcp dst port 9889 + type: proxy_tcp + target: 127.0.0.1:9889 ``` +`proxy_tcp` dials the configured `target` and forwards bytes in both directions between the incoming connection and the upstream service. Produced decoded events use the `proxy_tcp` protocol name and can include one captured payload entry per direction. Captured payloads are capped by `max_tcp_payload`; when a direction transfers more bytes than the cap, the decoded event is marked as truncated. + ## Configuration Loading Process Glutton uses the [Viper](https://github.com/spf13/viper) library to load configuration settings. The process works as follows: diff --git a/glutton.go b/glutton.go index c6ef637..74edea3 100644 --- a/glutton.go +++ b/glutton.go @@ -222,10 +222,18 @@ func (g *Glutton) tcpListen() { g.Logger.Error("Failed to set connection timeout", producer.ErrAttr(err)) } - if hfunc, ok := g.tcpProtocolHandlers[rule.Target]; ok { + var handlerName string + switch rule.Type { + case "proxy_tcp": + handlerName = rule.Type + default: + handlerName = rule.Target + } + + if hfunc, ok := g.tcpProtocolHandlers[handlerName]; ok { go func() { if err := hfunc(g.ctx, conn, md); err != nil { - g.Logger.Error("Failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target)) + g.Logger.Error("Failed to handle ", producer.ErrAttr(err), slog.String("handler", handlerName)) } }() } diff --git a/mkdocs.yml b/mkdocs.yml index ecb8baf..c7b691f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -8,4 +8,3 @@ nav: - FAQs: faq.md theme: name: readthedocs - diff --git a/protocols/protocols.go b/protocols/protocols.go index d43b3ca..e9a8b2a 100644 --- a/protocols/protocols.go +++ b/protocols/protocols.go @@ -72,6 +72,9 @@ func MapTCPProtocolHandlers(log interfaces.Logger, h interfaces.Honeypot) map[st protocolHandlers["mongodb"] = func(ctx context.Context, conn net.Conn, md connection.Metadata) error { return tcp.HandleMongoDB(ctx, conn, md, log, h) } + protocolHandlers["proxy_tcp"] = func(ctx context.Context, conn net.Conn, md connection.Metadata) error { + return tcp.HandleProxyTCP(ctx, conn, md, log, h) + } protocolHandlers["tcp"] = func(ctx context.Context, conn net.Conn, md connection.Metadata) error { snip, bufConn, err := Peek(conn, 4) if err != nil { diff --git a/protocols/tcp/proxy_tcp.go b/protocols/tcp/proxy_tcp.go new file mode 100644 index 0000000..50e40da --- /dev/null +++ b/protocols/tcp/proxy_tcp.go @@ -0,0 +1,524 @@ +package tcp + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "log/slog" + "net" + "syscall" + "time" + + "github.com/mushorg/glutton/connection" + "github.com/mushorg/glutton/producer" + "github.com/mushorg/glutton/protocols/interfaces" + "github.com/spf13/viper" +) + +const handlerName = "proxy_tcp" + +// a proxy connection event sent to producers +type event struct { + Direction string `json:"direction,omitempty"` + Payload []byte `json:"payload,omitempty"` + PayloadHash string `json:"payload_hash,omitempty"` // Used for easier identification, can remove + Bytes int64 `json:"bytes,omitempty"` + Truncated bool `json:"truncated,omitempty"` +} + +// holds a proxy connection metadata +type session struct { + source string + target string + producer bool + idleTimeout time.Duration + payloadSize int +} + +// reader wrapps the source connection for idle deadline. +type reader struct { + conn net.Conn + idle time.Duration + name string +} + +func (r reader) Read(p []byte) (int, error) { + if r.idle > 0 { + if err := r.conn.SetReadDeadline(time.Now().Add(r.idle)); err != nil { + return 0, fmt.Errorf("%s set read deadline: %w", r.name, err) + } + } + + n, err := r.conn.Read(p) + // EOF is the normal way a stream says there are no more bytes to read. + if err != nil && !errors.Is(err, io.EOF) { + return n, fmt.Errorf("%s read: %w", r.name, err) + } + return n, err +} + +// writer wraps the destination connection for logging and idle deadline +type writer struct { + conn net.Conn + session *session + logger interfaces.Logger + dir string + name string + + written int64 + payload []byte +} + +func (w *writer) Write(p []byte) (int, error) { + if w.session.idleTimeout > 0 { + if err := w.conn.SetWriteDeadline(time.Now().Add(w.session.idleTimeout)); err != nil { + return 0, fmt.Errorf("%s set write deadline: %w", w.name, err) + } + } + + n, err := w.conn.Write(p) + if err != nil { + w.logger.Debug("proxy writer returned error", logAttrs( + slog.String("function", "writer.Write"), + slog.String("direction", w.dir), + slog.Int("bytes_written", n), + producer.ErrAttr(err), + )...) + return n, fmt.Errorf("%s write: %w", w.name, err) + } + + // A Write can partially succeed. Only p[:n] reached the destination, so only + // those bytes should affect logs, byte counts, and capture samples. + if n > 0 { + written := p[:n] + w.written += int64(n) + w.session.logPayload(w.dir, written, w.logger) + w.storePayload(written) + } + + if n != len(p) { + return n, fmt.Errorf("%s short write: %w", w.name, io.ErrShortWrite) + } + + return n, nil +} + +// emits a structured log for connection metadata including raw payload +func (s *session) logPayload(direction string, data []byte, logger interfaces.Logger) { + if len(data) == 0 { + return + } + + // always log transfer metadata excluding raw service data. + fields := logAttrs( + slog.String("direction", direction), + slog.Int("length", len(data)), + slog.String("payload_hash", fmt.Sprintf("%x", sha256.Sum256(data))), + ) + + // when caputure enabled then include raw payload data to logs + if s.producer && s.payloadSize > 0 { + sample := data + truncated := false + if len(sample) > s.payloadSize { + sample = sample[:s.payloadSize] + truncated = true + } + + if isLikelyText(sample) { + fields = append(fields, slog.String("payload", string(sample))) + } else { + fields = append(fields, slog.String("hex", hex.EncodeToString(sample))) + } + fields = append(fields, slog.Bool("payload_truncated", truncated)) + } + + logger.Info("proxy_tcp payload_transferred", fields...) +} + +// storePayload stores a bounded sample of written bytes for producer output. +func (w *writer) storePayload(p []byte) { + if !w.session.producer || w.session.payloadSize <= 0 { + return + } + + if len(w.payload) >= w.session.payloadSize { + return + } + + remaining := w.session.payloadSize - len(w.payload) + if len(p) > remaining { + p = p[:remaining] + } + w.payload = append(w.payload, p...) +} + +// event converts the stored payload writer-owned capture sample into a producer event. +func (w *writer) event() *event { + if !w.session.producer || len(w.payload) == 0 { + return nil + } + + payload := append([]byte(nil), w.payload...) + hash := sha256.Sum256(payload) + return &event{ + Direction: w.dir, + Payload: payload, + PayloadHash: fmt.Sprintf("%x", hash[:]), + Bytes: w.written, + Truncated: w.written > int64(len(payload)), + } +} + +// logAttrs adds common structured log fields to every proxy_tcp log. +func logAttrs(fields ...any) []any { + base := []any{ + slog.String("handler", handlerName), + } + return append(base, fields...) +} + +// isLikelyText checks whether a byte slice is mostly printable text. +func isLikelyText(data []byte) bool { + if len(data) == 0 { + return false + } + + printable := 0 + for _, b := range data { + if b >= 32 && b <= 126 || b == '\n' || b == '\r' || b == '\t' { + printable++ + } + } + + // Require more than 80% printable bytes so mostly-binary payloads stay as hex. + return (printable*100)/len(data) > 80 +} + +// pipeResult is the completion report from one directional pipe. +type pipeResult struct { + dir string + bytes int64 + event *event + err error +} + +// pipe copies bytes in one direction and sends a pipeResult when that direction ends. +func pipe(done chan<- pipeResult, dst, src net.Conn, session *session, logger interfaces.Logger) { + dir := getDirection(src, dst) + writer := &writer{ + conn: dst, + session: session, + logger: logger, + dir: dir, + name: dir + " dst", + } + + reader := reader{ + conn: src, + idle: session.idleTimeout, + name: dir + " src", + } + + _, err := io.Copy(writer, reader) + + // Tell the destination peer there will be no more bytes from this direction. + if closeErr := finishWriteSide(dst, logger, dir); closeErr != nil && err == nil { + err = fmt.Errorf("%s close write: %w", dir, closeErr) + } + + // Stop reading from the source side after this direction has completed. + if closeErr := finishReadSide(src, logger, dir); closeErr != nil && err == nil { + err = fmt.Errorf("%s close read: %w", dir, closeErr) + } + + done <- pipeResult{ + dir: dir, + bytes: writer.written, + event: writer.event(), + err: err, + } +} + +// pipeBothWays starts proxy connection between client and target +func pipeBothWays(client, target net.Conn, session *session, logger interfaces.Logger) []pipeResult { + logger.Debug("starting proxy bidirectional copy", logAttrs( + slog.String("function", "pipeBothWays"), + slog.String("source", session.source), + slog.String("target", session.target), + )...) + + done := make(chan pipeResult, 2) + go pipe(done, target, client, session, logger) + go pipe(done, client, target, session, logger) + + // Wait for both directions before returning + return []pipeResult{<-done, <-done} +} + +// eventsFromResults extracts captured producer events from completed pipe results. +func eventsFromResults(results []pipeResult) []event { + events := make([]event, 0, len(results)) + for _, result := range results { + if result.event != nil { + events = append(events, *result.event) + } + } + return events +} + +func getDirection(src, dst net.Conn) string { + return fmt.Sprintf("%s -> %s", src.RemoteAddr().String(), dst.RemoteAddr().String()) +} + +// logResult records the outcome from one directional pipe. +func logResult(logger interfaces.Logger, result pipeResult) { + fields := logAttrs( + slog.String("function", "logResult"), + slog.String("direction", result.dir), + slog.Int64("bytes", result.bytes), + ) + + // Include capture metadata without logging raw bytes again. + if result.event != nil { + fields = append(fields, + slog.Int("captured_bytes", len(result.event.Payload)), + slog.String("payload_hash", result.event.PayloadHash), + slog.Bool("truncated", result.event.Truncated), + ) + } + + // Attach the copy error before deciding the log level. + if result.err != nil { + fields = append(fields, producer.ErrAttr(result.err)) + } + + if expectedPipeError(result.err) { + logger.Debug("proxy pipe completed", fields...) + } else { + logger.Error("proxy pipe failed", fields...) + } +} + +// expectedPipeError classifies normal network shutdown results from io.Copy. +func expectedPipeError(err error) bool { + if err == nil { + return true + } + + // EOF and "use of closed network connection" are expected when scanners or + // targets close sockets during normal request/response exchanges. + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.ENOTCONN) { + return true + } + + // Deadline timeouts can be created by reader or shutdown fallbacks + var nerr net.Error + return errors.As(err, &nerr) && nerr.Timeout() +} + +type closeWriter interface { + CloseWrite() error +} + +type closeReader interface { + CloseRead() error +} + +// Use for used for half-closing a write side, if unavailable, it uses a deadline +func finishWriteSide(conn net.Conn, logger interfaces.Logger, dir string) error { + if cw, ok := conn.(closeWriter); ok { + logger.Debug("closing proxy write side", logAttrs( + slog.String("function", "finishWriteSide"), + slog.String("direction", dir), + slog.String("address", conn.RemoteAddr().String()), + )...) + if err := cw.CloseWrite(); err != nil && !errors.Is(err, net.ErrClosed) { + return err + } + return nil + } + + // Non-TCP net.Conn implementations may not support half-close. + logger.Debug("setting proxy write shutdown deadline", logAttrs( + slog.String("function", "finishWriteSide"), + slog.String("direction", dir), + slog.String("address", conn.RemoteAddr().String()), + )...) + return conn.SetDeadline(time.Now().Add(2 * time.Second)) +} + +// Use for used for half-closing a read side, if unavailable, it uses a deadline +func finishReadSide(conn net.Conn, logger interfaces.Logger, dir string) error { + if cr, ok := conn.(closeReader); ok { + logger.Debug("closing proxy read side", logAttrs( + slog.String("function", "finishReadSide"), + slog.String("direction", dir), + slog.String("address", conn.RemoteAddr().String()), + )...) + if err := cr.CloseRead(); err != nil && !errors.Is(err, net.ErrClosed) { + return err + } + return nil + } + + // Non-TCP net.Conn implementations may not support CloseRead. + logger.Debug("setting proxy read shutdown deadline", logAttrs( + slog.String("function", "finishReadSide"), + slog.String("direction", dir), + slog.String("address", conn.RemoteAddr().String()), + )...) + return conn.SetReadDeadline(time.Now().Add(2 * time.Second)) +} + +// It is best-effort to enables TCP keepalive for real TCP connections +func setKeepAlive(conn net.Conn, logger interfaces.Logger, name string) { + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return + } + if err := tcpConn.SetKeepAlive(true); err != nil { + logger.Debug("failed to enable proxy keepalive", logAttrs( + slog.String("function", "setKeepAlive"), + slog.String("name", name), + producer.ErrAttr(err), + )...) + } +} + +func stopProxyOnCancel(ctx context.Context, client, target net.Conn, logger interfaces.Logger) func() { + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + logger.Debug("closing proxy connections after context cancellation", logAttrs( + slog.String("function", "stopProxyOnCancel"), + producer.ErrAttr(ctx.Err()), + )...) + closeProxyConn(client, logger, "client") + closeProxyConn(target, logger, "target") + case <-done: + } + }() + + return func() { + close(done) + } +} + +func closeProxyConn(conn net.Conn, logger interfaces.Logger, name string) { + if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + logger.Debug("failed to close proxy connection", logAttrs( + slog.String("function", "closeProxyConn"), + slog.String("name", name), + producer.ErrAttr(err), + )...) + } +} + +func HandleProxyTCP(ctx context.Context, conn net.Conn, md connection.Metadata, logger interfaces.Logger, h interfaces.Honeypot) error { + srcAddr := conn.RemoteAddr().String() + + logger.Debug("entered proxy handler", logAttrs( + slog.String("function", "HandleProxyTCP"), + slog.String("source", srcAddr), + )...) + defer logger.Debug("leaving proxy handler", logAttrs( + slog.String("function", "HandleProxyTCP"), + slog.String("source", srcAddr), + )...) + + defer func() { + if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + logger.Error("failed to close incoming connection", logAttrs( + slog.String("function", "HandleProxyTCP"), + producer.ErrAttr(err), + )...) + } + }() + + // If missing metadata, close without panick + if md.Rule == nil { + logger.Error("missing proxy_tcp rule metadata", logAttrs( + slog.String("function", "HandleProxyTCP"), + )...) + return nil + } + + // If it is missing here, the handler cannot safely dial anything + if md.Rule.ProxyTarget == nil || md.Rule.ProxyTarget.DialAddress == "" { + logger.Error("missing proxy_tcp target metadata", logAttrs( + slog.String("function", "HandleProxyTCP"), + )...) + return nil + } + destAddr := md.Rule.ProxyTarget.DialAddress + + session := &session{ + source: srcAddr, + target: destAddr, + producer: viper.GetBool("capture_traffic.enabled"), + idleTimeout: time.Duration(viper.GetInt("conn_timeout")) * time.Second, + payloadSize: viper.GetInt("max_tcp_payload"), + } + + var results []pipeResult + + // capture is enabled, produces one final proxy_tcp event after connection closes. + defer func() { + var events []event + if session.producer { + events = eventsFromResults(results) + } + if err := h.ProduceTCP("proxy_tcp", conn, md, nil, events); err != nil { + logger.Error("failed to produce proxy_tcp message", logAttrs( + slog.String("function", "HandleProxyTCP"), + producer.ErrAttr(err), + )...) + } + }() + + // Enable keepalive for the client side + setKeepAlive(conn, logger, "client") + + dialerTimeout := time.Duration(viper.GetInt("dial_timeout")) * time.Second + dialer := net.Dialer{Timeout: dialerTimeout} + targetConn, err := dialer.DialContext(ctx, "tcp", destAddr) + if err != nil { + logger.Error("failed to connect to the target", logAttrs( + slog.String("function", "HandleProxyTCP"), + slog.String("target", destAddr), + producer.ErrAttr(err), + )...) + return nil + } + defer targetConn.Close() + stopCancelWatcher := stopProxyOnCancel(ctx, conn, targetConn, logger) + defer stopCancelWatcher() + + // Enable keepalive for the target side + setKeepAlive(targetConn, logger, "target") + + // At this point both sockets are open. The next step is full-duplex copying. + logger.Debug("starting proxy tcp", logAttrs( + slog.String("function", "HandleProxyTCP"), + slog.String("source", srcAddr), + slog.String("target", destAddr), + slog.Duration("idle_timeout", session.idleTimeout), + slog.Int("payload_size", session.payloadSize), + )...) + + // pipeBothWays waits for the success / failure of proxy session and return results + results = pipeBothWays(conn, targetConn, session, logger) + for _, result := range results { + logResult(logger, result) + } + + logger.Debug("proxy tcp completed successfully", logAttrs( + slog.String("function", "HandleProxyTCP"), + )...) + return nil +} diff --git a/protocols/tcp/proxytcp_test.go b/protocols/tcp/proxytcp_test.go new file mode 100644 index 0000000..97e4311 --- /dev/null +++ b/protocols/tcp/proxytcp_test.go @@ -0,0 +1,402 @@ +package tcp + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "sync" + "syscall" + "testing" + "time" + + "github.com/mushorg/glutton/connection" + "github.com/mushorg/glutton/rules" + "github.com/spf13/viper" + "github.com/stretchr/testify/require" +) + +type recordingLogger struct { + mtx sync.Mutex + infos []string + debugs []string + errs []string + warns []string + fields []any +} + +func (l *recordingLogger) Info(msg string, fields ...any) { + l.record(&l.infos, msg, fields...) +} + +func (l *recordingLogger) Debug(msg string, fields ...any) { + l.record(&l.debugs, msg, fields...) +} + +func (l *recordingLogger) Error(msg string, fields ...any) { + l.record(&l.errs, msg, fields...) +} + +func (l *recordingLogger) Warn(msg string, fields ...any) { + l.record(&l.warns, msg, fields...) +} + +func (l *recordingLogger) record(target *[]string, msg string, fields ...any) { + l.mtx.Lock() + defer l.mtx.Unlock() + + *target = append(*target, msg) + l.fields = append(l.fields, fields...) +} + +func (l *recordingLogger) hasAttr(key, value string) bool { + l.mtx.Lock() + defer l.mtx.Unlock() + + for _, field := range l.fields { + attr, ok := field.(slog.Attr) + if !ok { + continue + } + if attr.Key == key && attr.Value.String() == value { + return true + } + } + return false +} + +type producedTCP struct { + protocol string + decoded interface{} +} + +type fakeHoneypot struct { + produced chan producedTCP +} + +func newFakeHoneypot() *fakeHoneypot { + return &fakeHoneypot{produced: make(chan producedTCP, 4)} +} + +func (h *fakeHoneypot) ProduceTCP(protocol string, conn net.Conn, md connection.Metadata, payload []byte, decoded interface{}) error { + h.produced <- producedTCP{protocol: protocol, decoded: decoded} + return nil +} + +func (h *fakeHoneypot) ProduceUDP(handler string, srcAddr, dstAddr *net.UDPAddr, md connection.Metadata, payload []byte, decoded interface{}) error { + return nil +} + +func (h *fakeHoneypot) ConnectionByFlow([2]uint64) connection.Metadata { + return connection.Metadata{} +} + +func (h *fakeHoneypot) UpdateConnectionTimeout(context.Context, net.Conn) error { + return nil +} + +func (h *fakeHoneypot) MetadataByConnection(net.Conn) (connection.Metadata, error) { + return connection.Metadata{}, nil +} + +func setCapture(t *testing.T, enabled bool) { + t.Helper() + + previousCapture := viper.Get("capture_traffic.enabled") + previousMaxPayload := viper.Get("max_tcp_payload") + previousTimeout := viper.Get("conn_timeout") + viper.Set("capture_traffic.enabled", enabled) + viper.Set("max_tcp_payload", 4096) + viper.Set("conn_timeout", 0) + t.Cleanup(func() { + viper.Set("capture_traffic.enabled", previousCapture) + viper.Set("max_tcp_payload", previousMaxPayload) + viper.Set("conn_timeout", previousTimeout) + }) +} + +func proxyMetadata(target string) connection.Metadata { + return connection.Metadata{ + Rule: &rules.Rule{ + Type: "proxy_tcp", + ProxyTarget: &rules.ProxyTarget{ + Host: "127.0.0.1", + Port: 0, + DialAddress: target, + }, + }, + } +} + +func startTCPServer(t *testing.T, handle func(net.Conn)) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + _ = listener.Close() + }) + + done := make(chan struct{}) + t.Cleanup(func() { + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("server at %s did not finish", listener.Addr().String()) + } + }) + + go func() { + defer close(done) + + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + handle(conn) + }() + + return listener.Addr().String() +} + +func startProxyServer(t *testing.T, target string, hp *fakeHoneypot, logger *recordingLogger) string { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + _ = listener.Close() + }) + + done := make(chan error, 1) + t.Cleanup(func() { + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatalf("proxy at %s did not finish", listener.Addr().String()) + } + }) + + go func() { + conn, err := listener.Accept() + if err != nil { + done <- nil + return + } + done <- HandleProxyTCP(context.Background(), conn, proxyMetadata(target), logger, hp) + }() + + return listener.Addr().String() +} + +func waitProduced(t *testing.T, hp *fakeHoneypot) producedTCP { + t.Helper() + + select { + case event := <-hp.produced: + return event + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for produced TCP event") + } + return producedTCP{} +} + +func TestHandleProxyAfterCloseWrite(t *testing.T) { + setCapture(t, false) + + targetAddr := startTCPServer(t, func(conn net.Conn) { + request, err := io.ReadAll(conn) + require.NoError(t, err) + require.Equal(t, "request", string(request)) + + _, err = conn.Write([]byte("response")) + require.NoError(t, err) + }) + + hp := newFakeHoneypot() + logger := &recordingLogger{} + proxyAddr := startProxyServer(t, targetAddr, hp, logger) + + client, err := net.Dial("tcp", proxyAddr) + require.NoError(t, err) + defer client.Close() + + _, err = client.Write([]byte("request")) + require.NoError(t, err) + require.NoError(t, client.(*net.TCPConn).CloseWrite()) + + response, err := io.ReadAll(client) + require.NoError(t, err) + require.Equal(t, "response", string(response)) + + event := waitProduced(t, hp) + require.Equal(t, "proxy_tcp", event.protocol) + require.Empty(t, event.decoded) + require.True(t, logger.hasAttr("handler", "proxy_tcp")) +} + +func TestHandleProxyTCP(t *testing.T) { + setCapture(t, true) + + targetAddr := startTCPServer(t, func(conn net.Conn) { + request, err := io.ReadAll(conn) + require.NoError(t, err) + require.Equal(t, "client-payload", string(request)) + + _, err = conn.Write([]byte("target-response")) + require.NoError(t, err) + }) + + hp := newFakeHoneypot() + logger := &recordingLogger{} + proxyAddr := startProxyServer(t, targetAddr, hp, logger) + + client, err := net.Dial("tcp", proxyAddr) + require.NoError(t, err) + defer client.Close() + + _, err = client.Write([]byte("client-payload")) + require.NoError(t, err) + require.NoError(t, client.(*net.TCPConn).CloseWrite()) + + response, err := io.ReadAll(client) + require.NoError(t, err) + require.Equal(t, "target-response", string(response)) + + produced := waitProduced(t, hp) + events, ok := produced.decoded.([]event) + require.True(t, ok) + require.Len(t, events, 2) + + payloads := map[string]bool{} + for _, captured := range events { + payloads[string(captured.Payload)] = true + require.NotEmpty(t, captured.Direction) + require.NotEmpty(t, captured.PayloadHash) + } + require.True(t, payloads["client-payload"]) + require.True(t, payloads["target-response"]) +} + +func TestPartialWriter(t *testing.T) { + writeErr := errors.New("short write") + logger := &recordingLogger{} + session := &session{producer: true, payloadSize: 4096} + writer := &writer{ + conn: testWriteConn{written: 3, err: writeErr}, + session: session, + logger: logger, + dir: "client->target", + name: "client->target dst", + } + + n, err := writer.Write([]byte("abcdef")) + require.ErrorIs(t, err, writeErr) + require.Equal(t, 3, n) + + event := writer.event() + require.Nil(t, event) + require.Zero(t, writer.written) +} + +// test capture writer skipped failed writer +func TestFailedWrites(t *testing.T) { + writeErr := errors.New("write failed") + logger := &recordingLogger{} + session := &session{producer: true, payloadSize: 4096} + writer := &writer{ + conn: testWriteConn{written: 0, err: writeErr}, + session: session, + logger: logger, + dir: "client->target", + name: "client->target dst", + } + + n, err := writer.Write([]byte("abcdef")) + require.ErrorIs(t, err, writeErr) + require.Zero(t, n) + require.Nil(t, writer.event()) + require.Zero(t, writer.written) +} + +// test capture writer returns short write error +func TestShortWriteError(t *testing.T) { + logger := &recordingLogger{} + session := &session{producer: true, payloadSize: 4096} + writer := &writer{ + conn: testWriteConn{written: 3}, + session: session, + logger: logger, + dir: "client->target", + name: "client->target dst", + } + + n, err := writer.Write([]byte("abcdef")) + require.ErrorIs(t, err, io.ErrShortWrite) + require.Equal(t, 3, n) + event := writer.event() + require.Equal(t, "abc", string(event.Payload)) + require.Equal(t, int64(3), event.Bytes) + require.False(t, event.Truncated) +} + +// test capture writer caps captured bytes +func TestWriterByteCaps(t *testing.T) { + logger := &recordingLogger{} + session := &session{producer: true, payloadSize: 4} + writer := &writer{ + conn: testWriteConn{written: 6}, + session: session, + logger: logger, + dir: "client->target", + name: "client->target dst", + } + + n, err := writer.Write([]byte("abcdef")) + require.NoError(t, err) + require.Equal(t, 6, n) + event := writer.event() + require.Equal(t, "abcd", string(event.Payload)) + require.Equal(t, int64(6), event.Bytes) + require.True(t, event.Truncated) +} + +func TestExpectedPipeErrorAllowsNotConnected(t *testing.T) { + err := fmt.Errorf("close read: %w", syscall.ENOTCONN) + + require.True(t, expectedPipeError(err)) +} + +// test proxy handler closes connection on missing metadata +func TestMissingMetadata(t *testing.T) { + setCapture(t, true) + + client, server := net.Pipe() + defer client.Close() + + logger := &recordingLogger{} + err := HandleProxyTCP(context.Background(), server, connection.Metadata{}, logger, newFakeHoneypot()) + require.NoError(t, err) + require.True(t, logger.hasAttr("handler", "proxy_tcp")) + require.True(t, logger.hasAttr("function", "HandleProxyTCP")) + + err = client.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + if err == nil { + _, err = client.Write([]byte("x")) + } + require.Error(t, err) +} + +type testWriteConn struct { + net.Conn + written int + err error +} + +func (c testWriteConn) Write(p []byte) (int, error) { + return c.written, c.err +} diff --git a/rules/rules.go b/rules/rules.go index 0b4d595..0c367ac 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -5,6 +5,7 @@ import ( "io" "net" "strconv" + "strings" "time" "github.com/google/gopacket" @@ -17,6 +18,7 @@ type RuleType int const ( UserConnHandler RuleType = iota + ProxyTCP Drop ) @@ -31,10 +33,17 @@ type Rule struct { Target string `yaml:"target,omitempty"` Name string `yaml:"name,omitempty"` - isInit bool - ruleType RuleType - index int - matcher *pcap.BPF + isInit bool + RuleType RuleType + ProxyTarget *ProxyTarget `yaml:"-"` + index int + matcher *pcap.BPF +} + +type ProxyTarget struct { + Host string + Port uint16 + DialAddress string } func (r *Rule) String() string { @@ -59,13 +68,23 @@ func (rule *Rule) init(idx int) error { switch rule.Type { case "conn_handler": - rule.ruleType = UserConnHandler + rule.RuleType = UserConnHandler + case "proxy_tcp": + rule.RuleType = ProxyTCP case "drop": - rule.ruleType = Drop + rule.RuleType = Drop default: return fmt.Errorf("unknown rule type: %s", rule.Type) } + if rule.RuleType == ProxyTCP { + target, err := parseProxyTarget(rule.Target) + if err != nil { + return fmt.Errorf("invalid proxy_tcp target: %w", err) + } + rule.ProxyTarget = target + } + var err error if len(rule.Match) > 0 { rule.matcher, err = pcap.NewBPF(layers.LinkTypeEthernet, 65535, rule.Match) @@ -80,6 +99,35 @@ func (rule *Rule) init(idx int) error { return nil } +func parseProxyTarget(raw string) (*ProxyTarget, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, fmt.Errorf("target is required") + } + + host, portValue, err := net.SplitHostPort(raw) + if err != nil { + return nil, err + } + if strings.TrimSpace(host) == "" { + return nil, fmt.Errorf("host is required") + } + + port, err := strconv.Atoi(portValue) + if err != nil { + return nil, fmt.Errorf("invalid port %q: %w", portValue, err) + } + if port < 1 || port > 65535 { + return nil, fmt.Errorf("port out of range: %d", port) + } + + return &ProxyTarget{ + Host: host, + Port: uint16(port), + DialAddress: net.JoinHostPort(host, strconv.Itoa(port)), + }, nil +} + func splitAddr(addr string) (string, uint16, error) { ip, port, err := net.SplitHostPort(addr) if err != nil { diff --git a/rules/rules_test.go b/rules/rules_test.go index c885de9..3c83ec2 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -3,6 +3,7 @@ package rules import ( "net" "os" + "strings" "testing" "time" @@ -43,6 +44,97 @@ func TestSplitAddr(t *testing.T) { require.Equal(t, uint16(8080), port) } +func TestParseProxyTarget(t *testing.T) { + tests := []struct { + name string + raw string + wantAddress string + wantErr string + }{ + { + name: "plain target", + raw: "127.0.0.1:9889", + wantAddress: "127.0.0.1:9889", + }, + { + name: "hostname target", + raw: "localhost:9889", + wantAddress: "localhost:9889", + }, + { + name: "empty target", + raw: "", + wantErr: "target is required", + }, + { + name: "scheme not supported", + raw: "tcp://127.0.0.1:9889", + wantErr: "too many colons", + }, + { + name: "missing port", + raw: "127.0.0.1", + wantErr: "missing port", + }, + { + name: "missing host", + raw: ":9889", + wantErr: "host is required", + }, + { + name: "non numeric port", + raw: "127.0.0.1:http", + wantErr: "invalid port", + }, + { + name: "port out of range", + raw: "127.0.0.1:70000", + wantErr: "port out of range", + }, + { + name: "path not supported", + raw: "127.0.0.1:9889/path", + wantErr: "invalid port", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target, err := parseProxyTarget(tt.raw) + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantAddress, target.DialAddress) + }) + } +} + +func TestInitProxyTCPRuleParsesTarget(t *testing.T) { + rules, err := Init(strings.NewReader(`rules: + - match: tcp dst port 9889 + type: proxy_tcp + target: 127.0.0.1:9889 +`)) + require.NoError(t, err) + require.Len(t, rules, 1) + require.NotNil(t, rules[0].ProxyTarget) + require.Equal(t, "127.0.0.1:9889", rules[0].ProxyTarget.DialAddress) +} + +func TestInitProxyTCPRuleRejectsInvalidTarget(t *testing.T) { + _, err := Init(strings.NewReader(`rules: + - match: tcp dst port 9889 + type: proxy_tcp + target: tcp://127.0.0.1:9889 +`)) + require.Error(t, err) + require.Contains(t, err.Error(), "too many colons") +} + func testConn(t *testing.T) (net.Conn, net.Listener) { ln, err := net.Listen("tcp", "127.0.0.1:1234") require.NoError(t, err) @@ -102,6 +194,21 @@ func TestRunMatchUDP(t *testing.T) { require.Equal(t, "test", match.Target) } +func TestRunMatchProxyTCP(t *testing.T) { + rules := parseRules(t) + require.NotEmpty(t, rules) + + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 50000} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 80} + + match, err := rules.Match("tcp", srcAddr, dstAddr) + require.NoError(t, err) + require.NotNil(t, match) + require.Equal(t, "proxy_tcp", match.Type) + require.NotNil(t, match.ProxyTarget) + require.Equal(t, "127.0.0.1:9889", match.ProxyTarget.DialAddress) +} + func TestBPF(t *testing.T) { buf := make([]byte, 65535) bpfi, err := pcap.NewBPF(layers.LinkTypeEthernet, 65535, "icmp") diff --git a/rules/test.yaml b/rules/test.yaml index 38396f2..b8608dc 100644 --- a/rules/test.yaml +++ b/rules/test.yaml @@ -9,6 +9,9 @@ rules: - match: udp dst port 1234 type: conn_handler target: test + - match: tcp dst port 80 + type: proxy_tcp + target: 127.0.0.1:9889 - match: tcp type: conn_handler target: tcp