diff --git a/features/dns/localdns/client.go b/features/dns/localdns/client.go index 5ba81859..b00febb3 100644 --- a/features/dns/localdns/client.go +++ b/features/dns/localdns/client.go @@ -80,13 +80,17 @@ func New() *Client { d := &net.Dialer{ Timeout: time.Second * 16, Control: func(network, address string, c syscall.RawConn) error { + var errs []error for _, ctl := range internet.Controllers { if err := ctl(network, address, c); err != nil { - errors.LogInfoInner(context.Background(), err, "failed to apply external controller") - return err + errs = append(errs, err) } } - return nil + err := errors.Combine(errs...) + if err != nil { + errors.LogInfoInner(context.Background(), err, "failed to apply external controller") + } + return err }, } diff --git a/proxy/wireguard/tun_linux.go b/proxy/wireguard/tun_linux.go index 068e21ee..b8a742e1 100644 --- a/proxy/wireguard/tun_linux.go +++ b/proxy/wireguard/tun_linux.go @@ -78,9 +78,9 @@ func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, erro var conn net.PacketConn var err error if raddr.Addr().Is4() { - conn, err = d.lc.ListenPacket(context.Background(), "udp4", ":0") + conn, err = d.lc.ListenPacket(context.Background(), "udp", "0.0.0.0:0") } else { - conn, err = d.lc.ListenPacket(context.Background(), "udp6", ":0") + conn, err = d.lc.ListenPacket(context.Background(), "udp", "[::]:0") } if err != nil { return nil, err diff --git a/transport/internet/finalmask/xdns/client.go b/transport/internet/finalmask/xdns/client.go index 6f8d9737..0c670c14 100644 --- a/transport/internet/finalmask/xdns/client.go +++ b/transport/internet/finalmask/xdns/client.go @@ -37,11 +37,11 @@ type packet struct { } type xdnsConnClient struct { - conn net.PacketConn - resolverConns []net.PacketConn + net.PacketConn + resolverAddrs []*net.UDPAddr resolverIdx uint32 - resolverSend []atomic.Uint32 + resolverSend map[string]*atomic.Uint32 clientID []byte domains []Name @@ -74,9 +74,8 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { servers = append(servers, parts[1]) } - var resolverConns []net.PacketConn var resolverAddrs []*net.UDPAddr - var resolverSend []atomic.Uint32 + var resolverSend = make(map[string]*atomic.Uint32) for _, rs := range servers { h, p, err := net.SplitHostPort(rs) if err != nil { @@ -90,27 +89,16 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { if port == 0 { return nil, errors.New("invalid port") } - var uc net.PacketConn - if ip.To4() != nil { - uc, err = net.ListenPacket("udp4", ":0") - } else { - uc, err = net.ListenPacket("udp6", ":0") - } - if err != nil { - for _, rc := range resolverConns { - rc.Close() - } - return nil, errors.New("failed to create resolver socket: ", err) - } - resolverConns = append(resolverConns, uc) - resolverAddrs = append(resolverAddrs, &net.UDPAddr{IP: ip, Port: port}) + addr := &net.UDPAddr{IP: ip, Port: port} + resolverAddrs = append(resolverAddrs, addr) + resolverSend[addr.String()] = &atomic.Uint32{} } - resolverSend = make([]atomic.Uint32, len(resolverConns)) conn := &xdnsConnClient{ - conn: raw, - resolverConns: resolverConns, + PacketConn: raw, + resolverAddrs: resolverAddrs, + resolverIdx: 0, resolverSend: resolverSend, clientID: make([]byte, 8), @@ -130,69 +118,67 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { } func (c *xdnsConnClient) recvLoop() { - var wg sync.WaitGroup + var buf [finalmask.UDPSize]byte - for i, rc := range c.resolverConns { - wg.Add(1) - go func() { - defer wg.Done() + for { + if c.closed { + break + } - var buf [finalmask.UDPSize]byte - - for { - if c.closed { - break - } - - n, addr, err := rc.ReadFrom(buf[:]) - if err != nil { - if go_errors.Is(err, net.ErrClosed) { - break - } - continue - } - - resp, err := MessageFromWireFormat(buf[:n]) - if err != nil { - errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err) - continue - } - - payload := dnsResponsePayload(&resp, c.domains) - - r := bytes.NewReader(payload) - anyPacket := false - for { - p, err := nextPacket(r) - if err != nil { - break - } - anyPacket = true - - buf := make([]byte, len(p)) - copy(buf, p) - select { - case c.readQueue <- &packet{ - p: buf, - addr: addr, - }: - default: - errors.LogDebug(context.Background(), addr, " mask read err queue full") - } - } - - if anyPacket { - c.resolverSend[i].Store(0) - select { - case c.pollChan <- struct{}{}: - default: - } - } + n, addr, err := c.PacketConn.ReadFrom(buf[:]) + if err != nil { + if go_errors.Is(err, net.ErrClosed) { + break } - }() - } + continue + } - wg.Wait() + if addr == nil { + continue + } + + send := c.resolverSend[addr.String()] + if send == nil { + continue + } + + resp, err := MessageFromWireFormat(buf[:n]) + if err != nil { + errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err) + continue + } + + payload := dnsResponsePayload(&resp, c.domains) + + r := bytes.NewReader(payload) + anyPacket := false + for { + p, err := nextPacket(r) + if err != nil { + break + } + anyPacket = true + + buf := make([]byte, len(p)) + copy(buf, p) + select { + case c.readQueue <- &packet{ + p: buf, + addr: addr, + }: + default: + errors.LogDebug(context.Background(), addr, " mask read err queue full") + } + } + + if anyPacket { + send.Store(0) + select { + case c.pollChan <- struct{}{}: + default: + } + } + } errors.LogDebug(context.Background(), "xdns closed") @@ -254,15 +240,15 @@ func (c *xdnsConnClient) sendLoop() { } cur := c.resolverIdx - curSend := c.resolverSend[c.resolverIdx].Add(1) - _, _ = c.resolverConns[c.resolverIdx].WriteTo(p.p, c.resolverAddrs[c.resolverIdx]) + curSend := c.resolverSend[c.resolverAddrs[cur].String()].Add(1) + _, _ = c.PacketConn.WriteTo(p.p, c.resolverAddrs[cur]) for { c.resolverIdx += 1 - c.resolverIdx %= uint32(len(c.resolverConns)) + c.resolverIdx %= uint32(len(c.resolverAddrs)) if c.resolverIdx == cur { break } - if c.resolverSend[c.resolverIdx].Load() < curSend { + if c.resolverSend[c.resolverAddrs[c.resolverIdx].String()].Load() < curSend { break } } @@ -290,7 +276,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { return 0, io.ErrClosedPipe } - encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverConns))]) + encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverAddrs))]) if err != nil { errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p)) return 0, nil @@ -310,35 +296,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *xdnsConnClient) Close() error { c.closed = true - for _, rc := range c.resolverConns { - rc.Close() - } - return c.conn.Close() -} - -func (c *xdnsConnClient) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *xdnsConnClient) SetDeadline(t time.Time) error { - for _, rc := range c.resolverConns { - rc.SetDeadline(t) - } - return c.conn.SetDeadline(t) -} - -func (c *xdnsConnClient) SetReadDeadline(t time.Time) error { - for _, rc := range c.resolverConns { - rc.SetReadDeadline(t) - } - return c.conn.SetReadDeadline(t) -} - -func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error { - for _, rc := range c.resolverConns { - rc.SetWriteDeadline(t) - } - return c.conn.SetWriteDeadline(t) + return c.PacketConn.Close() } func encode(p []byte, clientID []byte, domain Name) ([]byte, error) { diff --git a/transport/internet/finalmask/xdns/config.go b/transport/internet/finalmask/xdns/config.go index dbd78a28..bac0456e 100644 --- a/transport/internet/finalmask/xdns/config.go +++ b/transport/internet/finalmask/xdns/config.go @@ -2,27 +2,23 @@ package xdns import ( "net" - - "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/transport/internet" - "github.com/xtls/xray-core/transport/internet/hysteria/udphop" ) func (c *Config) UDP() { } func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { - _, ok1 := raw.(*internet.FakePacketConn) - _, ok2 := raw.(*udphop.UdpHopPacketConn) - if level != 0 || ok1 || ok2 { - return nil, errors.New("xdns requires being at the outermost level") - } + // _, ok1 := raw.(*internet.FakePacketConn) + // _, ok2 := raw.(*udphop.UdpHopPacketConn) + // if level != 0 || ok1 || ok2 { + // return nil, errors.New("xdns requires being at the outermost level") + // } return NewConnClient(c, raw) } func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { - if level != 0 { - return nil, errors.New("xdns requires being at the outermost level") - } + // if level != 0 { + // return nil, errors.New("xdns requires being at the outermost level") + // } return NewConnServer(c, raw) }