XDNS finalmask: Use single UDP socket for multiple resolvers for now (#5982)

https://github.com/XTLS/Xray-core/pull/5982#issuecomment-4302271929

Closes https://github.com/XTLS/Xray-core/pull/5976#issuecomment-4320460288
This commit is contained in:
LjhAUMEM
2026-04-26 04:14:03 +08:00
committed by RPRX
parent 85a8bf5f39
commit fa07b34956
4 changed files with 90 additions and 132 deletions

View File

@@ -80,13 +80,17 @@ func New() *Client {
d := &net.Dialer{ d := &net.Dialer{
Timeout: time.Second * 16, Timeout: time.Second * 16,
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
var errs []error
for _, ctl := range internet.Controllers { for _, ctl := range internet.Controllers {
if err := ctl(network, address, c); err != nil { if err := ctl(network, address, c); err != nil {
errors.LogInfoInner(context.Background(), err, "failed to apply external controller") errs = append(errs, err)
return err
} }
} }
return nil err := errors.Combine(errs...)
if err != nil {
errors.LogInfoInner(context.Background(), err, "failed to apply external controller")
}
return err
}, },
} }

View File

@@ -78,9 +78,9 @@ func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, erro
var conn net.PacketConn var conn net.PacketConn
var err error var err error
if raddr.Addr().Is4() { if raddr.Addr().Is4() {
conn, err = d.lc.ListenPacket(context.Background(), "udp4", ":0") conn, err = d.lc.ListenPacket(context.Background(), "udp", "0.0.0.0:0")
} else { } else {
conn, err = d.lc.ListenPacket(context.Background(), "udp6", ":0") conn, err = d.lc.ListenPacket(context.Background(), "udp", "[::]:0")
} }
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -37,11 +37,11 @@ type packet struct {
} }
type xdnsConnClient struct { type xdnsConnClient struct {
conn net.PacketConn net.PacketConn
resolverConns []net.PacketConn
resolverAddrs []*net.UDPAddr resolverAddrs []*net.UDPAddr
resolverIdx uint32 resolverIdx uint32
resolverSend []atomic.Uint32 resolverSend map[string]*atomic.Uint32
clientID []byte clientID []byte
domains []Name domains []Name
@@ -74,9 +74,8 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
servers = append(servers, parts[1]) servers = append(servers, parts[1])
} }
var resolverConns []net.PacketConn
var resolverAddrs []*net.UDPAddr var resolverAddrs []*net.UDPAddr
var resolverSend []atomic.Uint32 var resolverSend = make(map[string]*atomic.Uint32)
for _, rs := range servers { for _, rs := range servers {
h, p, err := net.SplitHostPort(rs) h, p, err := net.SplitHostPort(rs)
if err != nil { if err != nil {
@@ -90,27 +89,16 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
if port == 0 { if port == 0 {
return nil, errors.New("invalid port") return nil, errors.New("invalid port")
} }
var uc net.PacketConn addr := &net.UDPAddr{IP: ip, Port: port}
if ip.To4() != nil { resolverAddrs = append(resolverAddrs, addr)
uc, err = net.ListenPacket("udp4", ":0") resolverSend[addr.String()] = &atomic.Uint32{}
} else {
uc, err = net.ListenPacket("udp6", ":0")
}
if err != nil {
for _, rc := range resolverConns {
rc.Close()
}
return nil, errors.New("failed to create resolver socket: ", err)
}
resolverConns = append(resolverConns, uc)
resolverAddrs = append(resolverAddrs, &net.UDPAddr{IP: ip, Port: port})
} }
resolverSend = make([]atomic.Uint32, len(resolverConns))
conn := &xdnsConnClient{ conn := &xdnsConnClient{
conn: raw, PacketConn: raw,
resolverConns: resolverConns,
resolverAddrs: resolverAddrs, resolverAddrs: resolverAddrs,
resolverIdx: 0,
resolverSend: resolverSend, resolverSend: resolverSend,
clientID: make([]byte, 8), clientID: make([]byte, 8),
@@ -130,69 +118,67 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
} }
func (c *xdnsConnClient) recvLoop() { func (c *xdnsConnClient) recvLoop() {
var wg sync.WaitGroup var buf [finalmask.UDPSize]byte
for i, rc := range c.resolverConns { for {
wg.Add(1) if c.closed {
go func() { break
defer wg.Done() }
var buf [finalmask.UDPSize]byte n, addr, err := c.PacketConn.ReadFrom(buf[:])
if err != nil {
for { if go_errors.Is(err, net.ErrClosed) {
if c.closed { break
break
}
n, addr, err := rc.ReadFrom(buf[:])
if err != nil {
if go_errors.Is(err, net.ErrClosed) {
break
}
continue
}
resp, err := MessageFromWireFormat(buf[:n])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
continue
}
payload := dnsResponsePayload(&resp, c.domains)
r := bytes.NewReader(payload)
anyPacket := false
for {
p, err := nextPacket(r)
if err != nil {
break
}
anyPacket = true
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
}:
default:
errors.LogDebug(context.Background(), addr, " mask read err queue full")
}
}
if anyPacket {
c.resolverSend[i].Store(0)
select {
case c.pollChan <- struct{}{}:
default:
}
}
} }
}() continue
} }
wg.Wait() if addr == nil {
continue
}
send := c.resolverSend[addr.String()]
if send == nil {
continue
}
resp, err := MessageFromWireFormat(buf[:n])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
continue
}
payload := dnsResponsePayload(&resp, c.domains)
r := bytes.NewReader(payload)
anyPacket := false
for {
p, err := nextPacket(r)
if err != nil {
break
}
anyPacket = true
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
}:
default:
errors.LogDebug(context.Background(), addr, " mask read err queue full")
}
}
if anyPacket {
send.Store(0)
select {
case c.pollChan <- struct{}{}:
default:
}
}
}
errors.LogDebug(context.Background(), "xdns closed") errors.LogDebug(context.Background(), "xdns closed")
@@ -254,15 +240,15 @@ func (c *xdnsConnClient) sendLoop() {
} }
cur := c.resolverIdx cur := c.resolverIdx
curSend := c.resolverSend[c.resolverIdx].Add(1) curSend := c.resolverSend[c.resolverAddrs[cur].String()].Add(1)
_, _ = c.resolverConns[c.resolverIdx].WriteTo(p.p, c.resolverAddrs[c.resolverIdx]) _, _ = c.PacketConn.WriteTo(p.p, c.resolverAddrs[cur])
for { for {
c.resolverIdx += 1 c.resolverIdx += 1
c.resolverIdx %= uint32(len(c.resolverConns)) c.resolverIdx %= uint32(len(c.resolverAddrs))
if c.resolverIdx == cur { if c.resolverIdx == cur {
break break
} }
if c.resolverSend[c.resolverIdx].Load() < curSend { if c.resolverSend[c.resolverAddrs[c.resolverIdx].String()].Load() < curSend {
break break
} }
} }
@@ -290,7 +276,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverConns))]) encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverAddrs))])
if err != nil { if err != nil {
errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p)) errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p))
return 0, nil return 0, nil
@@ -310,35 +296,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
func (c *xdnsConnClient) Close() error { func (c *xdnsConnClient) Close() error {
c.closed = true c.closed = true
for _, rc := range c.resolverConns { return c.PacketConn.Close()
rc.Close()
}
return c.conn.Close()
}
func (c *xdnsConnClient) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *xdnsConnClient) SetDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetDeadline(t)
}
return c.conn.SetDeadline(t)
}
func (c *xdnsConnClient) SetReadDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetReadDeadline(t)
}
return c.conn.SetReadDeadline(t)
}
func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error {
for _, rc := range c.resolverConns {
rc.SetWriteDeadline(t)
}
return c.conn.SetWriteDeadline(t)
} }
func encode(p []byte, clientID []byte, domain Name) ([]byte, error) { func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {

View File

@@ -2,27 +2,23 @@ package xdns
import ( import (
"net" "net"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/hysteria/udphop"
) )
func (c *Config) UDP() { func (c *Config) UDP() {
} }
func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
_, ok1 := raw.(*internet.FakePacketConn) // _, ok1 := raw.(*internet.FakePacketConn)
_, ok2 := raw.(*udphop.UdpHopPacketConn) // _, ok2 := raw.(*udphop.UdpHopPacketConn)
if level != 0 || ok1 || ok2 { // if level != 0 || ok1 || ok2 {
return nil, errors.New("xdns requires being at the outermost level") // return nil, errors.New("xdns requires being at the outermost level")
} // }
return NewConnClient(c, raw) return NewConnClient(c, raw)
} }
func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
if level != 0 { // if level != 0 {
return nil, errors.New("xdns requires being at the outermost level") // return nil, errors.New("xdns requires being at the outermost level")
} // }
return NewConnServer(c, raw) return NewConnServer(c, raw)
} }