diff --git a/.gitignore b/.gitignore index 3f3c51ce..c629524b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ .DS_Store !/README.md /*.md +*.patch \ No newline at end of file diff --git a/internal/gtcpip/README.md b/gtcpip/README.md similarity index 100% rename from internal/gtcpip/README.md rename to gtcpip/README.md diff --git a/internal/gtcpip/checksum/checksum.go b/gtcpip/checksum/checksum.go similarity index 100% rename from internal/gtcpip/checksum/checksum.go rename to gtcpip/checksum/checksum.go diff --git a/internal/gtcpip/checksum/checksum_default.go b/gtcpip/checksum/checksum_default.go similarity index 100% rename from internal/gtcpip/checksum/checksum_default.go rename to gtcpip/checksum/checksum_default.go diff --git a/internal/gtcpip/checksum/checksum_ts.go b/gtcpip/checksum/checksum_ts.go similarity index 100% rename from internal/gtcpip/checksum/checksum_ts.go rename to gtcpip/checksum/checksum_ts.go diff --git a/internal/gtcpip/checksum/checksum_unsafe.go b/gtcpip/checksum/checksum_unsafe.go similarity index 100% rename from internal/gtcpip/checksum/checksum_unsafe.go rename to gtcpip/checksum/checksum_unsafe.go diff --git a/internal/gtcpip/errors.go b/gtcpip/errors.go similarity index 100% rename from internal/gtcpip/errors.go rename to gtcpip/errors.go diff --git a/internal/gtcpip/header/checksum.go b/gtcpip/header/checksum.go similarity index 97% rename from internal/gtcpip/header/checksum.go rename to gtcpip/header/checksum.go index 2c21e6d3..303502cc 100644 --- a/internal/gtcpip/header/checksum.go +++ b/gtcpip/header/checksum.go @@ -20,8 +20,8 @@ import ( "encoding/binary" "fmt" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" ) // PseudoHeaderChecksum calculates the pseudo-header checksum for the given diff --git a/internal/gtcpip/header/eth.go b/gtcpip/header/eth.go similarity index 99% rename from internal/gtcpip/header/eth.go rename to gtcpip/header/eth.go index 9d876ee6..613a72c6 100644 --- a/internal/gtcpip/header/eth.go +++ b/gtcpip/header/eth.go @@ -17,7 +17,7 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" ) const ( diff --git a/internal/gtcpip/header/icmpv4.go b/gtcpip/header/icmpv4.go similarity index 98% rename from internal/gtcpip/header/icmpv4.go rename to gtcpip/header/icmpv4.go index 580101c0..3b481041 100644 --- a/internal/gtcpip/header/icmpv4.go +++ b/gtcpip/header/icmpv4.go @@ -17,8 +17,8 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. diff --git a/internal/gtcpip/header/icmpv6.go b/gtcpip/header/icmpv6.go similarity index 98% rename from internal/gtcpip/header/icmpv6.go rename to gtcpip/header/icmpv6.go index 520b4036..7eae97ab 100644 --- a/internal/gtcpip/header/icmpv6.go +++ b/gtcpip/header/icmpv6.go @@ -17,8 +17,8 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. diff --git a/internal/gtcpip/header/interfaces.go b/gtcpip/header/interfaces.go similarity index 98% rename from internal/gtcpip/header/interfaces.go rename to gtcpip/header/interfaces.go index fc13100c..c0bb410c 100644 --- a/internal/gtcpip/header/interfaces.go +++ b/gtcpip/header/interfaces.go @@ -17,7 +17,7 @@ package header import ( "net/netip" - tcpip "github.com/sagernet/sing-tun/internal/gtcpip" + tcpip "github.com/sagernet/sing-tun/gtcpip" ) const ( diff --git a/internal/gtcpip/header/ipv4.go b/gtcpip/header/ipv4.go similarity index 99% rename from internal/gtcpip/header/ipv4.go rename to gtcpip/header/ipv4.go index ad06f38c..d5ffbf1d 100644 --- a/internal/gtcpip/header/ipv4.go +++ b/gtcpip/header/ipv4.go @@ -20,8 +20,8 @@ import ( "net/netip" "time" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" "github.com/sagernet/sing/common" ) diff --git a/internal/gtcpip/header/ipv6.go b/gtcpip/header/ipv6.go similarity index 99% rename from internal/gtcpip/header/ipv6.go rename to gtcpip/header/ipv6.go index 1a5a7a05..4de30737 100644 --- a/internal/gtcpip/header/ipv6.go +++ b/gtcpip/header/ipv6.go @@ -20,7 +20,7 @@ import ( "fmt" "net/netip" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" ) const ( diff --git a/internal/gtcpip/header/ipv6_extension_headers.go b/gtcpip/header/ipv6_extension_headers.go similarity index 99% rename from internal/gtcpip/header/ipv6_extension_headers.go rename to gtcpip/header/ipv6_extension_headers.go index 20064d8b..6c48b1bf 100644 --- a/internal/gtcpip/header/ipv6_extension_headers.go +++ b/gtcpip/header/ipv6_extension_headers.go @@ -20,7 +20,7 @@ import ( "fmt" "math" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" "github.com/sagernet/sing/common" ) diff --git a/internal/gtcpip/header/ipv6_fragment.go b/gtcpip/header/ipv6_fragment.go similarity index 99% rename from internal/gtcpip/header/ipv6_fragment.go rename to gtcpip/header/ipv6_fragment.go index 49aaca71..38f0b202 100644 --- a/internal/gtcpip/header/ipv6_fragment.go +++ b/gtcpip/header/ipv6_fragment.go @@ -17,7 +17,7 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" ) const ( diff --git a/internal/gtcpip/header/ndp_neighbor_advert.go b/gtcpip/header/ndp_neighbor_advert.go similarity index 98% rename from internal/gtcpip/header/ndp_neighbor_advert.go rename to gtcpip/header/ndp_neighbor_advert.go index 7a934cce..8f36765a 100644 --- a/internal/gtcpip/header/ndp_neighbor_advert.go +++ b/gtcpip/header/ndp_neighbor_advert.go @@ -14,7 +14,7 @@ package header -import "github.com/sagernet/sing-tun/internal/gtcpip" +import "github.com/sagernet/sing-tun/gtcpip" // NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will // only contain the body of an ICMPv6 packet. diff --git a/internal/gtcpip/header/ndp_neighbor_solicit.go b/gtcpip/header/ndp_neighbor_solicit.go similarity index 97% rename from internal/gtcpip/header/ndp_neighbor_solicit.go rename to gtcpip/header/ndp_neighbor_solicit.go index 61d61a8a..b4af20ce 100644 --- a/internal/gtcpip/header/ndp_neighbor_solicit.go +++ b/gtcpip/header/ndp_neighbor_solicit.go @@ -14,7 +14,7 @@ package header -import "github.com/sagernet/sing-tun/internal/gtcpip" +import "github.com/sagernet/sing-tun/gtcpip" // NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only // contain the body of an ICMPv6 packet. diff --git a/internal/gtcpip/header/ndp_options.go b/gtcpip/header/ndp_options.go similarity index 99% rename from internal/gtcpip/header/ndp_options.go rename to gtcpip/header/ndp_options.go index ba293398..365329a2 100644 --- a/internal/gtcpip/header/ndp_options.go +++ b/gtcpip/header/ndp_options.go @@ -23,7 +23,7 @@ import ( "math" "time" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" "github.com/sagernet/sing/common" ) diff --git a/internal/gtcpip/header/ndp_router_advert.go b/gtcpip/header/ndp_router_advert.go similarity index 100% rename from internal/gtcpip/header/ndp_router_advert.go rename to gtcpip/header/ndp_router_advert.go diff --git a/internal/gtcpip/header/ndp_router_solicit.go b/gtcpip/header/ndp_router_solicit.go similarity index 100% rename from internal/gtcpip/header/ndp_router_solicit.go rename to gtcpip/header/ndp_router_solicit.go diff --git a/internal/gtcpip/header/ndpoptionidentifier_string.go b/gtcpip/header/ndpoptionidentifier_string.go similarity index 100% rename from internal/gtcpip/header/ndpoptionidentifier_string.go rename to gtcpip/header/ndpoptionidentifier_string.go diff --git a/internal/gtcpip/header/netip.go b/gtcpip/header/netip.go similarity index 100% rename from internal/gtcpip/header/netip.go rename to gtcpip/header/netip.go diff --git a/internal/gtcpip/header/tcp.go b/gtcpip/header/tcp.go similarity index 99% rename from internal/gtcpip/header/tcp.go rename to gtcpip/header/tcp.go index 1b58df86..824b08c8 100644 --- a/internal/gtcpip/header/tcp.go +++ b/gtcpip/header/tcp.go @@ -17,9 +17,9 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/seqnum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/seqnum" "github.com/google/btree" ) diff --git a/internal/gtcpip/header/udp.go b/gtcpip/header/udp.go similarity index 98% rename from internal/gtcpip/header/udp.go rename to gtcpip/header/udp.go index a995a172..ce7708e1 100644 --- a/internal/gtcpip/header/udp.go +++ b/gtcpip/header/udp.go @@ -18,8 +18,8 @@ import ( "encoding/binary" "math" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" ) const ( diff --git a/internal/gtcpip/seqnum/seqnum.go b/gtcpip/seqnum/seqnum.go similarity index 100% rename from internal/gtcpip/seqnum/seqnum.go rename to gtcpip/seqnum/seqnum.go diff --git a/internal/gtcpip/tcpip.go b/gtcpip/tcpip.go similarity index 100% rename from internal/gtcpip/tcpip.go rename to gtcpip/tcpip.go diff --git a/internal/checksum_test/sum_bench_test.go b/internal/checksum_test/sum_bench_test.go index 35ee021c..2d07fff6 100644 --- a/internal/checksum_test/sum_bench_test.go +++ b/internal/checksum_test/sum_bench_test.go @@ -4,7 +4,7 @@ import ( "crypto/rand" "testing" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/checksum" "github.com/sagernet/sing-tun/internal/tschecksum" ) diff --git a/nfqueue_linux.go b/nfqueue_linux.go index baaefb54..9eed52fc 100644 --- a/nfqueue_linux.go +++ b/nfqueue_linux.go @@ -7,7 +7,7 @@ import ( "errors" "sync/atomic" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" diff --git a/ping/conn_interfaces.go b/ping/conn_interfaces.go new file mode 100644 index 00000000..96883a85 --- /dev/null +++ b/ping/conn_interfaces.go @@ -0,0 +1,11 @@ +package ping + +import "net/netip" + +type readMsgConn interface { + ReadMsg(b, oob []byte) (n, oobn int, addr netip.Addr, err error) +} + +type ttlSetter interface { + SetTTL(ttl uint8) +} diff --git a/ping/destination.go b/ping/destination.go index 60decb4c..8ac53491 100644 --- a/ping/destination.go +++ b/ping/destination.go @@ -9,25 +9,31 @@ import ( "sync" "time" - "github.com/sagernet/sing-tun" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + tun "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/gtcpip/header" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) var _ tun.DirectRouteDestination = (*Destination)(nil) type Destination struct { - conn *Conn - ctx context.Context - logger logger.ContextLogger - destination netip.Addr - routeContext tun.DirectRouteContext - timeout time.Duration - requestAccess sync.Mutex - requests map[pingRequest]time.Time + conn *Conn + errorListener *ErrorListener + ctx context.Context + logger logger.ContextLogger + destination netip.Addr + routeContext tun.DirectRouteContext + timeout time.Duration + requestAccess sync.Mutex + requests map[pingRequest]time.Time + originalSource common.TypedValue[netip.Addr] } type pingRequest struct { @@ -70,6 +76,15 @@ func ConnectDestination( timeout: timeout, requests: make(map[pingRequest]time.Time), } + + if errorListener := tryListenErrors(ctx, logger, controlFunc, destination); errorListener != nil { + d.errorListener = errorListener + go d.loopReadErrors() + logger.DebugContext(ctx, "ICMP error listener started") + } else { + logger.WarnContext(ctx, "ICMP error listener not available") + } + go d.loopRead() return d, nil } @@ -101,24 +116,24 @@ func (d *Destination) loopRead() { continue } icmpHdr := header.ICMPv4(ipHdr.Payload()) - if d.needFilter() { - if icmpHdr.Type() != header.ICMPv4EchoReply { - continue - } - var requestExists bool - request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()} - d.requestAccess.Lock() - _, loaded := d.requests[request] - if loaded { - requestExists = true - delete(d.requests, request) - } - d.requestAccess.Unlock() - if !requestExists { - continue + switch icmpHdr.Type() { + case header.ICMPv4EchoReply: + if d.needFilter() { + request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()} + d.requestAccess.Lock() + _, loaded := d.requests[request] + if loaded { + delete(d.requests, request) + } + d.requestAccess.Unlock() + if !loaded { + continue + } } + d.logger.TraceContext(d.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) + default: + continue } - d.logger.TraceContext(d.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) } else { ipHdr := header.IPv6(buffer.Bytes()) if !ipHdr.IsValid(buffer.Len()) { @@ -130,30 +145,196 @@ func (d *Destination) loopRead() { continue } icmpHdr := header.ICMPv6(ipHdr.Payload()) - if d.needFilter() { - if icmpHdr.Type() != header.ICMPv6EchoReply { + switch icmpHdr.Type() { + case header.ICMPv6EchoReply: + if d.needFilter() { + request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()} + d.requestAccess.Lock() + _, loaded := d.requests[request] + if loaded { + delete(d.requests, request) + } + d.requestAccess.Unlock() + if !loaded { + continue + } + } + d.logger.TraceContext(d.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) + default: + continue + } + } + err = d.routeContext.WritePacket(buffer.Bytes()) + if err != nil { + d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply")) + } + buffer.Release() + } +} + +func (d *Destination) loopReadErrors() { + defer d.errorListener.Close() + for { + buffer := buf.NewSize(1500) + if d.destination.Is6() { + // IPv6 raw sockets don't include the IPv6 header in received data; + // reserve space so we can prepend one later. + buffer.Advance(header.IPv6MinimumSize) + } + oob := make([]byte, 128) + n, oobn, addr, err := d.errorListener.ReadMsg(buffer.FreeBytes(), oob) + if err != nil { + buffer.Release() + if !E.IsClosed(err) { + d.logger.ErrorContext(d.ctx, E.Cause(err, "receive ICMP error")) + } + return + } + buffer.Truncate(n) + d.logger.DebugContext(d.ctx, "received raw ICMP packet from ", addr, " size ", n) + + if !d.destination.Is6() { + var ttl int + if oobn > 0 { + var cm ipv4.ControlMessage + err = cm.Parse(oob[:oobn]) + if err == nil { + ttl = cm.TTL + } + } + ipHdr := header.IPv4(buffer.Bytes()) + if !ipHdr.IsValid(n) { + continue + } + if ipHdr.PayloadLength() < header.ICMPv4MinimumSize { + continue + } + icmpHdr := header.ICMPv4(ipHdr.Payload()) + d.logger.DebugContext(d.ctx, "ICMPv4 error type ", uint8(icmpHdr.Type()), " from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr()) + switch icmpHdr.Type() { + case header.ICMPv4TimeExceeded, header.ICMPv4DstUnreachable: + if len(ipHdr.Payload()) < header.ICMPv4MinimumSize+header.IPv4MinimumSize+header.ICMPv4MinimumSize { + continue + } + innerIPHdr := header.IPv4(ipHdr.Payload()[header.ICMPv4MinimumSize:]) + if !innerIPHdr.IsValid(len(ipHdr.Payload()) - header.ICMPv4MinimumSize) { + continue + } + if innerIPHdr.PayloadLength() < header.ICMPv4MinimumSize { + continue + } + innerICMPHdr := header.ICMPv4(innerIPHdr.Payload()) + d.logger.DebugContext(d.ctx, "ICMPv4 error inner: src=", innerIPHdr.SourceAddr(), " dst=", innerIPHdr.DestinationAddr(), " ident=", innerICMPHdr.Ident(), " seq=", innerICMPHdr.Sequence()) + if d.needFilter() { + // The inner packet reflects the wire-level packet: source is the kernel's + // real IP (not the tunnel client IP) and ident is inverted (for privileged + // raw sockets). Invert ident back and match. + matchIdent := ^innerICMPHdr.Ident() + originalSource := d.originalSource.Load() + request := pingRequest{Source: originalSource, Destination: innerIPHdr.DestinationAddr(), Identifier: matchIdent, Sequence: innerICMPHdr.Sequence()} + d.requestAccess.Lock() + _, loaded := d.requests[request] + d.requestAccess.Unlock() + if !loaded { + d.logger.DebugContext(d.ctx, "ICMPv4 error: no matching request found") + continue + } + } + // Rewrite the error packet so it can be routed back through the tunnel: + // - outer destination → original client tunnel IP + // - inner source → original client tunnel IP + // - inner ident → original (pre-inversion) ident + originalSource := d.originalSource.Load() + if originalSource.IsValid() { + ipHdr.SetDestinationAddr(originalSource) + innerIPHdr.SetSourceAddr(originalSource) + innerIPHdr.SetChecksum(^innerIPHdr.CalculateChecksum()) + innerICMPHdr.SetIdent(^innerICMPHdr.Ident()) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + } else { + ipHdr.SetDestinationAddr(innerIPHdr.SourceAddr()) + } + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + d.logger.TraceContext(d.ctx, "read ICMPv4 error type ", uint8(icmpHdr.Type()), " from ", addr, " ttl ", ttl, " -> ", ipHdr.DestinationAddr()) + default: + continue + } + } else { + var hopLimit int + if oobn > 0 { + var cm *ipv6.ControlMessage + cm, err = parseIPv6ControlMessage(oob[:oobn]) + if err == nil && cm != nil { + hopLimit = cm.HopLimit + } + } + // IPv6 raw sockets return ICMPv6 payload only (no IPv6 header) + if n < header.ICMPv6MinimumSize { + continue + } + icmpHdr := header.ICMPv6(buffer.Bytes()) + d.logger.DebugContext(d.ctx, "ICMPv6 error type ", uint8(icmpHdr.Type()), " from ", addr) + switch icmpHdr.Type() { + case header.ICMPv6TimeExceeded, header.ICMPv6DstUnreachable: + if n < header.ICMPv6MinimumSize+header.IPv6MinimumSize+header.ICMPv6MinimumSize { continue } - var requestExists bool - request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()} - d.requestAccess.Lock() - _, loaded := d.requests[request] - if loaded { - requestExists = true - delete(d.requests, request) + innerIPHdr := header.IPv6(buffer.Bytes()[header.ICMPv6MinimumSize:]) + if !innerIPHdr.IsValid(n - header.ICMPv6MinimumSize) { + continue } - d.requestAccess.Unlock() - if !requestExists { + if innerIPHdr.PayloadLength() < header.ICMPv6MinimumSize { continue } + innerICMPHdr := header.ICMPv6(innerIPHdr.Payload()) + if d.needFilter() { + matchIdent := ^innerICMPHdr.Ident() + originalSource := d.originalSource.Load() + request := pingRequest{Source: originalSource, Destination: innerIPHdr.DestinationAddr(), Identifier: matchIdent, Sequence: innerICMPHdr.Sequence()} + d.requestAccess.Lock() + _, loaded := d.requests[request] + d.requestAccess.Unlock() + if !loaded { + continue + } + } + dstAddr := addr + originalSource := d.originalSource.Load() + if originalSource.IsValid() { + dstAddr = originalSource + innerIPHdr.SetSourceAddr(originalSource) + innerICMPHdr.SetIdent(^innerICMPHdr.Ident()) + innerICMPHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: innerICMPHdr, + Src: innerIPHdr.SourceAddressSlice(), + Dst: innerIPHdr.DestinationAddressSlice(), + })) + } else { + dstAddr = innerIPHdr.SourceAddr() + } + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: addr.AsSlice(), + Dst: dstAddr.AsSlice(), + })) + // Prepend synthesized IPv6 header + ipHdr := header.IPv6(buffer.ExtendHeader(header.IPv6MinimumSize)) + ipHdr.Encode(&header.IPv6Fields{ + PayloadLength: uint16(n), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: uint8(hopLimit), + SrcAddr: addr, + DstAddr: dstAddr, + }) + d.logger.TraceContext(d.ctx, "read ICMPv6 error type ", uint8(icmpHdr.Type()), " from ", addr, " hoplimit ", hopLimit, " -> ", dstAddr) + default: + continue } - d.logger.TraceContext(d.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) } err = d.routeContext.WritePacket(buffer.Bytes()) if err != nil { - d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply")) + d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP error")) } - buffer.Release() } } @@ -167,9 +348,14 @@ func (d *Destination) WritePacket(packet *buf.Buffer) error { return E.New("invalid ICMPv4 header") } icmpHdr := header.ICMPv4(ipHdr.Payload()) + d.originalSource.Store(ipHdr.SourceAddr()) if d.needFilter() { d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}) } + ttl := ipHdr.TTL() + if ttl > 0 { + _ = d.conn.SetTTL(ttl) + } d.logger.TraceContext(d.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) } else { ipHdr := header.IPv6(packet.Bytes()) @@ -180,9 +366,14 @@ func (d *Destination) WritePacket(packet *buf.Buffer) error { return E.New("invalid ICMPv6 header") } icmpHdr := header.ICMPv6(ipHdr.Payload()) + d.originalSource.Store(ipHdr.SourceAddr()) if d.needFilter() { d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}) } + hopLimit := ipHdr.HopLimit() + if hopLimit > 0 { + _ = d.conn.SetTTL(hopLimit) + } d.logger.TraceContext(d.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) } return d.conn.WriteIP(packet) @@ -216,6 +407,9 @@ func (d *Destination) registerRequest(request pingRequest) { } func (d *Destination) Close() error { + if d.errorListener != nil { + _ = d.errorListener.Close() + } return d.conn.Close() } diff --git a/ping/destination_gvisor.go b/ping/destination_gvisor.go index 0bac8036..4d988960 100644 --- a/ping/destination_gvisor.go +++ b/ping/destination_gvisor.go @@ -13,7 +13,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport" "github.com/sagernet/gvisor/pkg/waiter" - "github.com/sagernet/sing-tun" + tun "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" @@ -23,12 +23,13 @@ import ( var _ tun.DirectRouteDestination = (*GVisorDestination)(nil) type GVisorDestination struct { - ctx context.Context - logger logger.ContextLogger - endpoint tcpip.Endpoint - conn *gonet.TCPConn - rewriter *SourceRewriter - timeout time.Duration + ctx context.Context + logger logger.ContextLogger + endpoint tcpip.Endpoint + conn *gonet.TCPConn + rewriter *SourceRewriter + errorListener *ErrorListener + timeout time.Duration } func ConnectGVisor( @@ -86,6 +87,14 @@ func ConnectGVisor( rewriter: rewriter, timeout: timeout, } + + // Try to create an error listener for receiving ICMP error messages + if errorListener := tryListenErrors(ctx, logger, nil, destinationAddress); errorListener != nil { + destination.errorListener = errorListener + go destination.loopReadErrors() + logger.DebugContext(ctx, "ICMP error listener started") + } + go destination.loopRead() return destination, nil } @@ -120,7 +129,37 @@ func (d *GVisorDestination) WritePacket(packet *buf.Buffer) error { return common.Error(d.conn.Write(packet.Bytes())) } +func (d *GVisorDestination) loopReadErrors() { + defer d.errorListener.Close() + for { + buffer := buf.NewSize(1500) + oob := make([]byte, 128) + n, _, _, err := d.errorListener.ReadMsg(buffer.FreeBytes(), oob) + if err != nil { + buffer.Release() + if !E.IsClosed(err) { + d.logger.ErrorContext(d.ctx, E.Cause(err, "receive ICMP error")) + } + return + } + buffer.Truncate(n) + + // The error listener receives raw IP packets, but we ignore the + // TTL/hop limit from cmsg since gvisor handles that separately. + // Pass the raw IP packet through the SourceRewriter which handles + // address rewriting for both Echo Replies and ICMP errors. + _, err = d.rewriter.WriteBack(buffer.Bytes()) + if err != nil { + d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP error")) + } + buffer.Release() + } +} + func (d *GVisorDestination) Close() error { + if d.errorListener != nil { + _ = d.errorListener.Close() + } return d.conn.Close() } diff --git a/ping/destination_rewriter.go b/ping/destination_rewriter.go index a61e1556..26bb3551 100644 --- a/ping/destination_rewriter.go +++ b/ping/destination_rewriter.go @@ -4,7 +4,7 @@ import ( "net/netip" "github.com/sagernet/sing-tun" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common/buf" ) diff --git a/ping/error_listener.go b/ping/error_listener.go new file mode 100644 index 00000000..4736b21b --- /dev/null +++ b/ping/error_listener.go @@ -0,0 +1,29 @@ +package ping + +import ( + "context" + "net/netip" + + "github.com/sagernet/sing/common/control" + "github.com/sagernet/sing/common/logger" +) + +// tryListenErrors tries to create an ICMP error listener, first with +// privileged sockets, falling back to unprivileged if needed. +// Returns nil (no error) if neither mode succeeds. +func tryListenErrors( + ctx context.Context, + logger logger.ContextLogger, + controlFunc control.Func, + destination netip.Addr, +) *ErrorListener { + errorListener, err := listenErrors(true, controlFunc, destination) + if err != nil { + logger.DebugContext(ctx, "privileged error listener failed: ", err) + errorListener, err = listenErrors(false, controlFunc, destination) + if err != nil { + logger.DebugContext(ctx, "unprivileged error listener failed: ", err) + } + } + return errorListener +} diff --git a/ping/ping.go b/ping/ping.go index eab977b7..9d89f37f 100644 --- a/ping/ping.go +++ b/ping/ping.go @@ -9,7 +9,7 @@ import ( "sync/atomic" "time" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" @@ -70,8 +70,8 @@ func (c *Conn) connect(controlFunc control.Func) (err error) { } return } - } else if unprivilegedConn, isUnprivilegedConn := c.conn.(*UnprivilegedConn); isUnprivilegedConn { - c.readMsg = unprivilegedConn.ReadMsg + } else if rmConn, ok := c.conn.(readMsgConn); ok { + c.readMsg = rmConn.ReadMsg } else { return E.New("unsupported conn type: ", reflect.TypeOf(c.conn)) } diff --git a/ping/ping_test.go b/ping/ping_test.go index 7ec291a3..163dbd4c 100644 --- a/ping/ping_test.go +++ b/ping/ping_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/sagernet/gvisor/pkg/rand" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common/buf" diff --git a/ping/socket_linux_unprivileged.go b/ping/socket_linux_unprivileged.go index 79fd682d..8a640d83 100644 --- a/ping/socket_linux_unprivileged.go +++ b/ping/socket_linux_unprivileged.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" @@ -23,6 +23,7 @@ type UnprivilegedConn struct { destination netip.Addr receiveChan chan *unprivilegedResponse readDeadline pipe.Deadline + ttl uint8 mappingAccess sync.Mutex mapping map[uint16]net.Conn } @@ -105,7 +106,14 @@ func (c *UnprivilegedConn) Write(b []byte) (n int, err error) { go c.fetchResponse(conn.(*net.UDPConn), identifier) c.mapping[identifier] = conn } + ttl := c.ttl c.mappingAccess.Unlock() + if ttl > 0 { + err = c.setTTL(conn, ttl) + if err != nil { + return + } + } n, err = conn.Write(b) if err != nil { c.removeConn(conn.(*net.UDPConn), identifier) @@ -167,6 +175,12 @@ func (c *UnprivilegedConn) Close() error { return nil } +func (c *UnprivilegedConn) SetTTL(ttl uint8) { + c.mappingAccess.Lock() + c.ttl = ttl + c.mappingAccess.Unlock() +} + func (c *UnprivilegedConn) LocalAddr() net.Addr { return M.Socksaddr{} } diff --git a/ping/socket_unix.go b/ping/socket_unix.go index 1eec10a5..c7ca50e9 100644 --- a/ping/socket_unix.go +++ b/ping/socket_unix.go @@ -8,6 +8,7 @@ import ( "os" "runtime" "syscall" + "time" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/control" @@ -108,3 +109,80 @@ func connect(privileged bool, controlFunc control.Func, destination netip.Addr) return conn, nil } } + +// ErrorListener listens for ICMP error messages (Time Exceeded, Destination Unreachable) +// on an unconnected raw ICMP socket. +type ErrorListener struct { + conn *net.IPConn + destination netip.Addr +} + +func listenErrors(privileged bool, controlFunc control.Func, destination netip.Addr) (*ErrorListener, error) { + var ( + fd int + err error + ) + if destination.Is4() { + fd, err = unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_ICMP) + } else { + fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_ICMPV6) + } + if err != nil { + return nil, E.Cause(err, "socket()") + } + file := os.NewFile(uintptr(fd), "icmp-error-listener") + defer file.Close() + + if destination.Is4() { + err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1) + if err != nil { + return nil, E.Cause(err, "setsockopt IP_RECVTTL") + } + } else { + err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_RECVHOPLIMIT, 1) + if err != nil { + return nil, E.Cause(err, "setsockopt IPV6_RECVHOPLIMIT") + } + err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) + if err != nil { + return nil, E.Cause(err, "setsockopt IPV6_RECVTCLASS") + } + } + + var bindAddress netip.Addr + if !destination.Is6() { + bindAddress = netip.AddrFrom4([4]byte{}) + } else { + bindAddress = netip.AddrFrom16([16]byte{}) + } + err = unix.Bind(fd, M.AddrPortToSockaddr(netip.AddrPortFrom(bindAddress, 0))) + if err != nil { + return nil, E.Cause(err, "bind()") + } + + ipConn, err := net.FileConn(file) + if err != nil { + return nil, err + } + return &ErrorListener{ + conn: ipConn.(*net.IPConn), + destination: destination, + }, nil +} + +func (l *ErrorListener) ReadMsg(b, oob []byte) (n, oobn int, addr netip.Addr, err error) { + var ipAddr *net.IPAddr + n, oobn, _, ipAddr, err = l.conn.ReadMsgIP(b, oob) + if err == nil { + addr = M.AddrFromNet(ipAddr) + } + return +} + +func (l *ErrorListener) Close() error { + return l.conn.Close() +} + +func (l *ErrorListener) SetReadDeadline(t time.Time) error { + return l.conn.SetReadDeadline(t) +} diff --git a/ping/socket_windows_other.go b/ping/socket_windows_other.go new file mode 100644 index 00000000..c4d3870a --- /dev/null +++ b/ping/socket_windows_other.go @@ -0,0 +1,29 @@ +//go:build !unix + +package ping + +import ( + "net/netip" + "time" + + "github.com/sagernet/sing/common/control" +) + +type ErrorListener struct{} + +func listenErrors(privileged bool, controlFunc control.Func, destination netip.Addr) (*ErrorListener, error) { + // ICMP error listening not supported on non-Unix platforms + return nil, nil +} + +func (l *ErrorListener) ReadMsg(b, oob []byte) (n, oobn int, addr netip.Addr, err error) { + return 0, 0, netip.Addr{}, nil +} + +func (l *ErrorListener) Close() error { + return nil +} + +func (l *ErrorListener) SetReadDeadline(t time.Time) error { + return nil +} diff --git a/ping/source_rewriter.go b/ping/source_rewriter.go index 480c6a78..8360a3d3 100644 --- a/ping/source_rewriter.go +++ b/ping/source_rewriter.go @@ -5,17 +5,29 @@ import ( "net/netip" "sync" - "github.com/sagernet/sing-tun" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + tun "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common/logger" ) +// sourceKey identifies an outgoing session by protocol, port (or ICMP ident), +// and destination, for reverse lookup when rewriting ICMP errors or echo replies. +type sourceKey struct { + Protocol uint8 + Port uint16 // ICMP ident, or TCP/UDP source port + Destination netip.Addr +} + type SourceRewriter struct { - ctx context.Context - logger logger.ContextLogger - access sync.RWMutex - sessions map[tun.DirectRouteSession]tun.DirectRouteContext - sourceAddress map[uint16]netip.Addr + ctx context.Context + logger logger.ContextLogger + access sync.RWMutex + // sessions tracks active DirectRoute sessions for routing writeback packets. + sessions map[tun.DirectRouteSession]tun.DirectRouteContext + // sourceAddress maps {protocol, port/ident, destination} → original client + // address, used to rewrite ICMP echo replies and error destinations back + // to the TUN client. Covers ICMP, UDP, and TCP inner packets. + sourceAddress map[sourceKey]netip.Addr inet4Address netip.Addr inet6Address netip.Addr } @@ -25,7 +37,7 @@ func NewSourceRewriter(ctx context.Context, logger logger.ContextLogger, inet4Ad ctx: ctx, logger: logger, sessions: make(map[tun.DirectRouteSession]tun.DirectRouteContext), - sourceAddress: make(map[uint16]netip.Addr), + sourceAddress: make(map[sourceKey]netip.Addr), inet4Address: inet4Address, inet6Address: inet6Address, } @@ -65,7 +77,7 @@ func (m *SourceRewriter) RewritePacket(packet []byte) { case header.ICMPv4ProtocolNumber: icmpHdr := header.ICMPv4(ipHdr.Payload()) m.access.Lock() - m.sourceAddress[icmpHdr.Ident()] = sourceAddr + m.sourceAddress[sourceKey{Protocol: uint8(header.ICMPv4ProtocolNumber), Port: icmpHdr.Ident(), Destination: ipHdr.DestinationAddr()}] = sourceAddr m.access.Unlock() m.logger.TraceContext(m.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) case header.ICMPv6ProtocolNumber: @@ -76,10 +88,61 @@ func (m *SourceRewriter) RewritePacket(packet []byte) { Dst: ipHdr.DestinationAddressSlice(), })) m.access.Lock() - m.sourceAddress[icmpHdr.Ident()] = sourceAddr + m.sourceAddress[sourceKey{Protocol: uint8(header.ICMPv6ProtocolNumber), Port: icmpHdr.Ident(), Destination: ipHdr.DestinationAddr()}] = sourceAddr m.access.Unlock() m.logger.TraceContext(m.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) + case header.UDPProtocolNumber: + if len(ipHdr.Payload()) >= header.UDPMinimumSize { + udpHdr := header.UDP(ipHdr.Payload()) + m.access.Lock() + m.sourceAddress[sourceKey{Protocol: uint8(header.UDPProtocolNumber), Port: udpHdr.SourcePort(), Destination: ipHdr.DestinationAddr()}] = sourceAddr + m.access.Unlock() + } + case header.TCPProtocolNumber: + if len(ipHdr.Payload()) >= header.TCPMinimumSize { + tcpHdr := header.TCP(ipHdr.Payload()) + m.access.Lock() + m.sourceAddress[sourceKey{Protocol: uint8(header.TCPProtocolNumber), Port: tcpHdr.SourcePort(), Destination: ipHdr.DestinationAddr()}] = sourceAddr + m.access.Unlock() + } + } +} + +// resolveInnerSource looks up the original client address from the inner +// transport header of an ICMP error. Returns the source address and true +// if found. +func (m *SourceRewriter) resolveInnerSource( + innerProto uint8, + innerPayloadLen uint16, + innerPayload []byte, + innerDst netip.Addr, +) (netip.Addr, bool) { + var minSize int + switch innerProto { + case uint8(header.ICMPv4ProtocolNumber), uint8(header.ICMPv6ProtocolNumber): + minSize = header.ICMPv4MinimumSize + case uint8(header.UDPProtocolNumber): + minSize = header.UDPMinimumSize + case uint8(header.TCPProtocolNumber): + minSize = header.TCPMinimumSize + default: + return netip.Addr{}, false + } + if innerPayloadLen < uint16(minSize) { + return netip.Addr{}, false + } + var port uint16 + switch innerProto { + case uint8(header.ICMPv4ProtocolNumber), uint8(header.ICMPv6ProtocolNumber): + port = header.ICMPv4(innerPayload).Ident() // offset 4 + default: + port = header.UDP(innerPayload).SourcePort() // offset 0 (same for TCP) } + key := sourceKey{Protocol: innerProto, Port: port, Destination: innerDst} + m.access.RLock() + source, loaded := m.sourceAddress[key] + m.access.RUnlock() + return source, loaded } func (m *SourceRewriter) WriteBack(packet []byte) (bool, error) { @@ -95,34 +158,108 @@ func (m *SourceRewriter) WriteBack(packet []byte) (bool, error) { default: return false, nil } + var echoKey sourceKey + var isEchoReply bool + var resolvedSource netip.Addr switch ipHdr.TransportProtocol() { case header.ICMPv4ProtocolNumber: icmpHdr := header.ICMPv4(ipHdr.Payload()) - m.access.Lock() - ident := icmpHdr.Ident() - source, loaded := m.sourceAddress[ident] - if !loaded { - m.access.Unlock() + switch icmpHdr.Type() { + case header.ICMPv4EchoReply: + echoKey = sourceKey{Protocol: uint8(header.ICMPv4ProtocolNumber), Port: icmpHdr.Ident(), Destination: ipHdr.SourceAddr()} + isEchoReply = true + case header.ICMPv4TimeExceeded, header.ICMPv4DstUnreachable: + if len(ipHdr.Payload()) < header.ICMPv4MinimumSize+header.IPv4MinimumSize { + return false, nil + } + innerIPHdr := header.IPv4(ipHdr.Payload()[header.ICMPv4MinimumSize:]) + if !innerIPHdr.IsValid(len(ipHdr.Payload()) - header.ICMPv4MinimumSize) { + return false, nil + } + routeSession.Destination = innerIPHdr.DestinationAddr() + source, loaded := m.resolveInnerSource( + uint8(innerIPHdr.TransportProtocol()), + innerIPHdr.PayloadLength(), + innerIPHdr.Payload(), + innerIPHdr.DestinationAddr(), + ) + if !loaded { + if innerIPHdr.TransportProtocol() == header.ICMPv4ProtocolNumber { + // ICMP echo errors are optional — don't abort + break + } + return false, nil + } + innerIPHdr.SetSourceAddr(source) + innerIPHdr.SetChecksum(^innerIPHdr.CalculateChecksum()) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + resolvedSource = source + default: return false, nil } - delete(m.sourceAddress, icmpHdr.Ident()) - m.access.Unlock() - routeSession.Source = source case header.ICMPv6ProtocolNumber: icmpHdr := header.ICMPv6(ipHdr.Payload()) - m.access.Lock() - ident := icmpHdr.Ident() - source, loaded := m.sourceAddress[ident] - if !loaded { - m.access.Unlock() + switch icmpHdr.Type() { + case header.ICMPv6EchoReply: + echoKey = sourceKey{Protocol: uint8(header.ICMPv6ProtocolNumber), Port: icmpHdr.Ident(), Destination: ipHdr.SourceAddr()} + isEchoReply = true + case header.ICMPv6TimeExceeded, header.ICMPv6DstUnreachable: + if len(ipHdr.Payload()) < header.ICMPv6MinimumSize+header.IPv6MinimumSize { + return false, nil + } + innerIPHdr := header.IPv6(ipHdr.Payload()[header.ICMPv6MinimumSize:]) + if !innerIPHdr.IsValid(len(ipHdr.Payload()) - header.ICMPv6MinimumSize) { + return false, nil + } + routeSession.Destination = innerIPHdr.DestinationAddr() + source, loaded := m.resolveInnerSource( + uint8(innerIPHdr.TransportProtocol()), + innerIPHdr.PayloadLength(), + innerIPHdr.Payload(), + innerIPHdr.DestinationAddr(), + ) + if !loaded { + if innerIPHdr.TransportProtocol() == header.ICMPv6ProtocolNumber { + break + } + return false, nil + } + innerIPHdr.SetSourceAddr(source) + if innerIPHdr.TransportProtocol() == header.ICMPv6ProtocolNumber { + innerICMPHdr := header.ICMPv6(innerIPHdr.Payload()) + innerICMPHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: innerICMPHdr, + Src: innerIPHdr.SourceAddressSlice(), + Dst: innerIPHdr.DestinationAddressSlice(), + })) + } + resolvedSource = source + default: return false, nil } - delete(m.sourceAddress, icmpHdr.Ident()) - m.access.Unlock() - routeSession.Source = source default: return false, nil } + var source netip.Addr + if resolvedSource.IsValid() { + source = resolvedSource + } else { + var loaded bool + m.access.RLock() + source, loaded = m.sourceAddress[echoKey] + m.access.RUnlock() + if !loaded { + return false, nil + } + // Only delete the mapping for EchoReply, not for error messages + // (multiple errors may arrive for the same ident, e.g. traceroute) + if isEchoReply { + m.access.Lock() + delete(m.sourceAddress, echoKey) + m.access.Unlock() + } + } + routeSession.Source = source m.access.RLock() context, loaded := m.sessions[routeSession] m.access.RUnlock() @@ -136,7 +273,12 @@ func (m *SourceRewriter) WriteBack(packet []byte) (bool, error) { switch ipHdr.TransportProtocol() { case header.ICMPv4ProtocolNumber: icmpHdr := header.ICMPv4(ipHdr.Payload()) - m.logger.TraceContext(m.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) + switch icmpHdr.Type() { + case header.ICMPv4EchoReply: + m.logger.TraceContext(m.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) + case header.ICMPv4TimeExceeded, header.ICMPv4DstUnreachable: + m.logger.TraceContext(m.ctx, "read ICMPv4 error type ", uint8(icmpHdr.Type()), " from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr()) + } case header.ICMPv6ProtocolNumber: icmpHdr := header.ICMPv6(ipHdr.Payload()) icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -144,7 +286,12 @@ func (m *SourceRewriter) WriteBack(packet []byte) (bool, error) { Src: ipHdr.SourceAddressSlice(), Dst: ipHdr.DestinationAddressSlice(), })) - m.logger.TraceContext(m.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) + switch icmpHdr.Type() { + case header.ICMPv6EchoReply: + m.logger.TraceContext(m.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence()) + case header.ICMPv6TimeExceeded, header.ICMPv6DstUnreachable: + m.logger.TraceContext(m.ctx, "read ICMPv6 error type ", uint8(icmpHdr.Type()), " from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr()) + } } return true, context.WritePacket(packet) } diff --git a/ping/source_rewriter_internal_test.go b/ping/source_rewriter_internal_test.go new file mode 100644 index 00000000..c3aad28f --- /dev/null +++ b/ping/source_rewriter_internal_test.go @@ -0,0 +1,76 @@ +package ping + +import ( + "context" + "net/netip" + "testing" + + "github.com/sagernet/sing/common/logger" + + "github.com/stretchr/testify/require" +) + +func TestSourceRewriterTimeoutCleanup(t *testing.T) { + t.Parallel() + + rewriter := NewSourceRewriter( + context.Background(), + logger.NOP(), + netip.MustParseAddr("10.0.0.1"), + netip.Addr{}, + ) + + addr := netip.MustParseAddr("192.168.1.1") + dest := netip.MustParseAddr("1.1.1.1") + + // Insert an entry + key := sourceKey{Protocol: 1, Port: 100, Destination: dest} + rewriter.access.Lock() + rewriter.sourceAddress[key] = addr + rewriter.access.Unlock() + + // The entry should be found + rewriter.access.RLock() + found, ok := rewriter.sourceAddress[key] + rewriter.access.RUnlock() + require.True(t, ok, "entry should be found") + require.Equal(t, addr, found) + + // Delete and verify gone + rewriter.access.Lock() + delete(rewriter.sourceAddress, key) + rewriter.access.Unlock() + + rewriter.access.RLock() + _, ok = rewriter.sourceAddress[key] + rewriter.access.RUnlock() + require.False(t, ok, "deleted entry should not be found") +} + +func TestSourceRewriterCapacityLimit(t *testing.T) { + t.Parallel() + + rewriter := NewSourceRewriter( + context.Background(), + logger.NOP(), + netip.MustParseAddr("10.0.0.1"), + netip.Addr{}, + ) + + addr := netip.MustParseAddr("192.168.1.1") + dest := netip.MustParseAddr("1.1.1.1") + + // Insert multiple entries with different keys + for i := uint16(0); i < 100; i++ { + rewriter.access.Lock() + rewriter.sourceAddress[sourceKey{Protocol: 1, Port: i, Destination: dest}] = addr + rewriter.access.Unlock() + } + require.Equal(t, 100, len(rewriter.sourceAddress)) + + // The newest entry should exist + rewriter.access.RLock() + _, ok := rewriter.sourceAddress[sourceKey{Protocol: 1, Port: 99, Destination: dest}] + rewriter.access.RUnlock() + require.True(t, ok, "entry should exist") +} diff --git a/ping/source_rewriter_test.go b/ping/source_rewriter_test.go new file mode 100644 index 00000000..11ae3c2d --- /dev/null +++ b/ping/source_rewriter_test.go @@ -0,0 +1,762 @@ +package ping_test + +import ( + "context" + "encoding/binary" + "net/netip" + "testing" + + tun "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/gtcpip/header" + "github.com/sagernet/sing-tun/ping" + "github.com/sagernet/sing/common/logger" + + "github.com/stretchr/testify/require" +) + +// mockDirectRouteContext captures packets written back for verification. +type mockDirectRouteContext struct { + packets [][]byte +} + +func (m *mockDirectRouteContext) WritePacket(packet []byte) error { + copied := make([]byte, len(packet)) + copy(copied, packet) + m.packets = append(m.packets, copied) + return nil +} + +var _ tun.DirectRouteContext = (*mockDirectRouteContext)(nil) + +// --- SourceRewriter Tests --- + +func TestSourceRewriterEchoRequest(t *testing.T) { + t.Parallel() + + serverAddr := netip.MustParseAddr("192.168.10.254") + clientAddr := netip.MustParseAddr("10.0.0.2") + destAddr := netip.MustParseAddr("1.1.1.1") + + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + // Build an ICMP echo request from client to destination + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + packet := make([]byte, totalLen) + ipHdr := header.IPv4(packet) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: clientAddr, + DstAddr: destAddr, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + icmpHdr := header.ICMPv4(packet[header.IPv4MinimumSize:]) + icmpHdr.SetType(header.ICMPv4Echo) + icmpHdr.SetIdent(1234) + icmpHdr.SetSequence(1) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + + rewriter.RewritePacket(packet) + + // After rewrite, source should be serverAddr + ipHdr = header.IPv4(packet) + require.Equal(t, serverAddr, ipHdr.SourceAddr(), + "source should be rewritten to server bind address") + require.Equal(t, destAddr, ipHdr.DestinationAddr(), + "destination should remain unchanged") +} + +func TestSourceRewriterEchoReply(t *testing.T) { + t.Parallel() + + serverAddr := netip.MustParseAddr("192.168.10.254") + clientAddr := netip.MustParseAddr("10.0.0.2") + destAddr := netip.MustParseAddr("1.1.1.1") + + ctx := &mockDirectRouteContext{} + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + session := tun.DirectRouteSession{ + Source: clientAddr, + Destination: destAddr, + } + rewriter.CreateSession(session, ctx) + + // First, simulate outgoing request to register the ident mapping + reqPacket := buildICMPv4EchoRequest(t, clientAddr, destAddr, 1234, 1) + rewriter.RewritePacket(reqPacket) + + // Now simulate incoming reply + replyPacket := buildICMPv4EchoReply(t, destAddr, serverAddr, 1234, 1) + ok, err := rewriter.WriteBack(replyPacket) + require.NoError(t, err) + require.True(t, ok, "WriteBack should succeed for matching echo reply") + + require.Len(t, ctx.packets, 1, "one packet should be forwarded") + rewrittenIP := header.IPv4(ctx.packets[0]) + require.Equal(t, clientAddr, rewrittenIP.DestinationAddr(), + "reply destination should be rewritten to client tunnel IP") +} + +func TestSourceRewriterICMPv4TimeExceeded(t *testing.T) { + t.Parallel() + + // ICMP errors use the inner packet's destination for session lookup, + // so errors from both intermediate routers and the destination itself + // should match the session correctly. + + serverAddr := netip.MustParseAddr("192.168.10.254") + clientAddr := netip.MustParseAddr("10.0.0.2") + destAddr := netip.MustParseAddr("1.1.1.1") + routerAddr := netip.MustParseAddr("192.168.1.1") // intermediate router + + ctx := &mockDirectRouteContext{} + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + session := tun.DirectRouteSession{ + Source: clientAddr, + Destination: destAddr, + } + rewriter.CreateSession(session, ctx) + + // Register ident mapping via outgoing echo request + reqPacket := buildICMPv4EchoRequest(t, clientAddr, destAddr, 5678, 1) + rewriter.RewritePacket(reqPacket) + + // Error from intermediate router: should match via inner destination + errorPacket := buildICMPv4ErrorWithInnerICMP(t, + routerAddr, serverAddr, + serverAddr, destAddr, + 5678, 1, + header.ICMPv4TimeExceeded, + ) + + ok, err := rewriter.WriteBack(errorPacket) + require.NoError(t, err) + require.True(t, ok, "WriteBack should succeed for error from intermediate router") + + require.Len(t, ctx.packets, 1) + rewrittenIP := header.IPv4(ctx.packets[0]) + require.Equal(t, clientAddr, rewrittenIP.DestinationAddr()) + require.Equal(t, routerAddr, rewrittenIP.SourceAddr()) + + // Error from the destination itself should also match + errorFromDest := buildICMPv4ErrorWithInnerICMP(t, + destAddr, serverAddr, + serverAddr, destAddr, + 5678, 1, + header.ICMPv4DstUnreachable, + ) + ok, err = rewriter.WriteBack(errorFromDest) + require.NoError(t, err) + require.True(t, ok, "WriteBack should succeed for error from destination itself") + + require.Len(t, ctx.packets, 2) + rewrittenIP = header.IPv4(ctx.packets[1]) + require.Equal(t, clientAddr, rewrittenIP.DestinationAddr()) + + // Verify inner IP source was rewritten + icmpPayload := rewrittenIP.Payload() + innerIP := header.IPv4(icmpPayload[header.ICMPv4MinimumSize:]) + require.Equal(t, clientAddr, innerIP.SourceAddr(), + "inner IP source should be rewritten to client tunnel IP") +} + +func TestSourceRewriterICMPv4ErrorPreservesIdent(t *testing.T) { + t.Parallel() + + // After processing an error, the ident mapping should NOT be deleted + // (unlike echo reply), so multiple errors can arrive for the same ident. + // In the SourceRewriter path, only errors from the destination match, + // so we use destAddr as the error source. + + serverAddr := netip.MustParseAddr("192.168.10.254") + clientAddr := netip.MustParseAddr("10.0.0.2") + destAddr := netip.MustParseAddr("1.1.1.1") + + ctx := &mockDirectRouteContext{} + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + session := tun.DirectRouteSession{ + Source: clientAddr, + Destination: destAddr, + } + rewriter.CreateSession(session, ctx) + + // Register ident mapping + reqPacket := buildICMPv4EchoRequest(t, clientAddr, destAddr, 9999, 1) + rewriter.RewritePacket(reqPacket) + + // First error from destination (DstUnreachable) + errorPacket1 := buildICMPv4ErrorWithInnerICMP(t, + destAddr, serverAddr, + serverAddr, destAddr, + 9999, 1, + header.ICMPv4DstUnreachable, + ) + ok, err := rewriter.WriteBack(errorPacket1) + require.NoError(t, err) + require.True(t, ok) + + // Second error for same ident (ident mapping should still exist) + errorPacket2 := buildICMPv4ErrorWithInnerICMP(t, + destAddr, serverAddr, + serverAddr, destAddr, + 9999, 1, + header.ICMPv4DstUnreachable, + ) + ok, err = rewriter.WriteBack(errorPacket2) + require.NoError(t, err) + require.True(t, ok, "second error should still match (ident not deleted for errors)") + + require.Len(t, ctx.packets, 2, "both errors should be forwarded") +} + +func TestSourceRewriterUnknownIdent(t *testing.T) { + t.Parallel() + + serverAddr := netip.MustParseAddr("192.168.10.254") + + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + // Build a reply with an ident that was never registered + replyPacket := buildICMPv4EchoReply(t, + netip.MustParseAddr("1.1.1.1"), + serverAddr, + 42, 1, + ) + ok, err := rewriter.WriteBack(replyPacket) + require.NoError(t, err) + require.False(t, ok, "WriteBack should return false for unknown ident") +} + +func TestSourceRewriterSessionManagement(t *testing.T) { + t.Parallel() + + serverAddr := netip.MustParseAddr("192.168.10.254") + clientAddr := netip.MustParseAddr("10.0.0.2") + destAddr := netip.MustParseAddr("1.1.1.1") + + ctx := &mockDirectRouteContext{} + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + session := tun.DirectRouteSession{ + Source: clientAddr, + Destination: destAddr, + } + rewriter.CreateSession(session, ctx) + + // Register ident mapping + reqPacket := buildICMPv4EchoRequest(t, clientAddr, destAddr, 1111, 1) + rewriter.RewritePacket(reqPacket) + + // Delete session + rewriter.DeleteSession(session) + + // Now reply should fail (no session) + replyPacket := buildICMPv4EchoReply(t, destAddr, serverAddr, 1111, 1) + ok, err := rewriter.WriteBack(replyPacket) + require.NoError(t, err) + require.False(t, ok, "WriteBack should return false after session deletion") +} + +func TestSourceRewriterNonICMPPacket(t *testing.T) { + t.Parallel() + + serverAddr := netip.MustParseAddr("192.168.10.254") + + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + // Build a TCP packet (should be ignored by WriteBack) + totalLen := header.IPv4MinimumSize + 20 // minimal TCP header + packet := make([]byte, totalLen) + ipHdr := header.IPv4(packet) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: 6, // TCP + SrcAddr: netip.MustParseAddr("1.1.1.1"), + DstAddr: serverAddr, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + ok, err := rewriter.WriteBack(packet) + require.NoError(t, err) + require.False(t, ok, "WriteBack should return false for non-ICMP packets") +} + +func TestSourceRewriterDstUnreachable(t *testing.T) { + t.Parallel() + + serverAddr := netip.MustParseAddr("192.168.10.254") + clientAddr := netip.MustParseAddr("10.0.0.2") + destAddr := netip.MustParseAddr("1.1.1.1") + routerAddr := netip.MustParseAddr("1.1.1.1") // destination itself sends port unreachable + + ctx := &mockDirectRouteContext{} + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + session := tun.DirectRouteSession{ + Source: clientAddr, + Destination: destAddr, + } + rewriter.CreateSession(session, ctx) + + reqPacket := buildICMPv4EchoRequest(t, clientAddr, destAddr, 7777, 1) + rewriter.RewritePacket(reqPacket) + + errorPacket := buildICMPv4ErrorWithInnerICMP(t, + routerAddr, serverAddr, + serverAddr, destAddr, + 7777, 1, + header.ICMPv4DstUnreachable, + ) + + ok, err := rewriter.WriteBack(errorPacket) + require.NoError(t, err) + require.True(t, ok) + + require.Len(t, ctx.packets, 1) + rewrittenIP := header.IPv4(ctx.packets[0]) + require.Equal(t, clientAddr, rewrittenIP.DestinationAddr()) +} + +// --- Helper functions for building test packets --- + +func buildICMPv4EchoRequest(t *testing.T, src, dst netip.Addr, ident, seq uint16) []byte { + t.Helper() + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + packet := make([]byte, totalLen) + + ipHdr := header.IPv4(packet) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: src, + DstAddr: dst, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + icmpHdr := header.ICMPv4(packet[header.IPv4MinimumSize:]) + icmpHdr.SetType(header.ICMPv4Echo) + icmpHdr.SetIdent(ident) + icmpHdr.SetSequence(seq) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + + return packet +} + +func buildICMPv4EchoReply(t *testing.T, src, dst netip.Addr, ident, seq uint16) []byte { + t.Helper() + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + packet := make([]byte, totalLen) + + ipHdr := header.IPv4(packet) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: src, + DstAddr: dst, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + icmpHdr := header.ICMPv4(packet[header.IPv4MinimumSize:]) + icmpHdr.SetType(header.ICMPv4EchoReply) + icmpHdr.SetIdent(ident) + icmpHdr.SetSequence(seq) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + + return packet +} + +// buildICMPv4ErrorWithInnerICMP builds an ICMP error (TimeExceeded/DstUnreachable) +// containing an inner IPv4+ICMP echo request. +func buildICMPv4ErrorWithInnerICMP( + t *testing.T, + outerSrc, outerDst netip.Addr, + innerSrc, innerDst netip.Addr, + innerIdent, innerSeq uint16, + icmpType header.ICMPv4Type, +) []byte { + t.Helper() + + innerIPLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + innerIPLen + + packet := make([]byte, totalLen) + + // Outer IPv4 + outerIP := header.IPv4(packet) + outerIP.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: outerSrc, + DstAddr: outerDst, + }) + outerIP.SetChecksum(^outerIP.CalculateChecksum()) + + // ICMP error header + icmpOffset := header.IPv4MinimumSize + icmpHdr := header.ICMPv4(packet[icmpOffset:]) + icmpHdr.SetType(icmpType) + icmpHdr.SetCode(0) + + // Inner IPv4 + ICMP echo request + innerIPOffset := icmpOffset + header.ICMPv4MinimumSize + innerIP := header.IPv4(packet[innerIPOffset:]) + innerIP.Encode(&header.IPv4Fields{ + TotalLength: uint16(innerIPLen), + TTL: 1, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: innerSrc, + DstAddr: innerDst, + }) + innerIP.SetChecksum(^innerIP.CalculateChecksum()) + + innerICMPOffset := innerIPOffset + header.IPv4MinimumSize + innerICMP := header.ICMPv4(packet[innerICMPOffset:]) + innerICMP.SetType(header.ICMPv4Echo) + innerICMP.SetIdent(innerIdent) + innerICMP.SetSequence(innerSeq) + innerICMP.SetChecksum(header.ICMPv4Checksum(innerICMP, 0)) + + // Calculate outer ICMP checksum + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + + return packet +} + +// buildICMPv4ErrorWithInnerICMPInvertedIdent builds an ICMP error with an inner +// ICMP echo whose ident is inverted (as seen on the wire for privileged raw sockets). +func buildICMPv4ErrorWithInnerICMPInvertedIdent( + t *testing.T, + outerSrc, outerDst netip.Addr, + innerSrc, innerDst netip.Addr, + originalIdent, seq uint16, + icmpType header.ICMPv4Type, +) []byte { + t.Helper() + packet := buildICMPv4ErrorWithInnerICMP(t, + outerSrc, outerDst, + innerSrc, innerDst, + ^originalIdent, seq, // wire-level inverted ident + icmpType, + ) + return packet +} + +// --- Destination ICMP error rewrite tests (destination.go loopReadErrors logic) --- + +func TestDestinationICMPv4ErrorRewriteWithIdentInversion(t *testing.T) { + t.Parallel() + + // Simulate the exact scenario from the bug: + // Client sends ping with ident=1234. Server raw socket inverts to ^1234 on wire. + // Router sends ICMP TTL Exceeded. Inner ICMP has wire-level ^1234. + // loopReadErrors should match via ^innerIdent == ^(^1234) == 1234. + // Then rewrite: outer dst → client, inner src → client, inner ident → un-inverted. + + clientAddr := netip.MustParseAddr("10.0.0.2") + serverAddr := netip.MustParseAddr("192.168.10.254") + destAddr := netip.MustParseAddr("1.1.1.1") + routerAddr := netip.MustParseAddr("192.168.1.1") + var originalIdent uint16 = 1234 + + // Build error with inverted ident (as it appears on the wire) + packet := buildICMPv4ErrorWithInnerICMPInvertedIdent(t, + routerAddr, serverAddr, + serverAddr, destAddr, + originalIdent, 1, + header.ICMPv4TimeExceeded, + ) + + // Apply the same rewrite logic as destination.go loopReadErrors + outerIP := header.IPv4(packet) + icmpHdr := header.ICMPv4(outerIP.Payload()) + innerIP := header.IPv4(outerIP.Payload()[header.ICMPv4MinimumSize:]) + innerICMP := header.ICMPv4(innerIP.Payload()) + + // Verify the inner ident is inverted on wire + require.Equal(t, ^originalIdent, innerICMP.Ident(), "inner ident should be inverted on wire") + + // The matching logic: ^innerIdent should equal original ident + matchIdent := ^innerICMP.Ident() + require.Equal(t, originalIdent, matchIdent, "inverted ident should match original") + + // Apply rewrite + outerIP.SetDestinationAddr(clientAddr) + innerIP.SetSourceAddr(clientAddr) + innerICMP.SetIdent(^innerICMP.Ident()) // un-invert + innerIP.SetChecksum(^innerIP.CalculateChecksum()) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + outerIP.SetChecksum(^outerIP.CalculateChecksum()) + + // Verify final state + require.Equal(t, clientAddr, outerIP.DestinationAddr()) + require.Equal(t, routerAddr, outerIP.SourceAddr()) + require.Equal(t, clientAddr, innerIP.SourceAddr()) + require.Equal(t, destAddr, innerIP.DestinationAddr()) + require.Equal(t, originalIdent, innerICMP.Ident(), + "inner ident should be restored to original value") +} + +func TestDestinationICMPv4ErrorDstUnreachableRewrite(t *testing.T) { + t.Parallel() + + clientAddr := netip.MustParseAddr("10.0.0.2") + serverAddr := netip.MustParseAddr("192.168.10.254") + destAddr := netip.MustParseAddr("1.1.1.1") + var originalIdent uint16 = 4321 + + // Build DstUnreachable (e.g., port unreachable from 1.1.1.1) + packet := buildICMPv4ErrorWithInnerICMPInvertedIdent(t, + destAddr, serverAddr, // destination itself sends the error + serverAddr, destAddr, + originalIdent, 5, + header.ICMPv4DstUnreachable, + ) + + outerIP := header.IPv4(packet) + icmpHdr := header.ICMPv4(outerIP.Payload()) + innerIP := header.IPv4(outerIP.Payload()[header.ICMPv4MinimumSize:]) + innerICMP := header.ICMPv4(innerIP.Payload()) + + require.Equal(t, header.ICMPv4DstUnreachable, icmpHdr.Type()) + + // Apply rewrite + outerIP.SetDestinationAddr(clientAddr) + innerIP.SetSourceAddr(clientAddr) + innerICMP.SetIdent(^innerICMP.Ident()) + innerIP.SetChecksum(^innerIP.CalculateChecksum()) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + outerIP.SetChecksum(^outerIP.CalculateChecksum()) + + require.Equal(t, clientAddr, outerIP.DestinationAddr()) + require.Equal(t, clientAddr, innerIP.SourceAddr()) + require.Equal(t, originalIdent, innerICMP.Ident()) +} + +// --- ErrorListener tests --- + +func TestErrorListenerCreation(t *testing.T) { + t.Parallel() + // This test requires root on Linux/Darwin + if !canCreateRawSocket(t) { + t.SkipNow() + } + + // The ConnectDestination function internally creates an ErrorListener. + // Verify the destination can be created and closed cleanly. + dest, err := ping.ConnectDestination( + context.Background(), + logger.NOP(), + nil, + netip.MustParseAddr("127.0.0.1"), + nil, + 30*1e9, // 30s + ) + require.NoError(t, err) + require.False(t, dest.IsClosed()) + err = dest.Close() + require.NoError(t, err) + require.True(t, dest.IsClosed()) +} + +// --- IPv4 checksum validation after rewrite --- + +func TestICMPv4ErrorChecksumAfterRewrite(t *testing.T) { + t.Parallel() + + // Verify that checksums are valid after a complete rewrite cycle + packet := buildICMPv4ErrorWithInnerICMP(t, + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("192.168.10.254"), + netip.MustParseAddr("192.168.10.254"), + netip.MustParseAddr("1.1.1.1"), + 1234, 1, + header.ICMPv4TimeExceeded, + ) + + clientAddr := netip.MustParseAddr("10.0.0.2") + + outerIP := header.IPv4(packet) + icmpHdr := header.ICMPv4(outerIP.Payload()) + innerIP := header.IPv4(outerIP.Payload()[header.ICMPv4MinimumSize:]) + + // Apply rewrite + outerIP.SetDestinationAddr(clientAddr) + innerIP.SetSourceAddr(clientAddr) + innerIP.SetChecksum(0) + innerIP.SetChecksum(^innerIP.CalculateChecksum()) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + outerIP.SetChecksum(0) + outerIP.SetChecksum(^outerIP.CalculateChecksum()) + + // Validate outer IP checksum: calculate and verify it's consistent + savedChecksum := outerIP.Checksum() + outerIP.SetChecksum(0) + calculatedChecksum := ^outerIP.CalculateChecksum() + require.Equal(t, savedChecksum, calculatedChecksum, "outer IP checksum should be valid") + + // Validate inner IP checksum + savedInnerChecksum := innerIP.Checksum() + innerIP.SetChecksum(0) + calculatedInnerChecksum := ^innerIP.CalculateChecksum() + require.Equal(t, savedInnerChecksum, calculatedInnerChecksum, "inner IP checksum should be valid") +} + +// --- Helper to detect raw socket capability --- + +func canCreateRawSocket(t *testing.T) bool { + t.Helper() + // Try to create a destination - it will fail if no raw socket permission + dest, err := ping.ConnectDestination( + context.Background(), + logger.NOP(), + nil, + netip.MustParseAddr("127.0.0.1"), + nil, + 1e9, + ) + if err != nil { + return false + } + dest.Close() + return true +} + +// --- ICMP error with truncated inner packet (edge case) --- + +func TestSourceRewriterTruncatedInnerPacket(t *testing.T) { + t.Parallel() + + serverAddr := netip.MustParseAddr("192.168.10.254") + + rewriter := ping.NewSourceRewriter( + context.Background(), + logger.NOP(), + serverAddr, + netip.Addr{}, + ) + + // Build an ICMP error where the inner payload is too short to contain + // a valid inner IPv4 + ICMP header + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + 4 // only 4 bytes of inner data (too short) + packet := make([]byte, totalLen) + + ipHdr := header.IPv4(packet) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: netip.MustParseAddr("192.168.1.1"), + DstAddr: serverAddr, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + icmpHdr := header.ICMPv4(packet[header.IPv4MinimumSize:]) + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + + ok, err := rewriter.WriteBack(packet) + require.NoError(t, err) + require.False(t, ok, "should return false for truncated inner packet") +} + +// --- Unused helper for potential future IPv6 tests --- +// (IPv6 error rewrite tests could be added here when IPv6 environment is available) + +func buildICMPv4ErrorWithInnerUDP( + t *testing.T, + outerSrc, outerDst netip.Addr, + innerSrc, innerDst netip.Addr, + innerSrcPort, innerDstPort uint16, + icmpType header.ICMPv4Type, +) []byte { + t.Helper() + + innerIPLen := header.IPv4MinimumSize + header.UDPMinimumSize + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + innerIPLen + + packet := make([]byte, totalLen) + + outerIP := header.IPv4(packet) + outerIP.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: outerSrc, + DstAddr: outerDst, + }) + outerIP.SetChecksum(^outerIP.CalculateChecksum()) + + icmpOffset := header.IPv4MinimumSize + icmpHdr := header.ICMPv4(packet[icmpOffset:]) + icmpHdr.SetType(icmpType) + icmpHdr.SetCode(0) + + innerIPOffset := icmpOffset + header.ICMPv4MinimumSize + innerIP := header.IPv4(packet[innerIPOffset:]) + innerIP.Encode(&header.IPv4Fields{ + TotalLength: uint16(innerIPLen), + TTL: 1, + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: innerSrc, + DstAddr: innerDst, + }) + innerIP.SetChecksum(^innerIP.CalculateChecksum()) + + innerUDPOffset := innerIPOffset + header.IPv4MinimumSize + binary.BigEndian.PutUint16(packet[innerUDPOffset:], innerSrcPort) + binary.BigEndian.PutUint16(packet[innerUDPOffset+2:], innerDstPort) + binary.BigEndian.PutUint16(packet[innerUDPOffset+4:], header.UDPMinimumSize) + + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + + return packet +} diff --git a/ping/tcp_destination.go b/ping/tcp_destination.go new file mode 100644 index 00000000..f673bac7 --- /dev/null +++ b/ping/tcp_destination.go @@ -0,0 +1,167 @@ +package ping + +import ( + "context" + "net" + "net/netip" + "sync/atomic" + "time" + + tun "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/gtcpip/header" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" +) + +var _ tun.DirectRouteDestination = (*TCPDestination)(nil) + +const tcpRequestsCapacity = 4096 + +// TCPDestination sends raw TCP SYN packets (preserving TTL) and +// receives ICMP Time Exceeded / Destination Unreachable errors, +// enabling mtr --tcp / traceroute -T through a DirectRoute outbound. +type TCPDestination struct { + conn *tcpRawConn + errorListener *ErrorListener + ctx context.Context + logger logger.ContextLogger + destination netip.Addr + routeContext tun.DirectRouteContext + timeout time.Duration + closed atomic.Bool + + // Track active TCP source ports for ICMP error matching + requests freelru.Cache[uint16, struct{}] + originalSource common.TypedValue[netip.Addr] +} + +func ConnectTCPDestination( + ctx context.Context, + logger logger.ContextLogger, + controlFunc control.Func, + destination netip.Addr, + routeContext tun.DirectRouteContext, + timeout time.Duration, +) (tun.DirectRouteDestination, error) { + rawConn, err := connectTCPRaw(controlFunc, destination) + if err != nil { + return nil, err + } + d := &TCPDestination{ + conn: rawConn, + ctx: ctx, + logger: logger, + destination: destination, + routeContext: routeContext, + timeout: timeout, + requests: common.Must1(freelru.NewSynced[uint16, struct{}](tcpRequestsCapacity, maphash.NewHasher[uint16]().Hash32)), + } + d.requests.SetLifetime(timeout) + + if errorListener := tryListenErrors(ctx, logger, controlFunc, destination); errorListener != nil { + d.errorListener = errorListener + go d.loopReadErrors() + logger.DebugContext(ctx, "TCP ICMP error listener started") + } else { + logger.WarnContext(ctx, "TCP ICMP error listener not available") + } + + return d, nil +} + +func (d *TCPDestination) WritePacket(packet *buf.Buffer) error { + if !d.destination.Is6() { + ipHdr := header.IPv4(packet.Bytes()) + if !ipHdr.IsValid(packet.Len()) { + return E.New("invalid IPv4 header") + } + if ipHdr.TransportProtocol() != header.TCPProtocolNumber { + return E.New("not a TCP packet") + } + if ipHdr.PayloadLength() < header.TCPMinimumSize { + return E.New("invalid TCP header") + } + tcpHdr := header.TCP(ipHdr.Payload()) + srcPort := tcpHdr.SourcePort() + + d.originalSource.Store(ipHdr.SourceAddr()) + d.requests.Add(srcPort, struct{}{}) + + d.logger.TraceContext(d.ctx, "write TCP SYN from ", ipHdr.SourceAddr(), ":", srcPort, + " to ", ipHdr.DestinationAddr(), ":", tcpHdr.DestinationPort(), " ttl ", ipHdr.TTL()) + + // For IPv4 IPPROTO_RAW (IP_HDRINCL), set source to 0.0.0.0 + // so the kernel fills in the correct outgoing IP address. + // Without this, the source would be a private WG tunnel address + // that is not routable on the internet, so ICMP errors from + // intermediate routers would never reach us. + ipHdr.SetSourceAddr(netip.IPv4Unspecified()) + ipHdr.SetChecksum(0) // kernel recomputes when source is 0 + destAddr := &net.IPAddr{ + IP: d.destination.AsSlice(), + } + _, err := d.conn.conn.WriteTo(packet.Bytes(), destAddr) + packet.Release() + return err + } else { + ipHdr := header.IPv6(packet.Bytes()) + if !ipHdr.IsValid(packet.Len()) { + return E.New("invalid IPv6 header") + } + if ipHdr.TransportProtocol() != header.TCPProtocolNumber { + return E.New("not a TCP packet") + } + if ipHdr.PayloadLength() < header.TCPMinimumSize { + return E.New("invalid TCP header") + } + tcpHdr := header.TCP(ipHdr.Payload()) + srcPort := tcpHdr.SourcePort() + + d.originalSource.Store(ipHdr.SourceAddr()) + d.requests.Add(srcPort, struct{}{}) + + hopLimit := ipHdr.HopLimit() + if hopLimit > 0 { + _ = d.conn.SetHopLimit(hopLimit) + } + + d.logger.TraceContext(d.ctx, "write TCP SYN from ", ipHdr.SourceAddr(), ":", srcPort, + " to ", ipHdr.DestinationAddr(), ":", tcpHdr.DestinationPort(), " hoplimit ", hopLimit) + + // For IPv6, send only the TCP segment (kernel adds IPv6 header). + destAddr := &net.IPAddr{ + IP: d.destination.AsSlice(), + } + _, err := d.conn.conn.WriteTo(ipHdr.Payload(), destAddr) + packet.Release() + return err + } +} + +func (d *TCPDestination) loopReadErrors() { + transportErrorLoop( + d.errorListener, d.destination.Is6(), + d.requests, &d.originalSource, d.routeContext, + d.ctx, d.logger, + header.TCPProtocolNumber, header.TCPMinimumSize, "TCP", + func(srcPort, _ uint16) uint16 { return srcPort }, + nil, + ) +} + +func (d *TCPDestination) Close() error { + d.closed.Store(true) + if d.errorListener != nil { + _ = d.errorListener.Close() + } + return d.conn.Close() +} + +func (d *TCPDestination) IsClosed() bool { + return d.closed.Load() +} diff --git a/ping/tcp_socket_other.go b/ping/tcp_socket_other.go new file mode 100644 index 00000000..77d19962 --- /dev/null +++ b/ping/tcp_socket_other.go @@ -0,0 +1,24 @@ +//go:build !unix + +package ping + +import ( + "net/netip" + + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" +) + +type tcpRawConn struct{} + +func connectTCPRaw(controlFunc control.Func, destination netip.Addr) (*tcpRawConn, error) { + return nil, E.New("TCP traceroute not supported on this platform") +} + +func (c *tcpRawConn) SetHopLimit(hopLimit uint8) error { + return nil +} + +func (c *tcpRawConn) Close() error { + return nil +} diff --git a/ping/tcp_socket_unix.go b/ping/tcp_socket_unix.go new file mode 100644 index 00000000..ef934515 --- /dev/null +++ b/ping/tcp_socket_unix.go @@ -0,0 +1,86 @@ +//go:build unix + +package ping + +import ( + "net" + "net/netip" + "os" + "syscall" + + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/sys/unix" +) + +// tcpRawConn wraps a raw socket for sending TCP SYN packets. +// IPv4: IPPROTO_RAW (IP_HDRINCL automatic, sends full IP+TCP packet). +// IPv6: IPPROTO_TCP raw socket (kernel adds IPv6 header, sends TCP segment only). +type tcpRawConn struct { + fd int + file *os.File + conn net.PacketConn + isIPv6 bool +} + +func connectTCPRaw(controlFunc control.Func, destination netip.Addr) (*tcpRawConn, error) { + var ( + network string + fd int + err error + ) + isIPv6 := destination.Is6() + if !isIPv6 { + network = "ip4" + fd, err = unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW) + } else { + network = "ip6" + fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_TCP) + } + if err != nil { + return nil, E.Cause(err, "socket()") + } + + file := os.NewFile(uintptr(fd), "tcp-raw") + + if controlFunc != nil { + var syscallConn syscall.RawConn + syscallConn, err = file.SyscallConn() + if err != nil { + file.Close() + return nil, err + } + err = controlFunc(network, destination.String(), syscallConn) + if err != nil { + file.Close() + return nil, err + } + } + + packetConn, err := net.FilePacketConn(file) + if err != nil { + file.Close() + return nil, err + } + + return &tcpRawConn{ + fd: fd, + file: file, + conn: packetConn, + isIPv6: isIPv6, + }, nil +} + +func (c *tcpRawConn) SetHopLimit(hopLimit uint8) error { + rawConn, err := c.conn.(*net.IPConn).SyscallConn() + if err != nil { + return err + } + return setsockoptTTL(rawConn, c.isIPv6, hopLimit) +} + +func (c *tcpRawConn) Close() error { + c.file.Close() + return c.conn.Close() +} diff --git a/ping/transport_error_loop.go b/ping/transport_error_loop.go new file mode 100644 index 00000000..695ba311 --- /dev/null +++ b/ping/transport_error_loop.go @@ -0,0 +1,209 @@ +package ping + +import ( + "context" + "net/netip" + + tun "github.com/sagernet/sing-tun" + tcpip "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/header" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/contrab/freelru" + + "golang.org/x/net/ipv4" +) + +// transportErrorLoop reads ICMP Time Exceeded and Destination Unreachable +// errors from an ErrorListener and delivers matched ones back through +// routeContext. It handles both IPv4 and IPv6, including address rewriting +// and checksum recalculation. +// +// This is shared by UDPDestination and TCPDestination. The caller provides: +// - protocol/minSize: the expected inner transport protocol and minimum header size +// - name: protocol name for logging ("UDP" or "TCP") +// - matchPort: given (srcPort, dstPort), returns the port to check against requests +// - rewritePort: optional func to rewrite ports before checksum recalculation; +// receives the inner transport header as a byte slice +func transportErrorLoop( + el *ErrorListener, + isIPv6 bool, + requests freelru.Cache[uint16, struct{}], + originalSource *common.TypedValue[netip.Addr], + routeCtx tun.DirectRouteContext, + ctx context.Context, + log logger.ContextLogger, + protocol tcpip.TransportProtocolNumber, + minSize int, + name string, + matchPort func(srcPort, dstPort uint16) uint16, + rewritePort func(innerPayload []byte), +) { + defer el.Close() + for { + buffer := buf.NewSize(1500) + if isIPv6 { + buffer.Advance(header.IPv6MinimumSize) + } + oob := make([]byte, 128) + n, oobn, addr, err := el.ReadMsg(buffer.FreeBytes(), oob) + if err != nil { + buffer.Release() + if !E.IsClosed(err) { + log.ErrorContext(ctx, E.Cause(err, "receive ", name, " ICMP error")) + } + return + } + buffer.Truncate(n) + + forward := func() bool { + if !isIPv6 { + var ttl int + if oobn > 0 { + var cm ipv4.ControlMessage + if cm.Parse(oob[:oobn]) == nil { + ttl = cm.TTL + } + } + ipHdr := header.IPv4(buffer.Bytes()) + if !ipHdr.IsValid(n) { + return false + } + if ipHdr.PayloadLength() < header.ICMPv4MinimumSize { + return false + } + icmpHdr := header.ICMPv4(ipHdr.Payload()) + switch icmpHdr.Type() { + case header.ICMPv4TimeExceeded, header.ICMPv4DstUnreachable: + default: + return false + } + if len(ipHdr.Payload()) < header.ICMPv4MinimumSize+header.IPv4MinimumSize+minSize { + return false + } + innerIPHdr := header.IPv4(ipHdr.Payload()[header.ICMPv4MinimumSize:]) + if !innerIPHdr.IsValid(len(ipHdr.Payload()) - header.ICMPv4MinimumSize) { + return false + } + if innerIPHdr.TransportProtocol() != protocol { + return false + } + if innerIPHdr.PayloadLength() < uint16(minSize) { + return false + } + innerPayload := innerIPHdr.Payload() + srcPort := header.UDP(innerPayload).SourcePort() + dstPort := header.UDP(innerPayload).DestinationPort() + + log.DebugContext(ctx, name, " ICMPv4 error type ", uint8(icmpHdr.Type()), + " from ", addr, " inner: ", innerIPHdr.SourceAddr(), ":", srcPort, + " -> ", innerIPHdr.DestinationAddr(), ":", dstPort) + + if !requests.Contains(matchPort(srcPort, dstPort)) { + return false + } + + originalSrc := originalSource.Load() + if originalSrc.IsValid() { + ipHdr.SetDestinationAddr(originalSrc) + innerIPHdr.SetSourceAddr(originalSrc) + if rewritePort != nil { + rewritePort(innerPayload) + } + innerIPHdr.SetChecksum(0) + innerIPHdr.SetChecksum(^innerIPHdr.CalculateChecksum()) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + } else { + ipHdr.SetDestinationAddr(innerIPHdr.SourceAddr()) + } + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + log.TraceContext(ctx, "read ", name, " ICMPv4 error type ", uint8(icmpHdr.Type()), + " from ", addr, " ttl ", ttl, " -> ", ipHdr.DestinationAddr()) + return true + } + + // IPv6 + var hopLimit int + if oobn > 0 { + cm, cmErr := parseIPv6ControlMessage(oob[:oobn]) + if cmErr == nil && cm != nil { + hopLimit = cm.HopLimit + } + } + if n < header.ICMPv6MinimumSize { + return false + } + icmpHdr := header.ICMPv6(buffer.Bytes()) + switch icmpHdr.Type() { + case header.ICMPv6TimeExceeded, header.ICMPv6DstUnreachable: + default: + return false + } + if n < header.ICMPv6MinimumSize+header.IPv6MinimumSize+minSize { + return false + } + innerIPHdr := header.IPv6(buffer.Bytes()[header.ICMPv6MinimumSize:]) + if !innerIPHdr.IsValid(n - header.ICMPv6MinimumSize) { + return false + } + if innerIPHdr.TransportProtocol() != protocol { + return false + } + if innerIPHdr.PayloadLength() < uint16(minSize) { + return false + } + innerPayload := innerIPHdr.Payload() + srcPort := header.UDP(innerPayload).SourcePort() + dstPort := header.UDP(innerPayload).DestinationPort() + + log.DebugContext(ctx, name, " ICMPv6 error type ", uint8(icmpHdr.Type()), + " from ", addr, " inner: ", innerIPHdr.SourceAddr(), ":", srcPort, + " -> ", innerIPHdr.DestinationAddr(), ":", dstPort) + + if !requests.Contains(matchPort(srcPort, dstPort)) { + return false + } + + dstAddr := addr + originalSrc := originalSource.Load() + if originalSrc.IsValid() { + dstAddr = originalSrc + innerIPHdr.SetSourceAddr(originalSrc) + if rewritePort != nil { + rewritePort(innerPayload) + } + } else { + dstAddr = innerIPHdr.SourceAddr() + } + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: addr.AsSlice(), + Dst: dstAddr.AsSlice(), + })) + ipHdr := header.IPv6(buffer.ExtendHeader(header.IPv6MinimumSize)) + ipHdr.Encode(&header.IPv6Fields{ + PayloadLength: uint16(n), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: uint8(hopLimit), + SrcAddr: addr, + DstAddr: dstAddr, + }) + + log.TraceContext(ctx, "read ", name, " ICMPv6 error type ", uint8(icmpHdr.Type()), + " from ", addr, " hoplimit ", hopLimit, " -> ", dstAddr) + return true + }() + if forward { + if err = routeCtx.WritePacket(buffer.Bytes()); err != nil { + log.ErrorContext(ctx, E.Cause(err, "write ", name, " ICMP error")) + } + } + buffer.Release() + } +} diff --git a/ping/ttl_unix.go b/ping/ttl_unix.go new file mode 100644 index 00000000..f33d1524 --- /dev/null +++ b/ping/ttl_unix.go @@ -0,0 +1,59 @@ +//go:build unix + +package ping + +import ( + "net" + "syscall" +) + +// setsockoptTTL sets IP_TTL or IPV6_UNICAST_HOPS on a raw connection. +func setsockoptTTL(rawConn syscall.RawConn, isIPv6 bool, ttl uint8) error { + var sockErr error + var err error + if !isIPv6 { + err = rawConn.Control(func(fd uintptr) { + sockErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TTL, int(ttl)) + }) + } else { + err = rawConn.Control(func(fd uintptr) { + sockErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, int(ttl)) + }) + } + if err != nil { + return err + } + return sockErr +} + +func (c *Conn) SetTTL(ttl uint8) error { + if setter, ok := c.conn.(ttlSetter); ok { + setter.SetTTL(ttl) + return nil + } + syscallConn, ok := c.conn.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if !ok { + return nil + } + rawConn, err := syscallConn.SyscallConn() + if err != nil { + return err + } + return setsockoptTTL(rawConn, c.destination.Is6(), ttl) +} + +func (c *UnprivilegedConn) setTTL(conn net.Conn, ttl uint8) error { + syscallConn, ok := conn.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if !ok { + return nil + } + rawConn, err := syscallConn.SyscallConn() + if err != nil { + return err + } + return setsockoptTTL(rawConn, c.destination.Is6(), ttl) +} diff --git a/ping/ttl_windows.go b/ping/ttl_windows.go new file mode 100644 index 00000000..c6c60743 --- /dev/null +++ b/ping/ttl_windows.go @@ -0,0 +1,49 @@ +package ping + +import ( + "net" + "syscall" + + "golang.org/x/sys/windows" +) + +// setsockoptTTL sets IP_TTL or IPV6_UNICAST_HOPS on a raw connection. +func setsockoptTTL(rawConn syscall.RawConn, isIPv6 bool, ttl uint8) error { + var sockErr error + var err error + if !isIPv6 { + err = rawConn.Control(func(fd uintptr) { + sockErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, windows.IP_TTL, int(ttl)) + }) + } else { + err = rawConn.Control(func(fd uintptr) { + sockErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_UNICAST_HOPS, int(ttl)) + }) + } + if err != nil { + return err + } + return sockErr +} + +func (c *Conn) SetTTL(ttl uint8) error { + if setter, ok := c.conn.(ttlSetter); ok { + setter.SetTTL(ttl) + return nil + } + syscallConn, ok := c.conn.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if !ok { + return nil + } + rawConn, err := syscallConn.SyscallConn() + if err != nil { + return err + } + return setsockoptTTL(rawConn, c.destination.Is6(), ttl) +} + +func (c *UnprivilegedConn) setTTL(conn net.Conn, ttl uint8) error { + return nil +} diff --git a/ping/udp_destination.go b/ping/udp_destination.go new file mode 100644 index 00000000..18663d7e --- /dev/null +++ b/ping/udp_destination.go @@ -0,0 +1,184 @@ +package ping + +import ( + "context" + "net" + "net/netip" + "sync/atomic" + "time" + + tun "github.com/sagernet/sing-tun" + "github.com/sagernet/sing-tun/gtcpip/header" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" +) + +var _ tun.DirectRouteDestination = (*UDPDestination)(nil) + +const udpRequestsCapacity = 4096 + +type UDPDestination struct { + conn *net.UDPConn + errorListener *ErrorListener + ctx context.Context + logger logger.ContextLogger + destination netip.Addr + routeContext tun.DirectRouteContext + timeout time.Duration + closed atomic.Bool + + // Track active UDP destination ports for ICMP error matching + requests freelru.Cache[uint16, struct{}] + originalSource common.TypedValue[netip.Addr] + originalSourcePort uint16 // client's original UDP source port + localPort uint16 // kernel-assigned local UDP source port +} + +func ConnectUDPDestination( + ctx context.Context, + logger logger.ContextLogger, + controlFunc control.Func, + destination netip.Addr, + routeContext tun.DirectRouteContext, + timeout time.Duration, +) (tun.DirectRouteDestination, error) { + udpConn, err := connectUDP(controlFunc, destination) + if err != nil { + return nil, err + } + d := &UDPDestination{ + conn: udpConn, + ctx: ctx, + logger: logger, + destination: destination, + routeContext: routeContext, + timeout: timeout, + requests: common.Must1(freelru.NewSynced[uint16, struct{}](udpRequestsCapacity, maphash.NewHasher[uint16]().Hash32)), + } + d.requests.SetLifetime(timeout) + if localAddr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok { + d.localPort = uint16(localAddr.Port) + } + + if errorListener := tryListenErrors(ctx, logger, controlFunc, destination); errorListener != nil { + d.errorListener = errorListener + go d.loopReadErrors() + logger.DebugContext(ctx, "UDP ICMP error listener started") + } else { + logger.WarnContext(ctx, "UDP ICMP error listener not available") + } + + return d, nil +} + +func (d *UDPDestination) WritePacket(packet *buf.Buffer) error { + if !d.destination.Is6() { + ipHdr := header.IPv4(packet.Bytes()) + if !ipHdr.IsValid(packet.Len()) { + return E.New("invalid IPv4 header") + } + if ipHdr.TransportProtocol() != header.UDPProtocolNumber { + return E.New("not a UDP packet") + } + if ipHdr.PayloadLength() < header.UDPMinimumSize { + return E.New("invalid UDP header") + } + udpHdr := header.UDP(ipHdr.Payload()) + srcPort := udpHdr.SourcePort() + dstPort := udpHdr.DestinationPort() + + d.originalSource.Store(ipHdr.SourceAddr()) + d.originalSourcePort = srcPort + + // Register destination port for ICMP error matching + d.requests.Add(dstPort, struct{}{}) + + ttl := ipHdr.TTL() + if ttl > 0 { + _ = setUDPTTL(d.conn, false, ttl) + } + + d.logger.TraceContext(d.ctx, "write UDP from ", ipHdr.SourceAddr(), ":", srcPort, " to ", ipHdr.DestinationAddr(), ":", dstPort, " ttl ", ttl) + + // Send UDP payload to the correct destination port + destAddr := &net.UDPAddr{ + IP: d.destination.AsSlice(), + Port: int(dstPort), + } + udpPayload := udpHdr.Payload() + _, err := d.conn.WriteTo(udpPayload, destAddr) + packet.Release() + return err + } else { + ipHdr := header.IPv6(packet.Bytes()) + if !ipHdr.IsValid(packet.Len()) { + return E.New("invalid IPv6 header") + } + if ipHdr.TransportProtocol() != header.UDPProtocolNumber { + return E.New("not a UDP packet") + } + if ipHdr.PayloadLength() < header.UDPMinimumSize { + return E.New("invalid UDP header") + } + udpHdr := header.UDP(ipHdr.Payload()) + srcPort := udpHdr.SourcePort() + dstPort := udpHdr.DestinationPort() + + d.originalSource.Store(ipHdr.SourceAddr()) + d.originalSourcePort = srcPort + + // Register destination port for ICMP error matching + d.requests.Add(dstPort, struct{}{}) + + hopLimit := ipHdr.HopLimit() + if hopLimit > 0 { + _ = setUDPTTL(d.conn, true, hopLimit) + } + + d.logger.TraceContext(d.ctx, "write UDP from ", ipHdr.SourceAddr(), ":", srcPort, " to ", ipHdr.DestinationAddr(), ":", dstPort, " hoplimit ", hopLimit) + + destAddr := &net.UDPAddr{ + IP: d.destination.AsSlice(), + Port: int(dstPort), + } + udpPayload := udpHdr.Payload() + _, err := d.conn.WriteTo(udpPayload, destAddr) + packet.Release() + return err + } +} + +func (d *UDPDestination) loopReadErrors() { + transportErrorLoop( + d.errorListener, d.destination.Is6(), + d.requests, &d.originalSource, d.routeContext, + d.ctx, d.logger, + header.UDPProtocolNumber, header.UDPMinimumSize, "UDP", + func(_, dstPort uint16) uint16 { return dstPort }, + func(innerPayload []byte) { + if d.originalSourcePort != 0 && d.localPort != 0 { + innerUDP := header.UDP(innerPayload) + if innerUDP.SourcePort() == d.localPort { + innerUDP.SetSourcePort(d.originalSourcePort) + } + } + }, + ) +} + +func (d *UDPDestination) Close() error { + d.closed.Store(true) + if d.errorListener != nil { + _ = d.errorListener.Close() + } + return d.conn.Close() +} + +func (d *UDPDestination) IsClosed() bool { + return d.closed.Load() +} diff --git a/ping/udp_destination_internal_test.go b/ping/udp_destination_internal_test.go new file mode 100644 index 00000000..dcc7424b --- /dev/null +++ b/ping/udp_destination_internal_test.go @@ -0,0 +1,85 @@ +package ping + +import ( + "net/netip" + "testing" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" + + "github.com/stretchr/testify/require" +) + +func TestRegisterRequest(t *testing.T) { + t.Parallel() + requests := common.Must1(freelru.NewSynced[uint16, struct{}](udpRequestsCapacity, maphash.NewHasher[uint16]().Hash32)) + requests.SetLifetime(5 * time.Second) + d := &UDPDestination{ + timeout: 5 * time.Second, + requests: requests, + } + + d.requests.Add(33434, struct{}{}) + d.requests.Add(33435, struct{}{}) + d.requests.Add(33436, struct{}{}) + + require.Equal(t, 3, d.requests.Len()) + require.True(t, d.requests.Contains(33434)) + require.True(t, d.requests.Contains(33435)) + require.True(t, d.requests.Contains(33436)) +} + +func TestRegisterRequestExpiry(t *testing.T) { + t.Parallel() + requests := common.Must1(freelru.NewSynced[uint16, struct{}](udpRequestsCapacity, maphash.NewHasher[uint16]().Hash32)) + requests.SetLifetime(50 * time.Millisecond) + d := &UDPDestination{ + timeout: 50 * time.Millisecond, + requests: requests, + } + + d.requests.Add(33434, struct{}{}) + time.Sleep(100 * time.Millisecond) + + require.False(t, d.requests.Contains(33434), "expired request should not be found") + + d.requests.Add(33435, struct{}{}) + require.True(t, d.requests.Contains(33435), "new request should exist") +} + +func TestRegisterRequestLimit(t *testing.T) { + t.Parallel() + requests := common.Must1(freelru.NewSynced[uint16, struct{}](udpRequestsCapacity, maphash.NewHasher[uint16]().Hash32)) + requests.SetLifetime(1 * time.Hour) + d := &UDPDestination{ + timeout: 1 * time.Hour, + requests: requests, + } + + // Fill beyond capacity — LRU eviction keeps size bounded + for i := uint16(0); i < udpRequestsCapacity+100; i++ { + d.requests.Add(i, struct{}{}) + } + + require.LessOrEqual(t, d.requests.Len(), udpRequestsCapacity, + "request count should not exceed capacity") +} + +func TestOriginalSourcePort(t *testing.T) { + t.Parallel() + requests := common.Must1(freelru.NewSynced[uint16, struct{}](udpRequestsCapacity, maphash.NewHasher[uint16]().Hash32)) + d := &UDPDestination{ + timeout: 5 * time.Second, + requests: requests, + localPort: 23674, + } + + d.originalSource.Store(netip.MustParseAddr("10.0.0.2")) + d.originalSourcePort = 60183 + + require.Equal(t, uint16(23674), d.localPort) + require.Equal(t, uint16(60183), d.originalSourcePort) + require.Equal(t, netip.MustParseAddr("10.0.0.2"), d.originalSource.Load()) +} diff --git a/ping/udp_destination_test.go b/ping/udp_destination_test.go new file mode 100644 index 00000000..3b1c9ea8 --- /dev/null +++ b/ping/udp_destination_test.go @@ -0,0 +1,449 @@ +package ping_test + +import ( + "context" + "encoding/binary" + "net" + "net/netip" + "os" + "runtime" + "testing" + "time" + + "github.com/sagernet/sing-tun/gtcpip/header" + "github.com/sagernet/sing-tun/ping" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/logger" + + "github.com/stretchr/testify/require" +) + +func TestUDPDestinationIsClosed(t *testing.T) { + t.Parallel() + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.SkipNow() + } + if runtime.GOOS != "windows" && os.Getuid() != 0 { + t.SkipNow() + } + destination, err := ping.ConnectUDPDestination( + context.Background(), + logger.NOP(), + nil, + netip.MustParseAddr("127.0.0.1"), + nil, + 30*time.Second, + ) + require.NoError(t, err) + defer destination.Close() + + require.False(t, destination.IsClosed()) + destination.Close() + require.True(t, destination.IsClosed()) +} + +func TestUDPDestinationWritePacketIPv4(t *testing.T) { + t.Parallel() + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.SkipNow() + } + if runtime.GOOS != "windows" && os.Getuid() != 0 { + t.SkipNow() + } + + // Start a local UDP listener to receive the probe + listener, err := net.ListenPacket("udp4", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + listenerAddr := listener.LocalAddr().(*net.UDPAddr) + + destination, err := ping.ConnectUDPDestination( + context.Background(), + logger.NOP(), + nil, + netip.MustParseAddr("127.0.0.1"), + nil, + 30*time.Second, + ) + require.NoError(t, err) + defer destination.Close() + + // Build an IPv4+UDP packet + payload := []byte("traceroute-probe") + pkt := buildIPv4UDPPacket(t, + netip.MustParseAddr("10.0.0.2"), + netip.MustParseAddr("127.0.0.1"), + 12345, + uint16(listenerAddr.Port), + 64, + payload, + ) + + err = destination.WritePacket(pkt) + require.NoError(t, err) + + // Read from the listener + recvBuf := make([]byte, 1500) + require.NoError(t, listener.SetReadDeadline(time.Now().Add(3*time.Second))) + n, _, readErr := listener.ReadFrom(recvBuf) + require.NoError(t, readErr) + require.Equal(t, payload, recvBuf[:n]) +} + +func TestUDPDestinationWritePacketInvalidHeader(t *testing.T) { + t.Parallel() + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.SkipNow() + } + if runtime.GOOS != "windows" && os.Getuid() != 0 { + t.SkipNow() + } + + destination, err := ping.ConnectUDPDestination( + context.Background(), + logger.NOP(), + nil, + netip.MustParseAddr("127.0.0.1"), + nil, + 30*time.Second, + ) + require.NoError(t, err) + defer destination.Close() + + t.Run("too-short", func(t *testing.T) { + pkt := buf.As([]byte{0x45, 0x00}).ToOwned() + err := destination.WritePacket(pkt) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid IPv4 header") + }) + + t.Run("not-udp", func(t *testing.T) { + // Build an IPv4 packet with ICMP protocol instead of UDP + pkt := buildIPv4ICMPPacket(t, + netip.MustParseAddr("10.0.0.2"), + netip.MustParseAddr("127.0.0.1"), + 64, + ) + err := destination.WritePacket(pkt) + require.Error(t, err) + require.Contains(t, err.Error(), "not a UDP packet") + }) +} + +func TestUDPDestinationMultiplePortProbes(t *testing.T) { + t.Parallel() + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.SkipNow() + } + if runtime.GOOS != "windows" && os.Getuid() != 0 { + t.SkipNow() + } + + // Start multiple listeners on different ports (simulating traceroute behavior) + const numProbes = 3 + listeners := make([]net.PacketConn, numProbes) + for i := range numProbes { + l, err := net.ListenPacket("udp4", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + listeners[i] = l + } + + destination, err := ping.ConnectUDPDestination( + context.Background(), + logger.NOP(), + nil, + netip.MustParseAddr("127.0.0.1"), + nil, + 30*time.Second, + ) + require.NoError(t, err) + defer destination.Close() + + // Send probes with different TTLs and destination ports (like mtr --udp) + for i, l := range listeners { + lAddr := l.LocalAddr().(*net.UDPAddr) + payload := []byte{byte(i + 1)} // simple payload with hop number + pkt := buildIPv4UDPPacket(t, + netip.MustParseAddr("10.0.0.2"), + netip.MustParseAddr("127.0.0.1"), + 60183, + uint16(lAddr.Port), + uint8(i+1), // TTL 1, 2, 3 + payload, + ) + err := destination.WritePacket(pkt) + require.NoError(t, err) + } + + // All probes should arrive at their respective listeners + for i, l := range listeners { + recvBuf := make([]byte, 64) + require.NoError(t, l.SetReadDeadline(time.Now().Add(3*time.Second))) + n, _, readErr := l.ReadFrom(recvBuf) + require.NoError(t, readErr, "probe %d should arrive", i+1) + require.Equal(t, []byte{byte(i + 1)}, recvBuf[:n]) + } +} + +// buildIPv4UDPPacket constructs a valid IPv4+UDP packet as a buf.Buffer. +func buildIPv4UDPPacket( + t *testing.T, + src, dst netip.Addr, + srcPort, dstPort uint16, + ttl uint8, + payload []byte, +) *buf.Buffer { + t.Helper() + + udpLen := uint16(header.UDPMinimumSize + len(payload)) + totalLen := uint16(header.IPv4MinimumSize) + udpLen + + packet := make([]byte, totalLen) + + // Encode IPv4 header + ipHdr := header.IPv4(packet) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + TTL: ttl, + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: src, + DstAddr: dst, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + // Encode UDP header + udpHdr := header.UDP(packet[header.IPv4MinimumSize:]) + udpHdr.Encode(&header.UDPFields{ + SrcPort: srcPort, + DstPort: dstPort, + Length: udpLen, + }) + copy(packet[header.IPv4MinimumSize+header.UDPMinimumSize:], payload) + + return buf.As(packet).ToOwned() +} + +// buildIPv4ICMPPacket constructs a minimal IPv4+ICMP packet (for testing non-UDP rejection). +func buildIPv4ICMPPacket( + t *testing.T, + src, dst netip.Addr, + ttl uint8, +) *buf.Buffer { + t.Helper() + + icmpLen := header.ICMPv4MinimumSize + totalLen := uint16(header.IPv4MinimumSize + icmpLen) + + packet := make([]byte, totalLen) + + ipHdr := header.IPv4(packet) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + TTL: ttl, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: src, + DstAddr: dst, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + icmpHdr := header.ICMPv4(packet[header.IPv4MinimumSize:]) + icmpHdr.SetType(header.ICMPv4Echo) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + + return buf.As(packet).ToOwned() +} + +// buildICMPv4ErrorWithUDP constructs an ICMP Time Exceeded error containing an inner IPv4+UDP packet. +// This simulates what a router sends when TTL expires on a UDP probe. +func buildICMPv4ErrorWithUDP( + t *testing.T, + outerSrc, outerDst netip.Addr, + innerSrc, innerDst netip.Addr, + innerSrcPort, innerDstPort uint16, + icmpType header.ICMPv4Type, +) []byte { + t.Helper() + + // Inner: IPv4 + UDP header (minimum, no payload as per RFC) + innerIPLen := header.IPv4MinimumSize + header.UDPMinimumSize + // ICMP header (8 bytes) + inner IP+UDP + icmpPayloadLen := header.ICMPv4MinimumSize + innerIPLen + totalLen := header.IPv4MinimumSize + icmpPayloadLen + + packet := make([]byte, totalLen) + + // Outer IPv4 + outerIP := header.IPv4(packet) + outerIP.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: outerSrc, + DstAddr: outerDst, + }) + outerIP.SetChecksum(^outerIP.CalculateChecksum()) + + // ICMP header + icmpOffset := header.IPv4MinimumSize + icmpHdr := header.ICMPv4(packet[icmpOffset:]) + icmpHdr.SetType(icmpType) + icmpHdr.SetCode(0) + + // Inner IPv4 (inside ICMP error) + innerIPOffset := icmpOffset + header.ICMPv4MinimumSize + innerIP := header.IPv4(packet[innerIPOffset:]) + innerIP.Encode(&header.IPv4Fields{ + TotalLength: uint16(innerIPLen), + TTL: 1, + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: innerSrc, + DstAddr: innerDst, + }) + innerIP.SetChecksum(^innerIP.CalculateChecksum()) + + // Inner UDP header + innerUDPOffset := innerIPOffset + header.IPv4MinimumSize + innerUDP := header.UDP(packet[innerUDPOffset:]) + binary.BigEndian.PutUint16(packet[innerUDPOffset:], innerSrcPort) + binary.BigEndian.PutUint16(packet[innerUDPOffset+2:], innerDstPort) + _ = innerUDP + binary.BigEndian.PutUint16(packet[innerUDPOffset+4:], header.UDPMinimumSize) // length + + // Calculate ICMP checksum + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + + return packet +} + +func TestBuildIPv4UDPPacket(t *testing.T) { + t.Parallel() + payload := []byte("hello") + pkt := buildIPv4UDPPacket(t, + netip.MustParseAddr("10.0.0.2"), + netip.MustParseAddr("1.1.1.1"), + 12345, 33434, 5, payload, + ) + defer pkt.Release() + + ipHdr := header.IPv4(pkt.Bytes()) + require.True(t, ipHdr.IsValid(pkt.Len())) + require.Equal(t, netip.MustParseAddr("10.0.0.2"), ipHdr.SourceAddr()) + require.Equal(t, netip.MustParseAddr("1.1.1.1"), ipHdr.DestinationAddr()) + require.Equal(t, uint8(5), ipHdr.TTL()) + require.Equal(t, header.UDPProtocolNumber, ipHdr.TransportProtocol()) + + udpHdr := header.UDP(ipHdr.Payload()) + require.Equal(t, uint16(12345), udpHdr.SourcePort()) + require.Equal(t, uint16(33434), udpHdr.DestinationPort()) + require.Equal(t, payload, udpHdr.Payload()) +} + +func TestBuildICMPv4ErrorWithUDP(t *testing.T) { + t.Parallel() + packet := buildICMPv4ErrorWithUDP(t, + netip.MustParseAddr("192.168.1.1"), // router + netip.MustParseAddr("192.168.10.254"), // server + netip.MustParseAddr("192.168.10.254"), // inner src (server's real IP) + netip.MustParseAddr("1.1.1.1"), // inner dst + 23674, // inner src port (kernel port) + 33434, // inner dst port + header.ICMPv4TimeExceeded, + ) + + outerIP := header.IPv4(packet) + require.True(t, outerIP.IsValid(len(packet))) + require.Equal(t, header.ICMPv4ProtocolNumber, outerIP.TransportProtocol()) + + icmpHdr := header.ICMPv4(outerIP.Payload()) + require.Equal(t, header.ICMPv4TimeExceeded, icmpHdr.Type()) + + innerIP := header.IPv4(outerIP.Payload()[header.ICMPv4MinimumSize:]) + require.True(t, innerIP.IsValid(len(outerIP.Payload())-header.ICMPv4MinimumSize)) + require.Equal(t, header.UDPProtocolNumber, innerIP.TransportProtocol()) + + innerUDP := header.UDP(innerIP.Payload()) + require.Equal(t, uint16(23674), innerUDP.SourcePort()) + require.Equal(t, uint16(33434), innerUDP.DestinationPort()) +} + +func TestICMPv4ErrorRewriteAddresses(t *testing.T) { + t.Parallel() + + // Simulate the rewrite logic from loopReadErrors (IPv4 path) + // Scenario: ICMP TTL Exceeded from router 192.168.1.1 → server 192.168.10.254 + // Inner packet: server 192.168.10.254:23674 → 1.1.1.1:33434 + // After rewrite: outer dst → 10.0.0.2, inner src → 10.0.0.2, inner UDP src port → 60183 + packet := buildICMPv4ErrorWithUDP(t, + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("192.168.10.254"), + netip.MustParseAddr("192.168.10.254"), + netip.MustParseAddr("1.1.1.1"), + 23674, 33434, + header.ICMPv4TimeExceeded, + ) + + originalSource := netip.MustParseAddr("10.0.0.2") + var originalSourcePort uint16 = 60183 + var localPort uint16 = 23674 + + // Apply the same rewrite logic as loopReadErrors + outerIP := header.IPv4(packet) + icmpHdr := header.ICMPv4(outerIP.Payload()) + innerIP := header.IPv4(outerIP.Payload()[header.ICMPv4MinimumSize:]) + innerUDP := header.UDP(innerIP.Payload()) + + outerIP.SetDestinationAddr(originalSource) + innerIP.SetSourceAddr(originalSource) + if originalSourcePort != 0 && localPort != 0 && innerUDP.SourcePort() == localPort { + innerUDP.SetSourcePort(originalSourcePort) + } + innerIP.SetChecksum(^innerIP.CalculateChecksum()) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) + outerIP.SetChecksum(^outerIP.CalculateChecksum()) + + // Verify rewrites + require.Equal(t, netip.MustParseAddr("10.0.0.2"), outerIP.DestinationAddr(), + "outer dst should be rewritten to client tunnel IP") + require.Equal(t, netip.MustParseAddr("192.168.1.1"), outerIP.SourceAddr(), + "outer src should remain as router IP") + require.Equal(t, netip.MustParseAddr("10.0.0.2"), innerIP.SourceAddr(), + "inner src should be rewritten to client tunnel IP") + require.Equal(t, netip.MustParseAddr("1.1.1.1"), innerIP.DestinationAddr(), + "inner dst should remain as original destination") + require.Equal(t, uint16(60183), innerUDP.SourcePort(), + "inner UDP src port should be rewritten from kernel port to client port") + require.Equal(t, uint16(33434), innerUDP.DestinationPort(), + "inner UDP dst port should remain unchanged") +} + +func TestICMPv4ErrorNoPortRewriteWhenMismatch(t *testing.T) { + t.Parallel() + + // When inner src port doesn't match localPort, no rewrite should happen + packet := buildICMPv4ErrorWithUDP(t, + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("192.168.10.254"), + netip.MustParseAddr("192.168.10.254"), + netip.MustParseAddr("1.1.1.1"), + 99999%65536, 33434, // different port from localPort + header.ICMPv4TimeExceeded, + ) + + var localPort uint16 = 23674 + var originalSourcePort uint16 = 60183 + + outerIP := header.IPv4(packet) + innerIP := header.IPv4(outerIP.Payload()[header.ICMPv4MinimumSize:]) + innerUDP := header.UDP(innerIP.Payload()) + + originalInnerSrcPort := innerUDP.SourcePort() + + // Apply conditional rewrite + if originalSourcePort != 0 && localPort != 0 && innerUDP.SourcePort() == localPort { + innerUDP.SetSourcePort(originalSourcePort) + } + + require.Equal(t, originalInnerSrcPort, innerUDP.SourcePort(), + "inner UDP src port should NOT be rewritten when it doesn't match localPort") +} diff --git a/ping/udp_socket_other.go b/ping/udp_socket_other.go new file mode 100644 index 00000000..a39ad18b --- /dev/null +++ b/ping/udp_socket_other.go @@ -0,0 +1,19 @@ +//go:build !unix + +package ping + +import ( + "net" + "net/netip" + + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" +) + +func connectUDP(controlFunc control.Func, destination netip.Addr) (*net.UDPConn, error) { + return nil, E.New("UDP traceroute not supported on this platform") +} + +func setUDPTTL(conn *net.UDPConn, isIPv6 bool, ttl uint8) error { + return nil +} diff --git a/ping/udp_socket_unix.go b/ping/udp_socket_unix.go new file mode 100644 index 00000000..136dfebf --- /dev/null +++ b/ping/udp_socket_unix.go @@ -0,0 +1,76 @@ +//go:build unix + +package ping + +import ( + "net" + "net/netip" + "syscall" + + "os" + + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "golang.org/x/sys/unix" +) + +func connectUDP(controlFunc control.Func, destination netip.Addr) (*net.UDPConn, error) { + var ( + network string + fd int + err error + ) + if destination.Is4() { + network = "udp4" + fd, err = unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + } else { + network = "udp6" + fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + } + if err != nil { + return nil, E.Cause(err, "socket()") + } + + file := os.NewFile(uintptr(fd), "udp-traceroute") + defer file.Close() + + if controlFunc != nil { + var syscallConn syscall.RawConn + syscallConn, err = file.SyscallConn() + if err != nil { + return nil, err + } + err = controlFunc(network, destination.String(), syscallConn) + if err != nil { + return nil, err + } + } + + var bindAddress netip.Addr + if !destination.Is6() { + bindAddress = netip.AddrFrom4([4]byte{}) + } else { + bindAddress = netip.AddrFrom16([16]byte{}) + } + + err = unix.Bind(fd, M.AddrPortToSockaddr(netip.AddrPortFrom(bindAddress, 0))) + if err != nil { + return nil, E.Cause(err, "bind()") + } + + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, err + } + return packetConn.(*net.UDPConn), nil +} + +func setUDPTTL(conn *net.UDPConn, isIPv6 bool, ttl uint8) error { + rawConn, err := conn.SyscallConn() + if err != nil { + return err + } + return setsockoptTTL(rawConn, isIPv6, ttl) +} diff --git a/redirect_linux.go b/redirect_linux.go index e9c892c8..d575d192 100644 --- a/redirect_linux.go +++ b/redirect_linux.go @@ -44,6 +44,8 @@ type autoRedirect struct { nfqueueEnabled bool redirectRouteTableIndex int redirectInterfaces []control.Interface + dockerFirewallMonitor *nftables.Monitor + dockerFirewallDone chan struct{} } func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { diff --git a/redirect_nftables.go b/redirect_nftables.go index 266bbe91..f17e1f36 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -283,11 +283,15 @@ func (r *autoRedirect) setupNFTables() error { if err != nil { return E.Cause(err, "configure openwrt firewall4") } - err = nft.Flush() if err != nil { return E.Cause(err, "flush nftables") } + r.startDockerFirewallMonitor() + err = r.configureDockerFirewall(false) + if err != nil && r.logger != nil { + r.logger.Warn("configure docker firewall: ", err) + } r.networkListener = r.networkMonitor.RegisterCallback(func() { err = r.nftablesUpdateLocalAddressSet() @@ -361,6 +365,7 @@ func (r *autoRedirect) cleanupNFTables() { if r.networkListener != nil { r.networkMonitor.UnregisterCallback(r.networkListener) } + r.stopDockerFirewallMonitor() nft, err := nftables.New() if err != nil { return @@ -372,6 +377,10 @@ func (r *autoRedirect) cleanupNFTables() { _ = r.configureOpenWRTFirewall4(nft, true) _ = nft.Flush() _ = nft.CloseLasting() + err = r.configureDockerFirewall(true) + if err != nil && r.logger != nil { + r.logger.Warn("cleanup docker firewall: ", err) + } } func (r *autoRedirect) nftablesCreatePreMatchChains(nft *nftables.Conn, table *nftables.Table) error { diff --git a/redirect_nftables_docker.go b/redirect_nftables_docker.go new file mode 100644 index 00000000..0c6fc896 --- /dev/null +++ b/redirect_nftables_docker.go @@ -0,0 +1,266 @@ +//go:build linux + +package tun + +import ( + "bytes" + "strings" + + "github.com/sagernet/nftables" + "github.com/sagernet/nftables/expr" + "github.com/sagernet/nftables/userdata" + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + nftablesDockerFilterTable = "filter" + nftablesDockerUserChain = "DOCKER-USER" +) + +func (r *autoRedirect) startDockerFirewallMonitor() { + if r.dockerFirewallMonitor != nil { + return + } + doneCh := make(chan struct{}) + r.dockerFirewallDone = doneCh + monitor := nftables.NewMonitor( + nftables.WithMonitorAction(nftables.MonitorActionAny), + nftables.WithMonitorObject(nftables.MonitorObjectRuleset), + nftables.WithMonitorEventBuffer(16), + ) + nft, err := nftables.New() + if err != nil { + if r.logger != nil { + r.logger.Warn("create nftables monitor connection: ", err) + } + close(doneCh) + r.dockerFirewallDone = nil + return + } + events, err := nft.AddGenerationalMonitor(monitor) + _ = nft.CloseLasting() + if err != nil { + if r.logger != nil { + r.logger.Warn("start nftables monitor: ", err) + } + close(doneCh) + r.dockerFirewallDone = nil + return + } + r.dockerFirewallMonitor = monitor + go r.loopDockerFirewallMonitor(events, doneCh) +} + +func (r *autoRedirect) stopDockerFirewallMonitor() { + if r.dockerFirewallMonitor == nil { + return + } + _ = r.dockerFirewallMonitor.Close() + <-r.dockerFirewallDone + r.dockerFirewallMonitor = nil + r.dockerFirewallDone = nil +} + +func (r *autoRedirect) loopDockerFirewallMonitor(events <-chan *nftables.MonitorEvents, doneCh chan<- struct{}) { + defer close(doneCh) + for monitorEvents := range events { + if monitorEvents != nil && monitorEvents.GeneratedBy != nil && monitorEvents.GeneratedBy.Error != nil { + if r.logger != nil { + r.logger.Warn("nftables monitor closed: ", monitorEvents.GeneratedBy.Error) + } + return + } + if !nftablesDockerFirewallEventsRelevant(monitorEvents) { + continue + } + err := r.configureDockerFirewall(false) + if err != nil && r.logger != nil { + r.logger.Warn("update docker firewall: ", err) + } + } +} + +func (r *autoRedirect) configureDockerFirewall(cleanup bool) error { + nft, err := nftables.New() + if err != nil { + return E.Cause(err, "create nftables connection") + } + defer nft.CloseLasting() + + err = r.configureDockerFirewallWithConn(nft, cleanup) + if err != nil { + return err + } + return nft.Flush() +} + +func (r *autoRedirect) configureDockerFirewallWithConn(nft *nftables.Conn, cleanup bool) error { + var err error + if r.enableIPv4 { + err = E.Errors(err, r.configureDockerFirewallForFamily(nft, nftables.TableFamilyIPv4, cleanup)) + } + if r.enableIPv6 { + err = E.Errors(err, r.configureDockerFirewallForFamily(nft, nftables.TableFamilyIPv6, cleanup)) + } + return err +} + +func (r *autoRedirect) configureDockerFirewallForFamily(nft *nftables.Conn, family nftables.TableFamily, cleanup bool) error { + table, chain, loaded, err := nftablesLoadDockerUserChain(nft, family) + if err != nil || !loaded { + return err + } + err = r.configureDockerFirewallRules(nft, table, chain, cleanup) + return err +} + +func (r *autoRedirect) configureDockerFirewallRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, cleanup bool) error { + rules, err := nft.GetRules(table, chain) + if err != nil { + return E.Cause(err, "list docker user rules") + } + if cleanup { + return r.cleanupDockerFirewallRules(nft, rules) + } + return r.reconcileDockerFirewallRules(nft, table, chain, rules) +} + +func nftablesLoadDockerUserChain(nft *nftables.Conn, family nftables.TableFamily) (*nftables.Table, *nftables.Chain, bool, error) { + table, err := nft.ListTableOfFamily(nftablesDockerFilterTable, family) + if err != nil { + return nil, nil, false, nil + } + chain, err := nft.ListChain(table, nftablesDockerUserChain) + if err != nil { + return nil, nil, false, nil + } + return table, chain, true, nil +} + +func nftablesDockerFirewallEventsRelevant(events *nftables.MonitorEvents) bool { + if events == nil { + return false + } + for _, event := range events.Changes { + if nftablesDockerFirewallEventRelevant(event) { + return true + } + } + return false +} + +func nftablesDockerFirewallEventRelevant(event *nftables.MonitorEvent) bool { + if event == nil || event.Error != nil { + return false + } + switch data := event.Data.(type) { + case *nftables.Table: + return nftablesIsDockerFirewallTable(data) + case *nftables.Chain: + return data.Name == nftablesDockerUserChain && nftablesIsDockerFirewallTable(data.Table) + case *nftables.Rule: + return data.Chain != nil && data.Chain.Name == nftablesDockerUserChain && nftablesIsDockerFirewallTable(data.Table) + default: + return false + } +} + +func nftablesIsDockerFirewallTable(table *nftables.Table) bool { + return table != nil && + table.Name == nftablesDockerFilterTable && + (table.Family == nftables.TableFamilyIPv4 || table.Family == nftables.TableFamilyIPv6) +} + +func (r *autoRedirect) cleanupDockerFirewallRules(nft *nftables.Conn, rules []*nftables.Rule) error { + var deleteErr error + for _, rule := range rules { + if r.nftablesIsDockerCompatibilityRule(rule) { + deleteErr = E.Errors(deleteErr, nft.DelRule(rule)) + } + } + return deleteErr +} + +func (r *autoRedirect) reconcileDockerFirewallRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, rules []*nftables.Rule) error { + outputComment := r.nftablesDockerCompatibilityComment("output to tun") + inputComment := r.nftablesDockerCompatibilityComment("input from tun") + var hasOutputRule bool + var hasInputRule bool + var deleteErr error + for _, rule := range rules { + if nftablesDockerCompatibilityRuleMatches(rule, r.tunOptions.Name, expr.MetaKeyOIFNAME, outputComment) && !hasOutputRule { + hasOutputRule = true + } else if nftablesDockerCompatibilityRuleMatches(rule, r.tunOptions.Name, expr.MetaKeyIIFNAME, inputComment) && !hasInputRule { + hasInputRule = true + } else if r.nftablesIsDockerCompatibilityRule(rule) { + deleteErr = E.Errors(deleteErr, nft.DelRule(rule)) + } + } + if deleteErr != nil { + return deleteErr + } + if !hasOutputRule { + nft.InsertRule(nftablesDockerCompatibilityRule(table, chain, r.tunOptions.Name, expr.MetaKeyOIFNAME, outputComment)) + } + if !hasInputRule { + nft.InsertRule(nftablesDockerCompatibilityRule(table, chain, r.tunOptions.Name, expr.MetaKeyIIFNAME, inputComment)) + } + return nil +} + +func nftablesDockerCompatibilityRule(table *nftables.Table, chain *nftables.Chain, ifName string, ifNameKey expr.MetaKey, comment string) *nftables.Rule { + return &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: ifNameKey, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: nftablesIfname(ifName), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: userdata.AppendString(nil, userdata.TypeComment, comment), + } +} + +func nftablesDockerCompatibilityRuleMatches(rule *nftables.Rule, ifName string, ifNameKey expr.MetaKey, comment string) bool { + ruleComment, loaded := userdata.GetString(rule.UserData, userdata.TypeComment) + if !loaded || ruleComment != comment || len(rule.Exprs) != 4 { + return false + } + meta, loaded := rule.Exprs[0].(*expr.Meta) + if !loaded || meta.Key != ifNameKey || meta.Register != 1 { + return false + } + cmp, loaded := rule.Exprs[1].(*expr.Cmp) + if !loaded || cmp.Op != expr.CmpOpEq || cmp.Register != 1 || !bytes.Equal(cmp.Data, nftablesIfname(ifName)) { + return false + } + _, loaded = rule.Exprs[2].(*expr.Counter) + if !loaded { + return false + } + verdict, loaded := rule.Exprs[3].(*expr.Verdict) + return loaded && verdict.Kind == expr.VerdictAccept +} + +func (r *autoRedirect) nftablesIsDockerCompatibilityRule(rule *nftables.Rule) bool { + comment, loaded := userdata.GetString(rule.UserData, userdata.TypeComment) + return loaded && strings.HasPrefix(comment, r.nftablesDockerCompatibilityCommentPrefix()) +} + +func (r *autoRedirect) nftablesDockerCompatibilityComment(direction string) string { + return r.nftablesDockerCompatibilityCommentPrefix() + direction +} + +func (r *autoRedirect) nftablesDockerCompatibilityCommentPrefix() string { + return "!" + r.tableName + ": Docker compatibility " +} diff --git a/redirect_nftables_rules.go b/redirect_nftables_rules.go index 1b9af3c2..1ef5c19b 100644 --- a/redirect_nftables_rules.go +++ b/redirect_nftables_rules.go @@ -3,6 +3,7 @@ package tun import ( + "net" "net/netip" _ "unsafe" @@ -11,9 +12,8 @@ import ( "github.com/sagernet/nftables/expr" "github.com/sagernet/nftables/userdata" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/ranges" - E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/ranges" "golang.org/x/exp/slices" "golang.org/x/sys/unix" @@ -377,6 +377,149 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft }) } } + if len(r.tunOptions.IncludeMACAddress) > 0 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFTYPE, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint16(unix.ARPHRD_ETHER), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + if len(r.tunOptions.IncludeMACAddress) > 1 { + includeMACSet := &nftables.Set{ + Table: table, + Anonymous: true, + Constant: true, + KeyType: nftables.TypeEtherAddr, + } + err := nft.AddSet(includeMACSet, common.Map(r.tunOptions.IncludeMACAddress, func(it net.HardwareAddr) nftables.SetElement { + return nftables.SetElement{ + Key: []byte(it), + } + })) + if err != nil { + return err + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: includeMACSet.ID, + SetName: includeMACSet.Name, + Invert: true, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } else { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte(r.tunOptions.IncludeMACAddress[0]), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } + } + if len(r.tunOptions.ExcludeMACAddress) > 0 { + if len(r.tunOptions.ExcludeMACAddress) > 1 { + excludeMACSet := &nftables.Set{ + Table: table, + Anonymous: true, + Constant: true, + KeyType: nftables.TypeEtherAddr, + } + err := nft.AddSet(excludeMACSet, common.Map(r.tunOptions.ExcludeMACAddress, func(it net.HardwareAddr) nftables.SetElement { + return nftables.SetElement{ + Key: []byte(it), + } + })) + if err != nil { + return err + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: excludeMACSet.ID, + SetName: excludeMACSet.Name, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } else { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(r.tunOptions.ExcludeMACAddress[0]), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } + } } else { if len(r.tunOptions.IncludeUID) > 0 { if len(r.tunOptions.IncludeUID) > 1 || r.tunOptions.IncludeUID[0].Start != r.tunOptions.IncludeUID[0].End { @@ -532,7 +675,7 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet6RouteExcludeAddress.ID, inet6RouteExcludeAddress.Name, nftables.TableFamilyIPv6, false) } - if !r.tunOptions.EXP_DisableDNSHijack && ((chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT) || + if r.tunOptions.DNSModeOrDefault() == DNSModeHijack && ((chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT) || (r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeNAT)) { if r.enableIPv4 { err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv4, 5, "inet4_local_address_set") @@ -854,23 +997,19 @@ func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily( if err != nil { return E.Cause(err, "add dns protocol set") } - dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { - return it.Is4() == (family == nftables.TableFamilyIPv4) - }) - if !dnsServer.IsValid() { - if family == nftables.TableFamilyIPv4 { - if HasNextAddress(r.tunOptions.Inet4Address[0], 1) { - dnsServer = r.tunOptions.Inet4Address[0].Addr().Next() - } - } else { - if HasNextAddress(r.tunOptions.Inet6Address[0], 1) { - dnsServer = r.tunOptions.Inet6Address[0].Addr().Next() - } - } + var dnsServers []netip.Addr + if family == nftables.TableFamilyIPv4 { + dnsServers, err = r.tunOptions.Inet4DNSAddress() + } else { + dnsServers, err = r.tunOptions.Inet6DNSAddress() + } + if err != nil { + return err } - if !dnsServer.IsValid() { + if len(dnsServers) == 0 { return nil } + dnsServer := dnsServers[0] exprs := []expr.Any{ &expr.Meta{ Key: expr.MetaKeyNFPROTO, diff --git a/redirect_route_linux.go b/redirect_route_linux.go index db79cac6..7e0868c6 100644 --- a/redirect_route_linux.go +++ b/redirect_route_linux.go @@ -8,9 +8,9 @@ import ( "net/netip" "github.com/sagernet/netlink" - E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" "golang.org/x/sys/unix" ) diff --git a/stack.go b/stack.go index 7c34c798..5cdc212d 100644 --- a/stack.go +++ b/stack.go @@ -33,6 +33,7 @@ type StackOptions struct { ForwarderBindInterface bool IncludeAllNetworks bool InterfaceFinder control.InterfaceFinder + MaxTracerouteHopLimit uint8 } func NewStack( diff --git a/stack_gvisor.go b/stack_gvisor.go index ca7dcdc5..94de291e 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -27,18 +27,19 @@ const WithGVisor = true const DefaultNIC tcpip.NICID = 1 type GVisor struct { - ctx context.Context - tun GVisorTun - inet4Address netip.Addr - inet6Address netip.Addr - inet4LoopbackAddress []netip.Addr - inet6LoopbackAddress []netip.Addr - udpTimeout time.Duration - broadcastAddr netip.Addr - handler Handler - logger logger.Logger - stack *stack.Stack - endpoint stack.LinkEndpoint + ctx context.Context + tun GVisorTun + inet4Address netip.Addr + inet6Address netip.Addr + inet4LoopbackAddress []netip.Addr + inet6LoopbackAddress []netip.Addr + udpTimeout time.Duration + maxTracerouteHopLimit uint8 + broadcastAddr netip.Addr + handler Handler + logger logger.Logger + stack *stack.Stack + endpoint stack.LinkEndpoint } type GVisorTun interface { @@ -67,16 +68,17 @@ func NewGVisor( } gStack := &GVisor{ - ctx: options.Context, - tun: gTun, - inet4Address: inet4Address, - inet6Address: inet6Address, - inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress, - inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress, - udpTimeout: options.UDPTimeout, - broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), - handler: options.Handler, - logger: options.Logger, + ctx: options.Context, + tun: gTun, + inet4Address: inet4Address, + inet6Address: inet6Address, + inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress, + inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress, + udpTimeout: options.UDPTimeout, + maxTracerouteHopLimit: options.MaxTracerouteHopLimit, + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), + handler: options.Handler, + logger: options.Logger, } return gStack, nil } @@ -91,10 +93,12 @@ func (t *GVisor) Start() error { if err != nil { return err } - ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) icmpForwarder := NewICMPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout) icmpForwarder.SetLocalAddresses(t.inet4Address, t.inet6Address) + tcpHandler := NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket + udpHandler := NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, WrapTCPHandlerWithDirectRoute(ipStack, t.handler, icmpForwarder, t.udpTimeout, t.maxTracerouteHopLimit, netip.Addr{}, netip.Addr{}, tcpHandler)) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, WrapUDPHandlerWithDirectRoute(ipStack, t.handler, icmpForwarder, t.udpTimeout, t.maxTracerouteHopLimit, netip.Addr{}, netip.Addr{}, udpHandler)) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) t.stack = ipStack diff --git a/stack_gvisor_direct_route_wrapper.go b/stack_gvisor_direct_route_wrapper.go new file mode 100644 index 00000000..f0f31599 --- /dev/null +++ b/stack_gvisor_direct_route_wrapper.go @@ -0,0 +1,367 @@ +//go:build with_gvisor + +package tun + +import ( + "net/netip" + "time" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + buf "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +// WrapTCPHandlerWithDirectRoute wraps an existing gVisor TCP transport handler +// with DirectRoute support for traceroute. Low-TTL SYN packets are handled via +// DirectRoute (sending ICMP Time Exceeded), while all other packets are passed +// to the original handler. +// +// localAddr4/localAddr6 are the node's own addresses used as the source IP +// in ICMP Time Exceeded messages. If zero, TTL=1 packets fall through to +// the original handler. +func WrapTCPHandlerWithDirectRoute( + ipStack *stack.Stack, + handler Handler, + icmpForwarder *ICMPForwarder, + timeout time.Duration, + maxTracerouteHopLimit uint8, + localAddr4 netip.Addr, + localAddr6 netip.Addr, + original func(stack.TransportEndpointID, *stack.PacketBuffer) bool, +) func(stack.TransportEndpointID, *stack.PacketBuffer) bool { + if maxTracerouteHopLimit == 0 { + maxTracerouteHopLimit = defaultMaxTracerouteHopLimit + } + w := &directRouteTCPWrapper{ + stack: ipStack, + handler: handler, + icmpForwarder: icmpForwarder, + directRouteMapping: NewDirectRouteMapping(timeout), + maxTracerouteHopLimit: maxTracerouteHopLimit, + localAddr4: localAddr4, + localAddr6: localAddr6, + original: original, + } + return w.HandlePacket +} + +type directRouteTCPWrapper struct { + stack *stack.Stack + handler Handler + icmpForwarder *ICMPForwarder + directRouteMapping *DirectRouteMapping + maxTracerouteHopLimit uint8 + localAddr4 netip.Addr + localAddr6 netip.Addr + original func(stack.TransportEndpointID, *stack.PacketBuffer) bool +} + +func (w *directRouteTCPWrapper) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + // Only intercept SYN packets (no ACK) with low TTL/HopLimit. + tcpHdr := header.TCP(pkt.TransportHeader().Slice()) + if tcpHdr.Flags()&header.TCPFlagSyn == 0 || tcpHdr.Flags()&header.TCPFlagAck != 0 { + return w.original(id, pkt) + } + + ttlAct := checkTracerouteTTL(pkt, w.maxTracerouteHopLimit, w.localAddr4, w.localAddr6) + if ttlAct == ttlActionPass { + return w.original(id, pkt) + } + if ttlAct == ttlActionTLE { + _ = gWriteTimeExceeded(w.stack, pkt, w.localAddr4, w.localAddr6) + return true + } + + // ttlActionDecrement or ttlActionForward: forward via DirectRoute + source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) + destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) + + var sourceNetwork tcpip.NetworkProtocolNumber + if source.Addr.Is4() { + sourceNetwork = header.IPv4ProtocolNumber + } else { + sourceNetwork = header.IPv6ProtocolNumber + } + backWriter := &ICMPBackWriter{ + stack: w.stack, + packet: pkt, + source: AddressFromAddr(source.Addr), + sourceNetwork: sourceNetwork, + } + if action, err := w.directRouteMapping.Lookup( + DirectRouteSession{Source: source.Addr, Destination: destination.Addr}, + func(timeout time.Duration) (DirectRouteDestination, error) { + return w.handler.PrepareConnection(N.NetworkTCP, source, destination, backWriter, timeout) + }, + ); err == nil && action != nil { + if w.icmpForwarder != nil { + w.icmpForwarder.registerSession(uint8(header.TCPProtocolNumber), destination.Addr, source.Port, source.Addr, backWriter) + } + if ttlAct == ttlActionDecrement { + _ = directRouteWritePacketWithDecrementedTTL(action, pkt) + } else { + _ = directRouteWritePacket(action, pkt) + } + return true + } + + return w.original(id, pkt) +} + +// WrapUDPHandlerWithDirectRoute wraps an existing gVisor UDP transport handler +// with DirectRoute support for traceroute. Low-TTL packets are handled via +// DirectRoute (sending ICMP Time Exceeded), while all other packets are passed +// to the original handler. +func WrapUDPHandlerWithDirectRoute( + ipStack *stack.Stack, + handler Handler, + icmpForwarder *ICMPForwarder, + timeout time.Duration, + maxTracerouteHopLimit uint8, + localAddr4 netip.Addr, + localAddr6 netip.Addr, + original func(stack.TransportEndpointID, *stack.PacketBuffer) bool, +) func(stack.TransportEndpointID, *stack.PacketBuffer) bool { + if maxTracerouteHopLimit == 0 { + maxTracerouteHopLimit = defaultMaxTracerouteHopLimit + } + w := &directRouteUDPWrapper{ + stack: ipStack, + handler: handler, + icmpForwarder: icmpForwarder, + directRouteMapping: NewDirectRouteMapping(timeout), + maxTracerouteHopLimit: maxTracerouteHopLimit, + localAddr4: localAddr4, + localAddr6: localAddr6, + original: original, + } + return w.HandlePacket +} + +type directRouteUDPWrapper struct { + stack *stack.Stack + handler Handler + icmpForwarder *ICMPForwarder + directRouteMapping *DirectRouteMapping + maxTracerouteHopLimit uint8 + localAddr4 netip.Addr + localAddr6 netip.Addr + original func(stack.TransportEndpointID, *stack.PacketBuffer) bool +} + +func (w *directRouteUDPWrapper) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + ttlAct := checkTracerouteTTL(pkt, w.maxTracerouteHopLimit, w.localAddr4, w.localAddr6) + if ttlAct == ttlActionPass { + return w.original(id, pkt) + } + if ttlAct == ttlActionTLE { + _ = gWriteTimeExceeded(w.stack, pkt, w.localAddr4, w.localAddr6) + return true + } + + // ttlActionDecrement or ttlActionForward: forward via DirectRoute + source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) + destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) + + var sourceNetwork tcpip.NetworkProtocolNumber + if source.Addr.Is4() { + sourceNetwork = header.IPv4ProtocolNumber + } else { + sourceNetwork = header.IPv6ProtocolNumber + } + backWriter := &ICMPBackWriter{ + stack: w.stack, + packet: pkt, + source: AddressFromAddr(source.Addr), + sourceNetwork: sourceNetwork, + } + if action, err := w.directRouteMapping.Lookup( + DirectRouteSession{Source: source.Addr, Destination: destination.Addr}, + func(timeout time.Duration) (DirectRouteDestination, error) { + return w.handler.PrepareConnection(N.NetworkUDP, source, destination, backWriter, timeout) + }, + ); err == nil && action != nil { + if w.icmpForwarder != nil { + w.icmpForwarder.registerSession(uint8(header.UDPProtocolNumber), destination.Addr, source.Port, source.Addr, backWriter) + } + if ttlAct == ttlActionDecrement { + _ = directRouteWritePacketWithDecrementedTTL(action, pkt) + } else { + _ = directRouteWritePacket(action, pkt) + } + return true + } + + return w.original(id, pkt) +} + +// ttlAction describes what to do with a packet based on its TTL/HopLimit. +type ttlAction int + +const ( + ttlActionPass ttlAction = iota // Not traceroute (TTL=0 or >= maxHopLimit) + ttlActionForward // Forward via DirectRoute with original TTL (no local address) + ttlActionTLE // Generate ICMP Time Exceeded (TTL=1, has local address) + ttlActionDecrement // Forward via DirectRoute with TTL-1 (TTL>1, has local address) +) + +// checkTracerouteTTL inspects a packet's TTL/HopLimit and determines +// the appropriate action for traceroute support. +func checkTracerouteTTL(pkt *stack.PacketBuffer, maxHopLimit uint8, tleAddr4, tleAddr6 netip.Addr) ttlAction { + var hopLimit uint8 + var isIPv4 bool + switch ipHdr := pkt.Network().(type) { + case header.IPv4: + hopLimit = ipHdr.TTL() + isIPv4 = true + case header.IPv6: + hopLimit = ipHdr.HopLimit() + } + if hopLimit == 0 || hopLimit >= maxHopLimit { + return ttlActionPass + } + hasTLEAddr := isIPv4 && tleAddr4.IsValid() || !isIPv4 && tleAddr6.IsValid() + if !hasTLEAddr { + return ttlActionForward + } + if hopLimit == 1 { + return ttlActionTLE + } + return ttlActionDecrement +} + +// directRouteWritePacketWithDecrementedTTL copies the packet, decrements +// TTL/HopLimit by 1, updates the IPv4 header checksum, and forwards via +// DirectRoute. +func directRouteWritePacketWithDecrementedTTL(action DirectRouteDestination, packetBuffer *stack.PacketBuffer) error { + networkHdr := make([]byte, len(packetBuffer.NetworkHeader().Slice())) + copy(networkHdr, packetBuffer.NetworkHeader().Slice()) + + switch packetBuffer.Network().(type) { + case header.IPv4: + ipHdr := header.IPv4(networkHdr) + ipHdr.SetTTL(ipHdr.TTL() - 1) + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + case header.IPv6: + ipHdr := header.IPv6(networkHdr) + ipHdr.SetHopLimit(ipHdr.HopLimit() - 1) + } + + packetSlice := networkHdr + packetSlice = append(packetSlice, packetBuffer.TransportHeader().Slice()...) + packetSlice = append(packetSlice, packetBuffer.Data().AsRange().ToSlice()...) + return action.WritePacket(buf.As(packetSlice).ToOwned()) +} + +// gWriteTimeExceeded constructs and injects an ICMP Time Exceeded message +// back into the gVisor stack, making this node appear as a hop in traceroute. +func gWriteTimeExceeded(ipStack *stack.Stack, pkt *stack.PacketBuffer, localAddr4, localAddr6 netip.Addr) error { + origNetwork := pkt.NetworkHeader().Slice() + origTransport := pkt.TransportHeader().Slice() + origData := pkt.Data().AsRange().ToSlice() + + switch pkt.Network().(type) { + case header.IPv4: + return gWriteTimeExceeded4(ipStack, origNetwork, origTransport, origData, AddressFromAddr(localAddr4)) + case header.IPv6: + return gWriteTimeExceeded6(ipStack, origNetwork, origTransport, origData, AddressFromAddr(localAddr6)) + } + return nil +} + +func gWriteTimeExceeded4(ipStack *stack.Stack, origNetwork, origTransport, origData []byte, localAddr tcpip.Address) error { + clientAddr := header.IPv4(origNetwork).SourceAddress() + + // RFC 1812: include as much of original packet as possible, up to 576 bytes total + maxPayload := 576 - header.IPv4MinimumSize - header.ICMPv4MinimumSize + payload := buildICMPErrorPayload(origNetwork, origTransport, origData, maxPayload) + + route, gErr := ipStack.FindRoute(DefaultNIC, localAddr, clientAddr, header.IPv4ProtocolNumber, false) + if gErr != nil { + return gonet.TranslateNetstackError(gErr) + } + defer route.Release() + + // Build ICMP packet using gVisor's PacketBuffer API (same as gVisor's internal ICMP sending) + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize, + Payload: buffer.MakeWithData(payload), + }) + defer icmpPkt.DecRef() + + icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4TTLExceeded) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().Checksum())) + + return gonet.TranslateNetstackError(route.WritePacket( + stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + icmpPkt, + )) +} + +func gWriteTimeExceeded6(ipStack *stack.Stack, origNetwork, origTransport, origData []byte, localAddr tcpip.Address) error { + clientAddr := header.IPv6(origNetwork).SourceAddress() + + // RFC 4443: include as much of invoking packet as possible, up to minimum IPv6 MTU + maxPayload := 1280 - header.IPv6MinimumSize - header.ICMPv6MinimumSize + payload := buildICMPErrorPayload(origNetwork, origTransport, origData, maxPayload) + + route, gErr := ipStack.FindRoute(DefaultNIC, localAddr, clientAddr, header.IPv6ProtocolNumber, false) + if gErr != nil { + return gonet.TranslateNetstackError(gErr) + } + defer route.Release() + + // Build ICMPv6 packet using gVisor's PacketBuffer API + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6MinimumSize, + Payload: buffer.MakeWithData(payload), + }) + defer icmpPkt.DecRef() + + icmpPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6MinimumSize)) + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) + + pktData := icmpPkt.Data() + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: route.LocalAddress(), + Dst: route.RemoteAddress(), + PayloadCsum: pktData.Checksum(), + PayloadLen: pktData.Size(), + })) + + return gonet.TranslateNetstackError(route.WritePacket( + stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + icmpPkt, + )) +} + +func buildICMPErrorPayload(origNetwork, origTransport, origData []byte, maxLen int) []byte { + var payload []byte + payload = append(payload, origNetwork...) + payload = append(payload, origTransport...) + payload = append(payload, origData...) + if len(payload) > maxLen { + payload = payload[:maxLen] + } + return payload +} diff --git a/stack_gvisor_icmp.go b/stack_gvisor_icmp.go index da5549b6..40cb8361 100644 --- a/stack_gvisor_icmp.go +++ b/stack_gvisor_icmp.go @@ -17,18 +17,40 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" "github.com/sagernet/gvisor/pkg/tcpip/stack" - "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" ) type ICMPForwarder struct { - ctx context.Context - stack *stack.Stack - inet4Address netip.Addr - inet6Address netip.Addr - handler Handler - mapping *DirectRouteMapping + ctx context.Context + stack *stack.Stack + inet4Address netip.Addr + inet6Address netip.Addr + tleAddr4 netip.Addr + tleAddr6 netip.Addr + maxTracerouteHopLimit uint8 + handler Handler + mapping *DirectRouteMapping + // reverseMapping maps {protocol, port/ident, destination} → original client + // address for delivering ICMP error responses (TimeExceeded/DstUnreachable) + // back to the correct TUN client. Covers ICMP echo, UDP, and TCP inner packets. + reverseMapping freelru.Cache[reverseKey, icmpReverseEntry] +} + +// reverseKey identifies an outgoing session by protocol, port (or ICMP ident), +// and destination, for reverse lookup when delivering ICMP errors. +type reverseKey struct { + Protocol uint8 + Port uint16 // ICMP ident, or TCP/UDP source port + Destination netip.Addr +} + +type icmpReverseEntry struct { + ClientAddr netip.Addr + BackWriter *ICMPBackWriter } func NewICMPForwarder( @@ -37,11 +59,16 @@ func NewICMPForwarder( handler Handler, timeout time.Duration, ) *ICMPForwarder { + reverseMapping := common.Must1(freelru.NewSynced[reverseKey, icmpReverseEntry]( + 4096, maphash.NewHasher[reverseKey]().Hash32, + )) + reverseMapping.SetLifetime(30 * time.Second) return &ICMPForwarder{ - ctx: ctx, - stack: stack, - handler: handler, - mapping: NewDirectRouteMapping(timeout), + ctx: ctx, + stack: stack, + handler: handler, + mapping: NewDirectRouteMapping(timeout), + reverseMapping: reverseMapping, } } @@ -50,29 +77,73 @@ func (f *ICMPForwarder) SetLocalAddresses(inet4Address, inet6Address netip.Addr) f.inet6Address = inet6Address } +func (f *ICMPForwarder) SetTTLDecrement(addr4, addr6 netip.Addr, maxHopLimit uint8) { + f.tleAddr4 = addr4 + f.tleAddr6 = addr6 + if maxHopLimit == 0 { + maxHopLimit = defaultMaxTracerouteHopLimit + } + f.maxTracerouteHopLimit = maxHopLimit +} + +func (f *ICMPForwarder) registerSession(protocol uint8, destination netip.Addr, srcPort uint16, clientAddr netip.Addr, backWriter *ICMPBackWriter) { + f.reverseMapping.Add(reverseKey{ + Protocol: protocol, + Port: srcPort, + Destination: destination, + }, icmpReverseEntry{ + ClientAddr: clientAddr, + BackWriter: backWriter, + }) +} + func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { ipHdr := header.IPv4(pkt.NetworkHeader().Slice()) icmpHdr := header.ICMPv4(pkt.TransportHeader().Slice()) - if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 { + switch icmpHdr.Type() { + case header.ICMPv4TimeExceeded, header.ICMPv4DstUnreachable: + return f.handleICMPError4(pkt) + case header.ICMPv4Echo: + default: + return false + } + if icmpHdr.Code() != 0 { return false } sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice()) destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice()) if destinationAddr != f.inet4Address { + ttlAct := checkTracerouteTTL(pkt, f.maxTracerouteHopLimit, f.tleAddr4, f.tleAddr6) + if ttlAct == ttlActionTLE { + _ = gWriteTimeExceeded(f.stack, pkt, f.tleAddr4, f.tleAddr6) + return true + } + backWriter := &ICMPBackWriter{ + stack: f.stack, + packet: pkt, + source: ipHdr.SourceAddress(), + sourceNetwork: header.IPv4ProtocolNumber, + } action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) { - return f.handler.PrepareConnection( + dest, prepErr := f.handler.PrepareConnection( N.NetworkICMP, M.SocksaddrFrom(sourceAddr, 0), M.SocksaddrFrom(destinationAddr, 0), - &ICMPBackWriter{ - stack: f.stack, - packet: pkt, - source: ipHdr.SourceAddress(), - sourceNetwork: header.IPv4ProtocolNumber, - }, + backWriter, timeout, ) + if prepErr == nil && dest != nil { + f.reverseMapping.Add(reverseKey{ + Protocol: uint8(header.ICMPv4ProtocolNumber), + Port: icmpHdr.Ident(), + Destination: destinationAddr, + }, icmpReverseEntry{ + ClientAddr: sourceAddr, + BackWriter: backWriter, + }) + } + return dest, prepErr }) if errors.Is(err, ErrReset) { gWriteUnreachable(f.stack, pkt) @@ -81,8 +152,11 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa return true } if action != nil { - // TODO: handle error - _ = icmpWritePacketBuffer(action, pkt) + if ttlAct == ttlActionDecrement { + _ = directRouteWritePacketWithDecrementedTTL(action, pkt) + } else { + _ = directRouteWritePacket(action, pkt) + } return true } } @@ -102,7 +176,7 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa DefaultNIC, id.LocalAddress, id.RemoteAddress, - header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, false, ) if gErr != nil { @@ -115,25 +189,49 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa } else { ipHdr := header.IPv6(pkt.NetworkHeader().Slice()) icmpHdr := header.ICMPv6(pkt.TransportHeader().Slice()) - if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 { + switch icmpHdr.Type() { + case header.ICMPv6TimeExceeded, header.ICMPv6DstUnreachable: + return f.handleICMPError6(pkt) + case header.ICMPv6EchoRequest: + default: + return false + } + if icmpHdr.Code() != 0 { return false } sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice()) destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice()) if destinationAddr != f.inet6Address { + ttlAct := checkTracerouteTTL(pkt, f.maxTracerouteHopLimit, f.tleAddr4, f.tleAddr6) + if ttlAct == ttlActionTLE { + _ = gWriteTimeExceeded(f.stack, pkt, f.tleAddr4, f.tleAddr6) + return true + } + backWriter := &ICMPBackWriter{ + stack: f.stack, + packet: pkt, + source: ipHdr.SourceAddress(), + sourceNetwork: header.IPv6ProtocolNumber, + } action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) { - return f.handler.PrepareConnection( + dest, prepErr := f.handler.PrepareConnection( N.NetworkICMP, M.SocksaddrFrom(sourceAddr, 0), M.SocksaddrFrom(destinationAddr, 0), - &ICMPBackWriter{ - stack: f.stack, - packet: pkt, - source: ipHdr.SourceAddress(), - sourceNetwork: header.IPv6ProtocolNumber, - }, + backWriter, timeout, ) + if prepErr == nil && dest != nil { + f.reverseMapping.Add(reverseKey{ + Protocol: uint8(header.ICMPv6ProtocolNumber), + Port: icmpHdr.Ident(), + Destination: destinationAddr, + }, icmpReverseEntry{ + ClientAddr: sourceAddr, + BackWriter: backWriter, + }) + } + return dest, prepErr }) if errors.Is(err, ErrReset) { gWriteUnreachable(f.stack, pkt) @@ -142,9 +240,12 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa return true } if action != nil { - // TODO: handle error - pkt.IncRef() - _ = icmpWritePacketBuffer(action, pkt) + if ttlAct == ttlActionDecrement { + _ = directRouteWritePacketWithDecrementedTTL(action, pkt) + } else { + pkt.IncRef() + _ = directRouteWritePacket(action, pkt) + } return true } } @@ -159,7 +260,7 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa PayloadCsum: pkt.Data().Checksum(), PayloadLen: pkt.Data().Size(), })) - outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber) + outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv6ProtocolNumber) if gErr != nil { // TODO: log error return true @@ -190,55 +291,165 @@ type ICMPBackWriter struct { } func (w *ICMPBackWriter) WritePacket(p []byte) error { + var srcAddr tcpip.Address + if w.sourceNetwork == header.IPv4ProtocolNumber { + srcAddr = header.IPv4(p).SourceAddress() + } else { + srcAddr = header.IPv6(p).SourceAddress() + } + route, err := w.stack.FindRoute( + DefaultNIC, + srcAddr, + w.source, + w.sourceNetwork, + false, + ) + if err != nil { + return gonet.TranslateNetstackError(err) + } + defer route.Release() + packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(p), + }) + defer packet.DecRef() if w.sourceNetwork == header.IPv4ProtocolNumber { - route, err := w.stack.FindRoute( - DefaultNIC, - header.IPv4(p).SourceAddress(), - w.source, - w.sourceNetwork, - false, - ) - if err != nil { - return gonet.TranslateNetstackError(err) - } - defer route.Release() - packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(p), - }) - defer packet.DecRef() parse.IPv4(packet) - err = route.WritePacketDirect(packet) - if err != nil { - return gonet.TranslateNetstackError(err) - } } else { - route, err := w.stack.FindRoute( - DefaultNIC, - header.IPv6(p).SourceAddress(), - w.source, - w.sourceNetwork, - false, - ) - if err != nil { - return gonet.TranslateNetstackError(err) - } - defer route.Release() - packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(p), - }) parse.IPv6(packet) - defer packet.DecRef() - err = route.WritePacketDirect(packet) - if err != nil { - return gonet.TranslateNetstackError(err) - } } - return nil + return gonet.TranslateNetstackError(route.WritePacketDirect(packet)) +} + +// resolveInnerProtocol looks up the reverse mapping for the inner packet's +// protocol, port/ident, and destination. icmpProto is the ICMP protocol +// number for the address family (ICMPv4 or ICMPv6). +func (f *ICMPForwarder) resolveInnerProtocol( + icmpProto uint8, + innerProto uint8, + innerPayloadLen uint16, + innerPayload []byte, + innerDst netip.Addr, +) (icmpReverseEntry, bool) { + var minSize int + switch innerProto { + case icmpProto: + minSize = header.ICMPv4MinimumSize + case uint8(header.UDPProtocolNumber): + minSize = header.UDPMinimumSize + case uint8(header.TCPProtocolNumber): + minSize = header.TCPMinimumSize + default: + return icmpReverseEntry{}, false + } + if innerPayloadLen < uint16(minSize) { + return icmpReverseEntry{}, false + } + var port uint16 + if innerProto == icmpProto { + port = header.ICMPv4(innerPayload).Ident() // offset 4 + } else { + port = header.UDP(innerPayload).SourcePort() // offset 0 (same for TCP) + } + return f.reverseMapping.Get(reverseKey{ + Protocol: innerProto, + Port: port, + Destination: innerDst, + }) } -func icmpWritePacketBuffer(action DirectRouteDestination, packetBuffer *stack.PacketBuffer) error { - packetSlice := packetBuffer.NetworkHeader().Slice() - packetSlice = append(packetSlice, packetBuffer.TransportHeader().Slice()...) - packetSlice = append(packetSlice, packetBuffer.Data().AsRange().ToSlice()...) - return action.WritePacket(buf.As(packetSlice).ToOwned()) +func (f *ICMPForwarder) handleICMPError4(pkt *stack.PacketBuffer) bool { + transportHdr := pkt.TransportHeader().Slice() + dataSlice := pkt.Data().AsRange().ToSlice() + payload := make([]byte, len(transportHdr)+len(dataSlice)) + copy(payload, transportHdr) + copy(payload[len(transportHdr):], dataSlice) + if len(payload) < header.ICMPv4MinimumSize+header.IPv4MinimumSize { + return false + } + innerIPHdr := header.IPv4(payload[header.ICMPv4MinimumSize:]) + if !innerIPHdr.IsValid(len(payload) - header.ICMPv4MinimumSize) { + return false + } + innerDst := M.AddrFromIP(innerIPHdr.DestinationAddressSlice()) + entry, found := f.resolveInnerProtocol( + uint8(header.ICMPv4ProtocolNumber), + innerIPHdr.Protocol(), + innerIPHdr.PayloadLength(), + innerIPHdr.Payload(), + innerDst, + ) + if !found { + return false + } + networkHdr := pkt.NetworkHeader().Slice() + errPacket := make([]byte, len(networkHdr)+len(payload)) + copy(errPacket, networkHdr) + copy(errPacket[len(networkHdr):], payload) + outerIPHdr := header.IPv4(errPacket) + outerIPHdr.SetDestinationAddress(tcpip.AddrFrom4(entry.ClientAddr.As4())) + innerOffset := len(networkHdr) + header.ICMPv4MinimumSize + innerIP := header.IPv4(errPacket[innerOffset:]) + innerIP.SetSourceAddress(tcpip.AddrFrom4(entry.ClientAddr.As4())) + innerIP.SetChecksum(0) + innerIP.SetChecksum(^innerIP.CalculateChecksum()) + outerIPHdr.SetChecksum(0) + outerIPHdr.SetChecksum(^outerIPHdr.CalculateChecksum()) + outerICMP := header.ICMPv4(errPacket[outerIPHdr.HeaderLength():]) + outerICMP.SetChecksum(0) + outerICMP.SetChecksum(header.ICMPv4Checksum(outerICMP, 0)) + return entry.BackWriter.WritePacket(errPacket) == nil +} + +func (f *ICMPForwarder) handleICMPError6(pkt *stack.PacketBuffer) bool { + transportHdr := pkt.TransportHeader().Slice() + dataSlice := pkt.Data().AsRange().ToSlice() + payload := make([]byte, len(transportHdr)+len(dataSlice)) + copy(payload, transportHdr) + copy(payload[len(transportHdr):], dataSlice) + if len(payload) < header.ICMPv6MinimumSize+header.IPv6MinimumSize { + return false + } + innerIPHdr := header.IPv6(payload[header.ICMPv6MinimumSize:]) + if !innerIPHdr.IsValid(len(payload) - header.ICMPv6MinimumSize) { + return false + } + innerDst := M.AddrFromIP(innerIPHdr.DestinationAddressSlice()) + entry, found := f.resolveInnerProtocol( + uint8(header.ICMPv6ProtocolNumber), + uint8(innerIPHdr.TransportProtocol()), + innerIPHdr.PayloadLength(), + innerIPHdr.Payload(), + innerDst, + ) + if !found { + return false + } + networkHdr := pkt.NetworkHeader().Slice() + errPacket := make([]byte, len(networkHdr)+len(payload)) + copy(errPacket, networkHdr) + copy(errPacket[len(networkHdr):], payload) + outerIPHdr := header.IPv6(errPacket) + clientAddr16 := entry.ClientAddr.As16() + outerIPHdr.SetDestinationAddress(tcpip.AddrFrom16(clientAddr16)) + innerOffset := len(networkHdr) + header.ICMPv6MinimumSize + innerIP := header.IPv6(errPacket[innerOffset:]) + innerIP.SetSourceAddress(tcpip.AddrFrom16(clientAddr16)) + // Recalculate inner transport checksum if it's ICMPv6 + if innerIP.TransportProtocol() == header.ICMPv6ProtocolNumber && innerIP.PayloadLength() >= header.ICMPv6MinimumSize { + innerICMP := header.ICMPv6(innerIP.Payload()) + innerICMP.SetChecksum(0) + innerICMP.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: innerICMP, + Src: innerIP.SourceAddress(), + Dst: innerIP.DestinationAddress(), + })) + } + outerICMP := header.ICMPv6(errPacket[header.IPv6MinimumSize:]) + outerICMP.SetChecksum(0) + outerICMP.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: outerICMP, + Src: outerIPHdr.SourceAddress(), + Dst: outerIPHdr.DestinationAddress(), + })) + return entry.BackWriter.WritePacket(errPacket) == nil } diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go index f5e2e6e6..dcbcafb9 100644 --- a/stack_gvisor_lazy.go +++ b/stack_gvisor_lazy.go @@ -19,7 +19,7 @@ import ( ) type gLazyConn struct { - tcpConn *gonet.TCPConn + tcpConn *gTCPConn parentCtx context.Context stack *stack.Stack request *tcp.ForwarderRequest @@ -31,9 +31,6 @@ type gLazyConn struct { } func (c *gLazyConn) HandshakeContext(ctx context.Context) error { - if c.handshakeDone { - return c.handshakeErr - } c.handshakeAccess.Lock() defer c.handshakeAccess.Unlock() if c.handshakeDone { @@ -66,15 +63,12 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error { endpoint.SocketOptions().SetKeepAlive(true) endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIdleOption(15 * time.Second))) endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIntervalOption(15 * time.Second))) - tcpConn := gonet.NewTCPConn(&wq, endpoint) + tcpConn := newGTCPConn(&wq, endpoint, c.localAddr, c.remoteAddr) c.tcpConn = tcpConn return nil } func (c *gLazyConn) HandshakeFailure(err error) error { - if c.handshakeDone { - return os.ErrInvalid - } c.handshakeAccess.Lock() defer c.handshakeAccess.Unlock() if c.handshakeDone { @@ -90,6 +84,18 @@ func (c *gLazyConn) HandshakeSuccess() error { return c.HandshakeContext(context.Background()) } +func (c *gLazyConn) NeedHandshakeForRead() bool { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + return !c.handshakeDone +} + +func (c *gLazyConn) NeedHandshakeForWrite() bool { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + return !c.handshakeDone +} + func (c *gLazyConn) Read(b []byte) (n int, err error) { err = c.HandshakeContext(context.Background()) if err != nil { @@ -139,57 +145,38 @@ func (c *gLazyConn) SetWriteDeadline(t time.Time) error { } func (c *gLazyConn) Close() error { - if !c.handshakeDone { - c.handshakeAccess.Lock() - if !c.handshakeDone { - c.request.Complete(true) - c.handshakeErr = net.ErrClosed - c.handshakeDone = true - return nil - } else if c.handshakeErr != nil { - return nil - } - c.handshakeAccess.Unlock() - } else if c.handshakeErr != nil { + if c.closeBeforeHandshake() { return nil } return c.tcpConn.Close() } func (c *gLazyConn) CloseRead() error { - if !c.handshakeDone { - c.handshakeAccess.Lock() - if !c.handshakeDone { - c.request.Complete(true) - c.handshakeErr = net.ErrClosed - c.handshakeDone = true - return nil - } else if c.handshakeErr != nil { - return nil - } - c.handshakeAccess.Unlock() - } else if c.handshakeErr != nil { + if c.closeBeforeHandshake() { return nil } return c.tcpConn.CloseRead() } func (c *gLazyConn) CloseWrite() error { + if c.closeBeforeHandshake() { + return nil + } + return c.tcpConn.CloseWrite() +} + +func (c *gLazyConn) closeBeforeHandshake() bool { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() if !c.handshakeDone { - c.handshakeAccess.Lock() - if !c.handshakeDone { + if c.request != nil { c.request.Complete(true) - c.handshakeErr = net.ErrClosed - c.handshakeDone = true - return nil - } else if c.handshakeErr != nil { - return nil } - c.handshakeAccess.Unlock() - } else if c.handshakeErr != nil { - return nil + c.handshakeErr = net.ErrClosed + c.handshakeDone = true + return true } - return c.tcpConn.CloseRead() + return c.handshakeErr != nil } func (c *gLazyConn) ReaderReplaceable() bool { diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index 0c63ee11..e8cebc9f 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -11,7 +11,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/checksum" "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -27,20 +27,20 @@ type TCPForwarder struct { forwarder *tcp.Forwarder } -func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder { - return NewTCPForwarderWithLoopback(ctx, stack, handler, nil, nil, nil) +func NewTCPForwarder(ctx context.Context, ipStack *stack.Stack, handler Handler) *TCPForwarder { + return NewTCPForwarderWithLoopback(ctx, ipStack, handler, nil, nil, nil) } -func NewTCPForwarderWithLoopback(ctx context.Context, stack *stack.Stack, handler Handler, inet4LoopbackAddress []netip.Addr, inet6LoopbackAddress []netip.Addr, tun GVisorTun) *TCPForwarder { +func NewTCPForwarderWithLoopback(ctx context.Context, ipStack *stack.Stack, handler Handler, inet4LoopbackAddress []netip.Addr, inet6LoopbackAddress []netip.Addr, tun GVisorTun) *TCPForwarder { forwarder := &TCPForwarder{ ctx: ctx, - stack: stack, + stack: ipStack, handler: handler, inet4LoopbackAddress: common.Map(inet4LoopbackAddress, AddressFromAddr), inet6LoopbackAddress: common.Map(inet6LoopbackAddress, AddressFromAddr), tun: tun, } - forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward) + forwarder.forwarder = tcp.NewForwarder(ipStack, 0, 1024, forwarder.Forward) return forwarder } diff --git a/stack_gvisor_tcp_conn.go b/stack_gvisor_tcp_conn.go new file mode 100644 index 00000000..ad48d42f --- /dev/null +++ b/stack_gvisor_tcp_conn.go @@ -0,0 +1,325 @@ +//go:build with_gvisor + +package tun + +import ( + "bytes" + "errors" + "io" + "net" + "os" + "time" + + "github.com/sagernet/gvisor/pkg/sync" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/waiter" + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" +) + +var ( + _ net.Conn = (*gTCPConn)(nil) + _ N.ReadWaiter = (*gTCPConn)(nil) +) + +type gTCPConn struct { + gTCPDeadline + + wq *waiter.Queue + ep tcpip.Endpoint + + localAddr net.Addr + remoteAddr net.Addr + + readMu sync.Mutex + readWaitOption N.ReadWaitOptions +} + +func newGTCPConn(wq *waiter.Queue, ep tcpip.Endpoint, localAddr net.Addr, remoteAddr net.Addr) *gTCPConn { + conn := &gTCPConn{ + wq: wq, + ep: ep, + localAddr: localAddr, + remoteAddr: remoteAddr, + } + conn.gTCPDeadline.init() + return conn +} + +func (c *gTCPConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOption = options + return false +} + +func (c *gTCPConn) WaitReadBuffer() (*buf.Buffer, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + deadline := c.readCancel() + for { + if err := c.waitReadable(deadline); err != nil { + return nil, err + } + buffer := c.readWaitOption.NewBuffer() + writer := tcpip.SliceWriter(buffer.FreeBytes()) + result, err := c.ep.Read(&writer, tcpip.ReadOptions{}) + if _, wouldBlock := err.(*tcpip.ErrWouldBlock); wouldBlock { + buffer.Release() + continue + } + if err != nil { + buffer.Release() + return nil, c.translateReadError(err) + } + if result.Count == 0 { + buffer.Release() + continue + } + buffer.Truncate(result.Count) + c.readWaitOption.PostReturn(buffer) + c.ep.ModerateRecvBuf(result.Count) + return buffer, nil + } +} + +func (c *gTCPConn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + writer := tcpip.SliceWriter(b) + n, err := c.readTo(&writer, c.readCancel()) + if n != 0 { + c.ep.ModerateRecvBuf(n) + } + return n, err +} + +func (c *gTCPConn) readTo(writer io.Writer, deadline <-chan struct{}) (int, error) { + select { + case <-deadline: + return 0, c.newOpError("read", os.ErrDeadlineExceeded) + default: + } + + result, err := c.ep.Read(writer, tcpip.ReadOptions{}) + if _, wouldBlock := err.(*tcpip.ErrWouldBlock); wouldBlock { + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) + c.wq.EventRegister(&waitEntry) + defer c.wq.EventUnregister(&waitEntry) + for { + result, err = c.ep.Read(writer, tcpip.ReadOptions{}) + if _, wouldBlock = err.(*tcpip.ErrWouldBlock); !wouldBlock { + break + } + select { + case <-deadline: + return 0, c.newOpError("read", os.ErrDeadlineExceeded) + case <-notifyCh: + } + } + } + + if err != nil { + return 0, c.translateReadError(err) + } + return result.Count, nil +} + +func (c *gTCPConn) waitReadable(deadline <-chan struct{}) error { + select { + case <-deadline: + return c.newOpError("read", os.ErrDeadlineExceeded) + default: + } + if c.ep.Readiness(waiter.ReadableEvents)&waiter.ReadableEvents != 0 { + return nil + } + + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) + c.wq.EventRegister(&waitEntry) + defer c.wq.EventUnregister(&waitEntry) + for c.ep.Readiness(waiter.ReadableEvents)&waiter.ReadableEvents == 0 { + select { + case <-deadline: + return c.newOpError("read", os.ErrDeadlineExceeded) + case <-notifyCh: + } + } + return nil +} + +func (c *gTCPConn) translateReadError(err tcpip.Error) error { + if _, closed := err.(*tcpip.ErrClosedForReceive); closed { + return io.EOF + } + return c.newOpError("read", gonet.TranslateNetstackError(err)) +} + +func (c *gTCPConn) Write(b []byte) (int, error) { + deadline := c.writeCancel() + + select { + case <-deadline: + return 0, c.newOpError("write", os.ErrDeadlineExceeded) + default: + } + + var ( + reader bytes.Reader + nBytes int + entry waiter.Entry + ch <-chan struct{} + ) + for nBytes != len(b) { + reader.Reset(b[nBytes:]) + n, err := c.ep.Write(&reader, tcpip.WriteOptions{}) + nBytes += int(n) + switch err.(type) { + case nil: + case *tcpip.ErrWouldBlock: + if ch == nil { + entry, ch = waiter.NewChannelEntry(waiter.WritableEvents) + c.wq.EventRegister(&entry) + defer c.wq.EventUnregister(&entry) + } else { + select { + case <-deadline: + return nBytes, c.newOpError("write", os.ErrDeadlineExceeded) + case <-ch: + continue + } + } + default: + return nBytes, c.newOpError("write", gonet.TranslateNetstackError(err)) + } + } + return nBytes, nil +} + +func (c *gTCPConn) Close() error { + c.ep.Close() + return nil +} + +func (c *gTCPConn) CloseRead() error { + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + return c.newOpError("close", errors.New(err.String())) + } + return nil +} + +func (c *gTCPConn) CloseWrite() error { + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + return c.newOpError("close", errors.New(err.String())) + } + return nil +} + +func (c *gTCPConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *gTCPConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *gTCPConn) SetDeadline(t time.Time) error { + return c.gTCPDeadline.SetDeadline(t) +} + +func (c *gTCPConn) SetReadDeadline(t time.Time) error { + return c.gTCPDeadline.SetReadDeadline(t) +} + +func (c *gTCPConn) SetWriteDeadline(t time.Time) error { + return c.gTCPDeadline.SetWriteDeadline(t) +} + +func (c *gTCPConn) newOpError(op string, err error) *net.OpError { + return &net.OpError{ + Op: op, + Net: "tcp", + Source: c.localAddr, + Addr: c.remoteAddr, + Err: err, + } +} + +type gTCPDeadline struct { + mu sync.Mutex + + readTimer *time.Timer + readCancelCh chan struct{} + writeTimer *time.Timer + writeCancelCh chan struct{} +} + +func (d *gTCPDeadline) init() { + d.readCancelCh = make(chan struct{}) + d.writeCancelCh = make(chan struct{}) +} + +func (d *gTCPDeadline) readCancel() <-chan struct{} { + d.mu.Lock() + cancelCh := d.readCancelCh + d.mu.Unlock() + return cancelCh +} + +func (d *gTCPDeadline) writeCancel() <-chan struct{} { + d.mu.Lock() + cancelCh := d.writeCancelCh + d.mu.Unlock() + return cancelCh +} + +func (d *gTCPDeadline) SetDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.readCancelCh, &d.readTimer, t) + d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) + d.mu.Unlock() + return nil +} + +func (d *gTCPDeadline) SetReadDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.readCancelCh, &d.readTimer, t) + d.mu.Unlock() + return nil +} + +func (d *gTCPDeadline) SetWriteDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) + d.mu.Unlock() + return nil +} + +func (d *gTCPDeadline) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { + if *timer != nil && !(*timer).Stop() { + *cancelCh = make(chan struct{}) + } + + select { + case <-*cancelCh: + *cancelCh = make(chan struct{}) + default: + } + + if t.IsZero() { + *timer = nil + return + } + + timeout := time.Until(t) + if timeout <= 0 { + close(*cancelCh) + return + } + + ch := *cancelCh + *timer = time.AfterFunc(timeout, func() { + close(ch) + }) +} diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index f91a2b3e..785765b9 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -22,9 +22,15 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/udpnat2" + udpnat "github.com/sagernet/sing/common/udpnat2" ) +// defaultMaxTracerouteHopLimit is the default TTL/HopLimit threshold below +// which packets are treated as traceroute probes and sent via DirectRoute +// rather than the normal proxy path. Set to 60 to allow up to 4 layers of +// network devices between the originating host (TTL=64) and this TUN stack. +const defaultMaxTracerouteHopLimit = 60 + type UDPForwarder struct { ctx context.Context stack *stack.Stack @@ -32,10 +38,10 @@ type UDPForwarder struct { udpNat *udpnat.Service } -func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder { +func NewUDPForwarder(ctx context.Context, ipStack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder { forwarder := &UDPForwarder{ ctx: ctx, - stack: stack, + stack: ipStack, handler: handler, } forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout, true) @@ -45,6 +51,7 @@ func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, t func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) + bufferRange := pkt.Data().AsRange() var bufferSlices [][]byte rangeIterate(bufferRange, func(view *buffer.View) { @@ -182,3 +189,10 @@ func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer) error { return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, stack.RejectIPv6WithICMPPortUnreachable, true)) } } + +func directRouteWritePacket(action DirectRouteDestination, packetBuffer *stack.PacketBuffer) error { + packetSlice := packetBuffer.NetworkHeader().Slice() + packetSlice = append(packetSlice, packetBuffer.TransportHeader().Slice()...) + packetSlice = append(packetSlice, packetBuffer.Data().AsRange().ToSlice()...) + return action.WritePacket(buf.As(packetSlice).ToOwned()) +} diff --git a/stack_mixed.go b/stack_mixed.go index 8836d6ba..33284053 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -12,7 +12,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/link/channel" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" ) diff --git a/stack_system.go b/stack_system.go index ef8b709a..f475767a 100644 --- a/stack_system.go +++ b/stack_system.go @@ -8,8 +8,8 @@ import ( "syscall" "time" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" diff --git a/stack_system_packet.go b/stack_system_packet.go index 34fe51e4..a8f8076e 100644 --- a/stack_system_packet.go +++ b/stack_system_packet.go @@ -4,7 +4,7 @@ import ( "net/netip" "syscall" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" ) diff --git a/tun.go b/tun.go index 35cd0956..abfe67fa 100644 --- a/tun.go +++ b/tun.go @@ -9,8 +9,10 @@ import ( "strings" "time" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -68,6 +70,12 @@ const ( DefaultIPRoute2AutoRedirectFallbackRuleIndex = 32768 ) +const ( + DNSModeDisabled = "disabled" + DNSModeNative = "native" + DNSModeHijack = "hijack" +) + type Options struct { Name string Inet4Address []netip.Prefix @@ -78,7 +86,8 @@ type Options struct { InterfaceScope bool Inet4Gateway netip.Addr Inet6Gateway netip.Addr - DNSServers []netip.Addr + DNSMode string + DNSAddress []netip.Addr IPRoute2TableIndex int IPRoute2RuleIndex int IPRoute2AutoRedirectFallbackRuleIndex int @@ -102,6 +111,8 @@ type Options struct { IncludeAndroidUser []int IncludePackage []string ExcludePackage []string + IncludeMACAddress []net.HardwareAddr + ExcludeMACAddress []net.HardwareAddr InterfaceFinder control.InterfaceFinder InterfaceMonitor DefaultInterfaceMonitor FileDescriptor int @@ -122,6 +133,57 @@ type Options struct { EXP_SendMsgX bool } +func (o *Options) DNSModeOrDefault() string { + if o.DNSMode == "" { + return DNSModeHijack + } + return o.DNSMode +} + +func (o *Options) DNSServerAddress() ([]netip.Addr, error) { + inet4DNS, err := o.Inet4DNSAddress() + if err != nil { + return nil, err + } + inet6DNS, err := o.Inet6DNSAddress() + if err != nil { + return nil, err + } + return append(inet4DNS, inet6DNS...), nil +} + +func (o *Options) Inet4DNSAddress() ([]netip.Addr, error) { + if len(o.Inet4Address) == 0 { + return nil, nil + } + if len(o.DNSAddress) > 0 { + return common.Filter(o.DNSAddress, netip.Addr.Is4), nil + } + if HasNextAddress(o.Inet4Address[0], 1) { + return []netip.Addr{o.Inet4Address[0].Addr().Next()}, nil + } + if !(len(o.Inet6Address) > 0 && HasNextAddress(o.Inet6Address[0], 1)) { + return nil, E.New("no IPv4 server configured and no usable next address in ", o.Inet6Address[0], " for DNS") + } + return nil, nil +} + +func (o *Options) Inet6DNSAddress() ([]netip.Addr, error) { + if len(o.Inet6Address) == 0 { + return nil, nil + } + if len(o.DNSAddress) > 0 { + return common.Filter(o.DNSAddress, netip.Addr.Is6), nil + } + if HasNextAddress(o.Inet6Address[0], 1) { + return []netip.Addr{o.Inet6Address[0].Addr().Next()}, nil + } + if !(len(o.Inet4Address) > 0 && HasNextAddress(o.Inet4Address[0], 1)) { + return nil, E.New("no IPv6 server configured and no usable next address in ", o.Inet6Address[0], " for DNS") + } + return nil, nil +} + func (o *Options) Inet4GatewayAddr() netip.Addr { if o.Inet4Gateway.IsValid() { return o.Inet4Gateway diff --git a/tun_darwin.go b/tun_darwin.go index 8aa6923f..f4ca0edd 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -9,7 +9,7 @@ import ( "syscall" "unsafe" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing-tun/internal/rawfile_darwin" "github.com/sagernet/sing-tun/internal/stopfd_darwin" "github.com/sagernet/sing/common" diff --git a/tun_linux.go b/tun_linux.go index 20fdce23..dc1a02b7 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -14,8 +14,8 @@ import ( "unsafe" "github.com/sagernet/netlink" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" @@ -317,7 +317,12 @@ func (t *NativeTun) Start() error { return E.Cause(err, "set rules") } - t.setSearchDomainForSystemdResolved() + if t.options.DNSMode != DNSModeDisabled { + err = t.setSearchDomainForSystemdResolved() + if err != nil { + return E.Cause(err, "set search domain") + } + } if t.options.AutoRoute && runtime.GOOS == "android" { t.interfaceCallback = t.options.InterfaceMonitor.RegisterCallback(t.routeUpdate) @@ -332,7 +337,9 @@ func (t *NativeTun) Close() error { if t.options.EXP_ExternalConfiguration { return common.Close(common.PtrOrNil(t.tunFile)) } - t.unsetSearchDomainForSystemdResolved() + if t.options.DNSMode != DNSModeDisabled { + t.unsetSearchDomainForSystemdResolved() + } t.unsetAddresses() return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile))) } @@ -1073,37 +1080,24 @@ func (t *NativeTun) routeUpdate(_ *control.Interface, flags int) { } } -func (t *NativeTun) setSearchDomainForSystemdResolved() { - if t.options.EXP_DisableDNSHijack { - return - } +func (t *NativeTun) setSearchDomainForSystemdResolved() error { ctlPath, err := exec.LookPath("resolvectl") if err != nil { - return - } - dnsServer := t.options.DNSServers - if len(dnsServer) == 0 { - if len(t.options.Inet4Address) > 0 && HasNextAddress(t.options.Inet4Address[0], 1) { - dnsServer = append(dnsServer, t.options.Inet4Address[0].Addr().Next()) - } - if len(t.options.Inet6Address) > 0 && HasNextAddress(t.options.Inet6Address[0], 1) { - dnsServer = append(dnsServer, t.options.Inet6Address[0].Addr().Next()) - } + return nil } - if len(dnsServer) == 0 { - return + dnsAddress, err := t.options.DNSServerAddress() + if err != nil { + return err } go func() { _ = shell.Exec(ctlPath, "domain", t.options.Name, "~.").Run() _ = shell.Exec(ctlPath, "default-route", t.options.Name, "true").Run() - _ = shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Run() + _ = shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsAddress, netip.Addr.String)...)...).Run() }() + return nil } func (t *NativeTun) unsetSearchDomainForSystemdResolved() { - if t.options.EXP_DisableDNSHijack { - return - } ctlPath, err := exec.LookPath("resolvectl") if err != nil { return diff --git a/tun_offload.go b/tun_offload.go index a0eee82f..83c833af 100644 --- a/tun_offload.go +++ b/tun_offload.go @@ -4,9 +4,9 @@ import ( "encoding/binary" "fmt" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/header" ) const ( diff --git a/tun_offload_linux.go b/tun_offload_linux.go index 77337607..a3085304 100644 --- a/tun_offload_linux.go +++ b/tun_offload_linux.go @@ -13,9 +13,9 @@ import ( "io" "unsafe" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/header" "golang.org/x/sys/unix" ) diff --git a/tun_windows.go b/tun_windows.go index 54e73954..c32bffdf 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -16,7 +16,6 @@ import ( "github.com/sagernet/sing-tun/internal/winipcfg" "github.com/sagernet/sing-tun/internal/winsys" "github.com/sagernet/sing-tun/internal/wintun" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/windnsapi" @@ -74,16 +73,14 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv4 address") } - if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack { - dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is4) - if len(dnsServers) == 0 && HasNextAddress(t.options.Inet4Address[0], 1) { - dnsServers = []netip.Addr{t.options.Inet4Address[0].Addr().Next()} + if t.options.AutoRoute && t.options.DNSModeOrDefault() != DNSModeDisabled { + dnsServers, err := t.options.Inet4DNSAddress() + if err != nil { + return err } - if len(dnsServers) > 0 { - err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), dnsServers, nil) - if err != nil { - return E.Cause(err, "set ipv4 dns") - } + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), dnsServers, nil) + if err != nil { + return E.Cause(err, "set ipv4 dns") } } else { err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), nil, nil) @@ -97,16 +94,14 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv6 address") } - if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack { - dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is6) - if len(dnsServers) == 0 && HasNextAddress(t.options.Inet6Address[0], 1) { - dnsServers = []netip.Addr{t.options.Inet6Address[0].Addr().Next()} + if t.options.AutoRoute && t.options.DNSModeOrDefault() != DNSModeDisabled { + dnsServers, err := t.options.Inet6DNSAddress() + if err != nil { + return err } - if len(dnsServers) > 0 { - err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), dnsServers, nil) - if err != nil { - return E.Cause(err, "set ipv6 dns") - } + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), dnsServers, nil) + if err != nil { + return E.Cause(err, "set ipv6 dns") } } else { err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), nil, nil) @@ -327,7 +322,7 @@ func (t *NativeTun) Start() error { } } - if !t.options.EXP_DisableDNSHijack { + if t.options.DNSModeOrDefault() == DNSModeHijack { blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL