Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -1660,16 +1660,30 @@ func (c *Sudoku) Build() (proto.Message, error) {
}

type Xdns struct {
Domain string `json:"domain"`
Domain json.RawMessage `json:"domain"`

Domains []string `json:"domains"`
Resolvers []string `json:"resolvers"`
}

func (c *Xdns) Build() (proto.Message, error) {
if c.Domain == "" {
return nil, errors.New("empty domain")
if c.Domain != nil {
return nil, errors.PrintRemovedFeatureError("domain", "domains(server) & resolvers(client)")
}

if len(c.Domains) == 0 && len(c.Resolvers) == 0 {
return nil, errors.New("empty domains & empty resolvers")
}

for _, r := range c.Resolvers {
if !strings.Contains(r, "+udp://") {
return nil, errors.New("invalid resolver ", r)
}
}

return &xdns.Config{
Domain: c.Domain,
Domains: c.Domains,
Resolvers: c.Resolvers,
}, nil
}

Expand Down
248 changes: 176 additions & 72 deletions transport/internet/finalmask/xdns/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
go_errors "errors"
"io"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/xtls/xray-core/common"
Expand All @@ -34,10 +37,14 @@ type packet struct {
}

type xdnsConnClient struct {
net.PacketConn
conn net.PacketConn
resolverConns []net.PacketConn
resolverAddrs []*net.UDPAddr
resolverIdx uint32
resolverSend []atomic.Uint32

clientID []byte
domain Name
domains []Name

pollChan chan struct{}
readQueue chan *packet
Expand All @@ -48,16 +55,66 @@ type xdnsConnClient struct {
}

func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
domain, err := ParseName(c.Domain)
if err != nil {
return nil, err
if len(c.Resolvers) == 0 {
return nil, errors.New("empty resolvers")
}

var domains []Name
var servers []string
for _, rs := range c.Resolvers {
parts := strings.Split(rs, "+udp://")
if len(parts) != 2 {
return nil, errors.New("invalid resolvers")
}
domain, err := ParseName(parts[0])
if err != nil {
return nil, err
}
domains = append(domains, domain)
servers = append(servers, parts[1])
}

var resolverConns []net.PacketConn
var resolverAddrs []*net.UDPAddr
var resolverSend []atomic.Uint32
for _, rs := range servers {
h, p, err := net.SplitHostPort(rs)
if err != nil {
return nil, err
}
ip := net.ParseIP(h)
if ip == nil {
return nil, errors.New("invalid ip address")
}
port, _ := strconv.Atoi(p)
if port == 0 {
return nil, errors.New("invalid port")
}
var uc net.PacketConn
if ip.To4() != nil {
uc, err = net.ListenPacket("udp4", ":0")
} else {
uc, err = net.ListenPacket("udp6", ":0")
}
if err != nil {
for _, rc := range resolverConns {
rc.Close()
}
return nil, errors.New("failed to create resolver socket: ", err)
}
resolverConns = append(resolverConns, uc)
resolverAddrs = append(resolverAddrs, &net.UDPAddr{IP: ip, Port: port})
}
resolverSend = make([]atomic.Uint32, len(resolverConns))

conn := &xdnsConnClient{
PacketConn: raw,
conn: raw,
resolverConns: resolverConns,
resolverAddrs: resolverAddrs,
resolverSend: resolverSend,

clientID: make([]byte, 8),
domain: domain,
domains: domains,

pollChan: make(chan struct{}, pollLimit),
readQueue: make(chan *packet, 256),
Expand All @@ -73,58 +130,70 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
}

func (c *xdnsConnClient) recvLoop() {
var buf [finalmask.UDPSize]byte

for {
if c.closed {
break
}

n, addr, err := c.PacketConn.ReadFrom(buf[:])
if err != nil || n == 0 {
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) {
break
}
continue
}

resp, err := MessageFromWireFormat(buf[:n])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
continue
}

payload := dnsResponsePayload(&resp, c.domain)

r := bytes.NewReader(payload)
anyPacket := false
for {
p, err := nextPacket(r)
if err != nil {
break
var wg sync.WaitGroup

for i, rc := range c.resolverConns {
wg.Add(1)
go func() {
defer wg.Done()

var buf [finalmask.UDPSize]byte

for {
if c.closed {
break
}

n, addr, err := rc.ReadFrom(buf[:])
if err != nil {
if go_errors.Is(err, net.ErrClosed) {
break
}
continue
}

resp, err := MessageFromWireFormat(buf[:n])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
continue
}

payload := dnsResponsePayload(&resp, c.domains)

r := bytes.NewReader(payload)
anyPacket := false
for {
p, err := nextPacket(r)
if err != nil {
break
}
anyPacket = true

buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
}:
default:
errors.LogDebug(context.Background(), addr, " mask read err queue full")
}
}

if anyPacket {
c.resolverSend[i].Store(0)
select {
case c.pollChan <- struct{}{}:
default:
}
}
}
anyPacket = true

buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
}:
default:
errors.LogDebug(context.Background(), addr, " mask read err queue full")
}
}

if anyPacket {
select {
case c.pollChan <- struct{}{}:
default:
}
}
}()
}

