Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions modules/postgres/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ func decodeAuthMode(buf []byte) *AuthenticationMode {
}
}

// isValidPostgresError checks that a decoded PostgresError contains the minimum
// fields expected from a real PostgreSQL server: severity (or severity_v), code,
// and message. This guards against false-positive detections where a non-Postgres
// service happens to return data that superficially looks like a Postgres packet.
func isValidPostgresError(e *PostgresError) bool {
if e == nil {
return false
}
hasSeverity := (*e)["severity"] != "" || (*e)["severity_v"] != ""
return hasSeverity && (*e)["code"] != "" && (*e)["message"] != ""
}

// decodeError() decodes an 'E'-type tag into a map of friendly name -> value; see https://www.postgresql.org/docs/10/static/protocol-error-fields.html
func decodeError(buf []byte) *PostgresError {
partMap := map[byte]string{
Expand Down Expand Up @@ -411,12 +423,20 @@ func (scanner *Scanner) Scan(ctx context.Context, dialGroup *zgrab2.DialerGroup,
}

if response.Type != 'E' {
// No server should be allowing a 0.0 client...but if it does allow it, don't bail out
log.Debugf("Unexpected response from server: %s", response.ToString())
results.SupportedVersions = response.OutputValue()
} else {
results.SupportedVersions = strings.Trim(string(response.Body), "\x00\r\n ")
// A real Postgres server always returns an 'E' error for an unsupported
// protocol version. If we get anything else, this is not Postgres.
return zgrab2.SCAN_PROTOCOL_ERROR, &results, fmt.Errorf("expected Postgres error response to version probe, got message type '%c'", response.Type)
}
if response.Length > 0 {
// Standard structured error packet — validate it has real Postgres fields.
decoded := decodeError(response.Body)
if !isValidPostgresError(decoded) {
return zgrab2.SCAN_PROTOCOL_ERROR, &results, errors.New("server returned an 'E' packet without valid Postgres error fields")
}
}
// Length == 0 means a pre-startup error (raw \n\0-terminated string),
// which older Postgres versions use. Still a valid detection.
results.SupportedVersions = strings.Trim(string(response.Body), "\x00\r\n ")

if _, err := sql.ReadAll(); err != nil {
return err.Unpack(&results)
Expand All @@ -443,11 +463,19 @@ func (scanner *Scanner) Scan(ctx context.Context, dialGroup *zgrab2.DialerGroup,
}

if response.Type != 'E' {
// No server should be allowing a 255.255 client...but if it does allow it, don't bail out
log.Debugf("Unexpected response from server: %s", response.ToString())
results.ProtocolError = response.ToError()
return zgrab2.SCAN_PROTOCOL_ERROR, &results, fmt.Errorf("expected Postgres error response to high-version probe, got message type '%c'", response.Type)
}
if response.Length > 0 {
decoded := decodeError(response.Body)
if !isValidPostgresError(decoded) {
return zgrab2.SCAN_PROTOCOL_ERROR, &results, errors.New("server returned an 'E' packet without valid Postgres error fields")
}
results.ProtocolError = decoded
} else {
results.ProtocolError = decodeError(response.Body)
// Pre-startup raw error string from older Postgres versions
results.ProtocolError = &PostgresError{
"message": strings.Trim(string(response.Body), "\x00\r\n "),
}
}

if _, err := sql.ReadAll(); err != nil {
Expand Down
221 changes: 221 additions & 0 deletions modules/postgres/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package postgres
import (
"context"
stdtls "crypto/tls"
"encoding/binary"
"net"
"strings"
"testing"

"github.com/zmap/zgrab2"
Expand Down Expand Up @@ -101,3 +103,222 @@ func TestPostgresHandshakeCompletedSuccessfully(t *testing.T) {
t.Error("expected HandshakeCompletedSuccessfully = true")
}
}

func TestIsValidPostgresError(t *testing.T) {
tests := []struct {
name string
err *PostgresError
valid bool
}{
{"nil error", nil, false},
{"empty error", &PostgresError{}, false},
{"severity only", &PostgresError{"severity": "FATAL"}, false},
{"severity and code", &PostgresError{"severity": "FATAL", "code": "08P01"}, false},
{"severity_v and code and message", &PostgresError{"severity_v": "FATAL", "code": "08P01", "message": "unsupported version"}, true},
{"full valid error", &PostgresError{"severity": "FATAL", "code": "08P01", "message": "unsupported version"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isValidPostgresError(tt.err); got != tt.valid {
t.Errorf("isValidPostgresError() = %v, want %v", got, tt.valid)
}
})
}
}

// makePostgresErrorPacket builds a raw Postgres 'E'-type response packet from
// a map of field tag -> value. The packet format is:
// byte 'E' | uint32 length | (byte tag + string value + \0)... | \0
func makePostgresErrorPacket(fields map[byte]string) []byte {
var body []byte
for tag, val := range fields {
body = append(body, tag)
body = append(body, []byte(val)...)
body = append(body, 0)
}
body = append(body, 0) // terminator

length := uint32(len(body) + 4) // length includes itself
pkt := make([]byte, 1+4+len(body))
pkt[0] = 'E'
binary.BigEndian.PutUint32(pkt[1:5], length)
copy(pkt[5:], body)
return pkt
}

// validPostgresError is a reusable valid Postgres error packet
var validPostgresError = makePostgresErrorPacket(map[byte]string{
'S': "FATAL",
'V': "FATAL",
'C': "08P01",
'M': "unsupported frontend protocol",
})

// makeConnPairFunc returns a function that, on each call, returns a new
// client-side net.Conn and starts a goroutine running serverFn on the server side.
func makeConnPairFunc(serverFn func(net.Conn)) func() net.Conn {
return func() net.Conn {
client, server := net.Pipe()
go serverFn(server)
return client
}
}

// makeMultiL4Dialer returns an L4Dialer that calls newConn() for each dial,
// allowing the postgres scanner to open multiple sequential connections.
func makeMultiL4Dialer(newConn func() net.Conn) func(*zgrab2.ScanTarget) func(context.Context, string, string) (net.Conn, error) {
return func(*zgrab2.ScanTarget) func(context.Context, string, string) (net.Conn, error) {
return func(context.Context, string, string) (net.Conn, error) {
return newConn(), nil
}
}
}

// drainAndRespond reads one request from conn, writes response, then closes.
func drainAndRespond(conn net.Conn, response []byte) {
defer conn.Close()
buf := make([]byte, 4096)
conn.Read(buf) //nolint:errcheck
if len(response) > 0 {
conn.Write(response) //nolint:errcheck
}
}

func newTestScanner() *Scanner {
return &Scanner{Config: &Flags{SkipSSL: true, ProtocolVersion: "3.0"}}
}

func TestFalsePositiveDetection_NonPostgresServer(t *testing.T) {
// Simulates a non-Postgres service that responds 'N' to any request.
// The scanner should bail with SCAN_PROTOCOL_ERROR since 'N' is not a
// valid Postgres 'E' error response.
newConn := makeConnPairFunc(func(conn net.Conn) {
drainAndRespond(conn, []byte{'N'})
})
scanner := newTestScanner()
target := &zgrab2.ScanTarget{IP: net.ParseIP("127.0.0.1"), Port: 5432}
dialGroup := &zgrab2.DialerGroup{
L4Dialer: makeMultiL4Dialer(newConn),
}

status, _, err := scanner.Scan(context.Background(), dialGroup, target)
if status == zgrab2.SCAN_SUCCESS {
t.Errorf("expected non-success status for non-Postgres server, got %s", status)
}
if status != zgrab2.SCAN_PROTOCOL_ERROR {
t.Errorf("expected SCAN_PROTOCOL_ERROR, got %s (err: %v)", status, err)
}
}

func TestFalsePositiveDetection_InvalidErrorFields(t *testing.T) {
// Server returns an 'E'-type packet but with no structured fields — just
// garbage data. Should fail isValidPostgresError.
badErrorPkt := makePostgresErrorPacket(map[byte]string{
'X': "unknown",
})
newConn := makeConnPairFunc(func(conn net.Conn) {
drainAndRespond(conn, badErrorPkt)
})
scanner := newTestScanner()
target := &zgrab2.ScanTarget{IP: net.ParseIP("127.0.0.1"), Port: 5432}
dialGroup := &zgrab2.DialerGroup{
L4Dialer: makeMultiL4Dialer(newConn),
}

status, _, _ := scanner.Scan(context.Background(), dialGroup, target)
if status == zgrab2.SCAN_SUCCESS {
t.Errorf("expected non-success for invalid error fields, got %s", status)
}
if status != zgrab2.SCAN_PROTOCOL_ERROR {
t.Errorf("expected SCAN_PROTOCOL_ERROR, got %s", status)
}
}

func TestValidPostgresServer_PassesVersionProbe(t *testing.T) {
// Server returns a valid Postgres error for the version 0.0 probe.
// The scanner should accept it and populate SupportedVersions.
newConn := makeConnPairFunc(func(conn net.Conn) {
drainAndRespond(conn, validPostgresError)
})
scanner := newTestScanner()
target := &zgrab2.ScanTarget{IP: net.ParseIP("127.0.0.1"), Port: 5432}
dialGroup := &zgrab2.DialerGroup{
L4Dialer: makeMultiL4Dialer(newConn),
}

status, result, _ := scanner.Scan(context.Background(), dialGroup, target)
if result == nil {
t.Fatal("expected non-nil result")
}
pgResult, ok := result.(*Results)
if !ok {
t.Fatal("expected *Results")
}
if pgResult.SupportedVersions == "" {
t.Error("expected SupportedVersions to be populated after valid error response")
}
// The scan may fail on subsequent connections (our fake server only handles
// one exchange per connection), but the first probe should pass validation.
// SCAN_PROTOCOL_ERROR from the detection check would be a regression.
if status == zgrab2.SCAN_PROTOCOL_ERROR && strings.Contains(pgResult.SupportedVersions, "unsupported") {
t.Error("valid Postgres error was rejected by detection check")
}
}

func TestFalsePositive_ServerClosesImmediately(t *testing.T) {
// Server accepts the connection but closes immediately without sending
// any data. Should not result in SCAN_SUCCESS.
newConn := makeConnPairFunc(func(conn net.Conn) {
conn.Close()
})
scanner := newTestScanner()
target := &zgrab2.ScanTarget{IP: net.ParseIP("127.0.0.1"), Port: 5432}
dialGroup := &zgrab2.DialerGroup{
L4Dialer: makeMultiL4Dialer(newConn),
}

status, _, _ := scanner.Scan(context.Background(), dialGroup, target)
if status == zgrab2.SCAN_SUCCESS {
t.Errorf("expected non-success for immediately-closing server, got %s", status)
}
}

func TestPreStartupError_OlderPostgres(t *testing.T) {
// Older Postgres versions (pre-9.6 in some configurations) respond to
// bogus version probes with a pre-startup error: a raw \n\0-terminated
// string rather than a structured 'E' packet with tagged fields.
// The scanner must accept these as valid detections.
preStartupError := []byte("FATAL: unsupported frontend protocol 0.0: server supports 1.0 to 3.0\n\x00")

newConn := makeConnPairFunc(func(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 4096)
conn.Read(buf) //nolint:errcheck
// Send 'E' header byte followed by the raw error string.
// tryReadPacket sees length[0] > 0x00 and reads as pre-startup format.
response := append([]byte{'E'}, preStartupError...)
conn.Write(response) //nolint:errcheck
})
scanner := newTestScanner()
target := &zgrab2.ScanTarget{IP: net.ParseIP("127.0.0.1"), Port: 5432}
dialGroup := &zgrab2.DialerGroup{
L4Dialer: makeMultiL4Dialer(newConn),
}

status, result, scanErr := scanner.Scan(context.Background(), dialGroup, target)
if result == nil {
t.Fatal("expected non-nil result")
}
pgResult, ok := result.(*Results)
if !ok {
t.Fatal("expected *Results")
}
if pgResult.SupportedVersions == "" {
t.Error("expected SupportedVersions to be populated for pre-startup error")
}
// The scan will fail on subsequent connections, but the first probe must
// NOT fail with SCAN_PROTOCOL_ERROR from our detection check.
if status == zgrab2.SCAN_PROTOCOL_ERROR {
t.Errorf("pre-startup error from older Postgres was incorrectly rejected: %v", scanErr)
}
}
Loading