diff --git a/config/config.yaml b/config/config.yaml index 222deb7..c8c4219 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -24,3 +24,6 @@ producers: conn_timeout: 45 max_tcp_payload: 4096 + +capture_traffic: + enabled: false \ No newline at end of file diff --git a/config/rules.yaml b/config/rules.yaml index 6b5bb4c..88ab2c2 100644 --- a/config/rules.yaml +++ b/config/rules.yaml @@ -35,6 +35,9 @@ rules: - match: tcp dst port 27017 type: conn_handler target: mongodb + - match: tcp dst port 9889 + type: tcp_proxy + target: 127.0.0.1:9889 # Can use hostip:port for the required destination. - match: tcp type: conn_handler target: tcp diff --git a/glutton.go b/glutton.go index c6ef637..61e58e1 100644 --- a/glutton.go +++ b/glutton.go @@ -222,12 +222,22 @@ func (g *Glutton) tcpListen() { g.Logger.Error("Failed to set connection timeout", producer.ErrAttr(err)) } - if hfunc, ok := g.tcpProtocolHandlers[rule.Target]; ok { - go func() { - if err := hfunc(g.ctx, conn, md); err != nil { - g.Logger.Error("Failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target)) - } - }() + if rule.Type == "tcp_proxy" { + if hfunc, ok := g.tcpProtocolHandlers[rule.Type]; ok { + go func() { + if err := hfunc(g.ctx, conn, md); err != nil { + g.Logger.Error("Failed to handle TCP passthrough", producer.ErrAttr(err), slog.String("handler", "tcp_proxy")) + } + }() + } + } else { + if hfunc, ok := g.tcpProtocolHandlers[rule.Target]; ok { + go func() { + if err := hfunc(g.ctx, conn, md); err != nil { + g.Logger.Error("Failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target)) + } + }() + } } } } diff --git a/protocols/protocols.go b/protocols/protocols.go index d43b3ca..7e8f6ba 100644 --- a/protocols/protocols.go +++ b/protocols/protocols.go @@ -72,6 +72,9 @@ func MapTCPProtocolHandlers(log interfaces.Logger, h interfaces.Honeypot) map[st protocolHandlers["mongodb"] = func(ctx context.Context, conn net.Conn, md connection.Metadata) error { return tcp.HandleMongoDB(ctx, conn, md, log, h) } + protocolHandlers["tcp_proxy"] = func(ctx context.Context, conn net.Conn, md connection.Metadata) error { + return tcp.HandlePassThrough(ctx, conn, md, log, h) + } protocolHandlers["tcp"] = func(ctx context.Context, conn net.Conn, md connection.Metadata) error { snip, bufConn, err := Peek(conn, 4) if err != nil { diff --git a/protocols/tcp/passthrough.go b/protocols/tcp/passthrough.go new file mode 100644 index 0000000..232b9a7 --- /dev/null +++ b/protocols/tcp/passthrough.go @@ -0,0 +1,200 @@ +package tcp + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "log" + "log/slog" + "net" + "time" + + "github.com/mushorg/glutton/connection" + "github.com/mushorg/glutton/producer" + "github.com/mushorg/glutton/protocols/interfaces" + "github.com/spf13/viper" +) + +type parsedPassThrough struct { + Direction string `json:"direction,omitempty"` + Payload []byte `json:"payload,omitempty"` + PayloadHash string `json:"payload_hash,omitempty"` // Used for easier identification, can remove +} + +type passThroughServer struct { + events []parsedPassThrough + conn net.Conn + target string + source string +} + +type loggingWriter struct { + dst net.Conn + server *passThroughServer + logger interfaces.Logger + capture bool + dir string +} + +func (lw *loggingWriter) Write(p []byte) (int, error) { + lw.server.logPayload(lw.dir, p, lw.logger) + lw.server.recordEvent(lw.dir, p, lw.capture) + return lw.dst.Write(p) +} + +// checks whether the payload can be converted to text, to prevent expensive hex coding. +func (srv *passThroughServer) isLikelyText(data []byte) bool { + if len(data) == 0 { + return false + } + + printable := 0 + for _, b := range data { + if b >= 32 && b <= 126 || b == '\n' || b == '\r' || b == '\t' { + printable++ + } + } + + return (printable*100)/len(data) > 80 // threshold value --> 80% +} + +// logs the payload hex or payload text. +func (srv *passThroughServer) logPayload(direction string, data []byte, logger interfaces.Logger) { + if len(data) == 0 { + return + } + + fields := []any{ + slog.String("direction", direction), + slog.Int("length", len(data)), + slog.String("sha256", fmt.Sprintf("%x", sha256.Sum256(data))), + } + + if srv.isLikelyText(data) { + fields = append(fields, slog.String("payload", string(data))) + } else { + fields = append(fields, slog.String("hex", hex.EncodeToString(data))) + } + + logger.Info("payload_transferred", fields...) +} + +// records the events in the server +func (srv *passThroughServer) recordEvent(dir string, buf []byte, capture bool) { + if !capture { + return + } + hash := sha256.Sum256(buf) + + payload := append([]byte(nil), buf...) // defensive copy + + srv.events = append(srv.events, parsedPassThrough{ + Direction: dir, + Payload: payload, + PayloadHash: fmt.Sprintf("%x", hash[:]), + }) +} + +// pipeBidirectional handles data transfer between the two connections +func pipeBidirectional(src, dst net.Conn, server *passThroughServer, logger interfaces.Logger, capture bool, errChan chan error) { + direction := getDirection(src, dst) + writer := &loggingWriter{dst: dst, server: server, logger: logger, capture: capture, dir: direction} + + // source to target + go func() { + _, err := io.Copy(writer, src) + errChan <- err + }() + + revDirection := getDirection(dst, src) + revWriter := &loggingWriter{dst: src, server: server, logger: logger, capture: capture, dir: revDirection} + + // target to source + go func() { + _, err := io.Copy(revWriter, dst) + errChan <- err + }() +} + +// getDirection returns the direction as a string +func getDirection(src, dst net.Conn) string { + srcAddr := src.RemoteAddr().String() + dstAddr := dst.RemoteAddr().String() + return fmt.Sprintf("%s -> %s", srcAddr, dstAddr) +} + +// Dial to the source ip, acting as a proxy between the client and real source by piping the data back and forth w/o interfering w it. +func HandlePassThrough(ctx context.Context, conn net.Conn, md connection.Metadata, logger interfaces.Logger, h interfaces.Honeypot) error { + var err error + handler := "tcp_proxy" + + srcAddr := conn.RemoteAddr().String() + destAddr := md.Rule.Target + + host, _, err := net.SplitHostPort(destAddr) + if err != nil { + logger.Error("invalid address format", producer.ErrAttr(err)) + return nil + } + + if ip := net.ParseIP(host); ip == nil { + if _, err := net.LookupHost(host); err != nil { + return fmt.Errorf("invalid host: %w", err) + } + } + + server := &passThroughServer{ + events: []parsedPassThrough{}, + conn: conn, + target: destAddr, + source: srcAddr, + } + + var capture bool + if viper.GetBool("capture_traffic.enabled") { + capture = true + } + + defer func() { + var events []parsedPassThrough + if capture { + events = server.events + } + if err := h.ProduceTCP("passthrough", conn, md, nil, events); err != nil { + logger.Error("failed to produce passthrough message", producer.ErrAttr(err)) + } + if err := conn.Close(); err != nil { + logger.Error("failed to close incoming connection", slog.String("handler", handler), producer.ErrAttr(err)) + } + }() + + if destAddr == "" { + logger.Error("no target defined", slog.String("handler", handler)) + return nil + } + + timeout := 5 * time.Second + + targetConn, err := net.DialTimeout("tcp", destAddr, timeout) + if err != nil { + logger.Error("failed to connect to the target", slog.String("handler", handler), slog.String("target", string(destAddr)), producer.ErrAttr(err)) + return nil + } + defer targetConn.Close() + + logger.Info("starting passthrough", slog.String("source", srcAddr), slog.String("target", string(destAddr)), slog.String("handler", handler)) + + errChan := make(chan error, 2) + + go pipeBidirectional(conn, targetConn, server, logger, capture, errChan) + + // wait for either side to close + if err := <-errChan; err != nil { + log.Printf("connection closed: %v", err) + } + + logger.Info("Passthrough completed successfully") + return nil +} diff --git a/protocols/tcp/passthrough_test.go b/protocols/tcp/passthrough_test.go new file mode 100644 index 0000000..4f5569f --- /dev/null +++ b/protocols/tcp/passthrough_test.go @@ -0,0 +1,254 @@ +package tcp + +import ( + "context" + "crypto/rand" + "io" + "net" + "testing" + + "github.com/mushorg/glutton/protocols/interfaces" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type MockLogger struct { + mock.Mock +} + +func (m *MockLogger) Info(msg string, attrs ...interface{}) { + m.Called(msg, attrs) +} + +func (m *MockLogger) Debug(msg string, attrs ...interface{}) { + m.Called(msg, attrs) +} + +func (m *MockLogger) Error(msg string, attrs ...interface{}) { + m.Called(msg, attrs) +} + +func (m *MockLogger) Warn(msg string, attrs ...interface{}) { + m.Called(msg, attrs) +} + +func TestIsLikelyText(t *testing.T) { + tests := []struct { + name string + input []byte + expected bool + }{ + { + name: "Empty input", + input: []byte(""), + expected: false, + }, + { + name: "Simple ASCII text", + input: []byte("This is plain text"), + expected: true, + }, + { + name: "Text with whitespace", + input: []byte("Text with\nnewlines\tand tabs\r\n"), + expected: true, + }, + { + name: "Binary data", + input: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + expected: false, + }, + { + name: "Mixed content with few non-printable", + input: []byte("Text\x00with\x01binary"), + expected: true, // checking threshold at 85.7% + }, + { + name: "Exactly 80% printable", + input: []byte("AAAA\x01"), // 4/5 = 80% + expected: false, + }, + { + name: "Just below 80% printable", + input: []byte("AAA\x01\x02"), // 3/5 = 60% + expected: false, + }, + } + + srv := &passThroughServer{} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := srv.isLikelyText(test.input) + require.Equal(t, test.expected, result, "unexpected result for test case: %s", test.name) + }) + } +} + +func TestRecording(t *testing.T) { + s := &passThroughServer{} + s.recordEvent("test", []byte("data"), true) + assert.Len(t, s.events, 1) +} + +func TestPipeBidirectional(t *testing.T) { + mockLogger := &MockLogger{} + mockLogger.On("Info", mock.Anything, mock.Anything).Return() + + mockServer := &passThroughServer{ + events: make([]parsedPassThrough, 0), + conn: nil, + target: "test-target:1234", + source: "test-source:5678", + } + + type args struct { + ctx context.Context + src net.Conn + dst net.Conn + server *passThroughServer + logger interfaces.Logger + capture bool + errChan chan error + } + tests := []struct { + name string + args args + setup func() (net.Conn, net.Conn) + wantErr bool + wantErrType error + verify func(t *testing.T, args args) + }{ + { + name: "successful data transfer with capture", + args: args{ + ctx: context.Background(), + server: mockServer, + logger: mockLogger, + capture: true, + errChan: make(chan error, 1), + }, + + setup: func() (net.Conn, net.Conn) { + client, server := net.Pipe() + go func() { + client.Write([]byte("test data")) + client.Close() + }() + return client, server + }, + verify: func(t *testing.T, args args) { + buf := make([]byte, 1024) + n, err := args.dst.Read(buf) + + require.NoError(t, err) + assert.Equal(t, "test data", string(buf[:n])) + + require.True(t, args.capture, "Capture should be enabled") + }, + }, + { + name: "read error from source", + args: args{ + ctx: context.Background(), + server: mockServer, + logger: mockLogger, + capture: false, + errChan: make(chan error, 1), + }, + setup: func() (net.Conn, net.Conn) { + client, server := net.Pipe() + client.Close() + return client, server + }, + wantErr: true, + wantErrType: io.EOF, + }, + { + name: "write error to destination", + args: args{ + ctx: context.Background(), + server: mockServer, + logger: mockLogger, + capture: false, + errChan: make(chan error, 1), + }, + setup: func() (net.Conn, net.Conn) { + client, server := net.Pipe() + server.Close() + return client, server + }, + wantErr: true, + }, + { + name: "zero byte read", + args: args{ + ctx: context.Background(), + server: mockServer, + logger: mockLogger, + capture: true, + errChan: make(chan error, 1), + }, + setup: func() (net.Conn, net.Conn) { + client, server := net.Pipe() + go func() { + client.Write([]byte{}) + client.Close() + }() + return client, server + }, + verify: func(t *testing.T, args args) { + assert.Empty(t, args.server.events) + }, + }, + { + name: "large data transfer", + args: args{ + ctx: context.Background(), + server: mockServer, + logger: mockLogger, + capture: true, + errChan: make(chan error, 1), + }, + setup: func() (net.Conn, net.Conn) { + client, server := net.Pipe() + largeData := make([]byte, 8192) + rand.Read(largeData) + go func() { + client.Write(largeData) + client.Close() + }() + return client, server + }, + verify: func(t *testing.T, args args) { + buf := make([]byte, 8192) + n, err := io.ReadFull(args.dst, buf) + require.NoError(t, err) + assert.Equal(t, 8192, n) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.args.src, tt.args.dst = tt.setup() + defer tt.args.src.Close() + defer tt.args.dst.Close() + } + + go pipeBidirectional( + tt.args.src, + tt.args.dst, + tt.args.server, + tt.args.logger, + tt.args.capture, + tt.args.errChan, + ) + + if tt.verify != nil { + tt.verify(t, tt.args) + } + }) + } +} diff --git a/rules/rules.go b/rules/rules.go index 0b4d595..b28e234 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -18,6 +18,7 @@ type RuleType int const ( UserConnHandler RuleType = iota Drop + Tcp_Proxy ) type Config struct { @@ -32,7 +33,7 @@ type Rule struct { Name string `yaml:"name,omitempty"` isInit bool - ruleType RuleType + RuleType RuleType index int matcher *pcap.BPF } @@ -59,9 +60,11 @@ func (rule *Rule) init(idx int) error { switch rule.Type { case "conn_handler": - rule.ruleType = UserConnHandler + rule.RuleType = UserConnHandler + case "tcp_proxy": + rule.RuleType = Tcp_Proxy case "drop": - rule.ruleType = Drop + rule.RuleType = Drop default: return fmt.Errorf("unknown rule type: %s", rule.Type) }