wg.Wait()

errors.LogDebug(context.Background(), "xdns closed")

close(c.pollChan)
Expand All @@ -138,8 +207,6 @@ func (c *xdnsConnClient) recvLoop() {
}

func (c *xdnsConnClient) sendLoop() {
var addr net.Addr

pollDelay := initPollDelay
pollTimer := time.NewTimer(pollDelay)
for {
Expand All @@ -158,17 +225,14 @@ func (c *xdnsConnClient) sendLoop() {
}

if p != nil {
addr = p.addr

select {
case <-c.pollChan:
default:
}
} else if addr != nil {
encoded, _ := encode(nil, c.clientID, c.domain)
} else {
encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx])
p = &packet{
p: encoded,
addr: addr,
p: encoded,
}
}

Expand All @@ -189,10 +253,16 @@ func (c *xdnsConnClient) sendLoop() {
return
}

if p != nil {
_, err := c.PacketConn.WriteTo(p.p, p.addr)
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) {
c.closed = true
cur := c.resolverIdx
curSend := c.resolverSend[c.resolverIdx].Add(1)
_, _ = c.resolverConns[c.resolverIdx].WriteTo(p.p, c.resolverAddrs[c.resolverIdx])
for {
c.resolverIdx += 1
c.resolverIdx %= uint32(len(c.resolverConns))
if c.resolverIdx == cur {
break
}
if c.resolverSend[c.resolverIdx].Load() < curSend {
break
}
}
Expand Down Expand Up @@ -220,7 +290,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, io.ErrClosedPipe
}

encoded, err := encode(p, c.clientID, c.domain)
encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverConns))])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p))
return 0, nil
Expand All @@ -240,7 +310,35 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {

func (c *xdnsConnClient) Close() error {
c.closed = true
return c.PacketConn.Close()
for _, rc := range c.resolverConns {
rc.Close()
}
return c.conn.Close()
}

func (c *xdnsConnClient) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}

func (c *xdnsConnClient) SetDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetDeadline(t)
}
return c.conn.SetDeadline(t)
}

func (c *xdnsConnClient) SetReadDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetReadDeadline(t)
}
return c.conn.SetReadDeadline(t)
}

func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetWriteDeadline(t)
}
return c.conn.SetWriteDeadline(t)
}

func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
Expand Down Expand Up @@ -332,7 +430,7 @@ func nextPacket(r *bytes.Reader) ([]byte, error) {
return p, err
}

func dnsResponsePayload(resp *Message, domain Name) []byte {
func dnsResponsePayload(resp *Message, domains []Name) []byte {
if resp.Flags&0x8000 != 0x8000 {
return nil
}
Expand All @@ -345,7 +443,13 @@ func dnsResponsePayload(resp *Message, domain Name) []byte {
}
answer := resp.Answer[0]

_, ok := answer.Name.TrimSuffix(domain)
var ok bool
for _, domain := range domains {
_, ok = answer.Name.TrimSuffix(domain)
if ok {
break
}
}
if !ok {
return nil
}
Expand Down
Loading