From ba88aa173cb086429803051ed2a1838a33c271d9 Mon Sep 17 00:00:00 2001 From: LjhAUMEM Date: Sun, 5 Apr 2026 20:57:08 +0800 Subject: [PATCH] WireGuard outbound: Fix UDP FullCone NAT on Linux (#5858) Fixes https://github.com/XTLS/Xray-core/issues/5848 --- proxy/wireguard/client.go | 16 ++++- proxy/wireguard/tun.go | 69 +++++++++++++++++----- proxy/wireguard/tun_linux.go | 51 ++++++++++++---- transport/internet/system_dialer.go | 5 +- transport/internet/system_listener.go | 7 +-- transport/internet/system_listener_test.go | 6 +- 6 files changed, 116 insertions(+), 38 deletions(-) diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go index 3030ac78..3ee3a2c5 100644 --- a/proxy/wireguard/client.go +++ b/proxy/wireguard/client.go @@ -370,6 +370,18 @@ func (c *udpConnClient) ReadMultiBuffer() (buf.MultiBuffer, error) { return buf.MultiBuffer{b}, nil } -func (c *udpConnClient) Write(p []byte) (int, error) { - return c.Conn.(net.PacketConn).WriteTo(p, c.dest.RawNetAddr()) +func (c *udpConnClient) WriteMultiBuffer(mb buf.MultiBuffer) error { + for i, b := range mb { + dst := c.dest + if b.UDP != nil { + dst = *b.UDP + } + _, err := c.Conn.(net.PacketConn).WriteTo(b.Bytes(), dst.RawNetAddr()) + if err != nil { + buf.ReleaseMulti(mb[i:]) + return err + } + b.Release() + } + return nil } diff --git a/proxy/wireguard/tun.go b/proxy/wireguard/tun.go index deea7cd6..86ff9f45 100644 --- a/proxy/wireguard/tun.go +++ b/proxy/wireguard/tun.go @@ -189,8 +189,14 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo // if len(data) == 0 { // return false // } - src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort)) - dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)) + srcIP := net.IPAddress(id.RemoteAddress.AsSlice()) + dstIP := net.IPAddress(id.LocalAddress.AsSlice()) + if srcIP == nil || dstIP == nil { + errors.LogDebug(context.Background(), "drop udp with size ", len(data), " > invalid ip address ", id.RemoteAddress.AsSlice(), " ", id.LocalAddress.AsSlice()) + return true + } + src := net.UDPDestination(srcIP, net.Port(id.RemotePort)) + dst := net.UDPDestination(dstIP, net.Port(id.LocalPort)) manager.feed(src, dst, data) return true }) @@ -212,8 +218,12 @@ func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) uc, ok := m.m[src.NetAddr()] if ok { select { - case uc.ch <- data: + case uc.queue <- &packet{ + p: data, + dest: &dst, + }: default: + errors.LogDebug(context.Background(), "drop udp with size ", len(data), " to ", dst.NetAddr(), " original ", uc.dst.NetAddr(), " > queue full") } m.mutex.RUnlock() return @@ -226,9 +236,9 @@ func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) uc, ok = m.m[src.NetAddr()] if !ok { uc = &udpConn{ - ch: make(chan []byte, 1024), - src: src, - dst: dst, + queue: make(chan *packet, 1024), + src: src, + dst: dst, } uc.writeFunc = m.writeRawUDPPacket uc.closeFunc = func() { @@ -241,15 +251,19 @@ func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) } select { - case uc.ch <- data: + case uc.queue <- &packet{ + p: data, + dest: &dst, + }: default: + errors.LogDebug(context.Background(), "drop udp with size ", len(data), " to ", dst.NetAddr(), " original ", uc.dst.NetAddr(), " > queue full") } } func (m *udpManager) close(uc *udpConn) { if !uc.closed { uc.closed = true - close(uc.ch) + close(uc.queue) delete(m.m, uc.src.NetAddr()) } } @@ -317,8 +331,13 @@ func (m *udpManager) writeRawUDPPacket(payload []byte, src net.Destination, dst return nil } +type packet struct { + p []byte + dest *net.Destination +} + type udpConn struct { - ch chan []byte + queue chan *packet src net.Destination dst net.Destination writeFunc func(payload []byte, src net.Destination, dst net.Destination) error @@ -326,13 +345,35 @@ type udpConn struct { closed bool } +func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) { + for { + q, ok := <-c.queue + if !ok { + return nil, io.EOF + } + + b := buf.New() + + _, err := b.Write(q.p) + if err != nil { + errors.LogDebugInner(context.Background(), err, "drop udp with size ", len(q.p), " to ", q.dest.NetAddr(), " original ", c.dst.NetAddr()) + b.Release() + continue + } + + b.UDP = q.dest + + return buf.MultiBuffer{b}, nil + } +} + func (c *udpConn) Read(p []byte) (int, error) { - b, ok := <-c.ch + q, ok := <-c.queue if !ok { return 0, io.EOF } - n := copy(p, b) - if n != len(b) { + n := copy(p, q.p) + if n != len(q.p) { return 0, io.ErrShortBuffer } return n, nil @@ -368,11 +409,11 @@ func (c *udpConn) Close() error { } func (c *udpConn) LocalAddr() net.Addr { - return c.src.RawNetAddr() // fake + return c.dst.RawNetAddr() } func (c *udpConn) RemoteAddr() net.Addr { - return c.src.RawNetAddr() // src + return c.src.RawNetAddr() } func (c *udpConn) SetDeadline(t time.Time) error { diff --git a/proxy/wireguard/tun_linux.go b/proxy/wireguard/tun_linux.go index 7a46138a..068e21ee 100644 --- a/proxy/wireguard/tun_linux.go +++ b/proxy/wireguard/tun_linux.go @@ -10,18 +10,20 @@ import ( "net/netip" "os" "sync" + "syscall" "golang.org/x/sys/unix" - "github.com/sagernet/sing/common/control" "github.com/vishvananda/netlink" "github.com/xtls/xray-core/common/errors" - wgtun "golang.zx2c4.com/wireguard/tun" + "github.com/xtls/xray-core/transport/internet" + "golang.zx2c4.com/wireguard/tun" ) type deviceNet struct { tunnel - dialer net.Dialer + dialer *net.Dialer + lc *net.ListenConfig handle *netlink.Handle linkAddrs []netlink.Addr @@ -47,10 +49,23 @@ func allocateIPv6TableIndex() int { } func newDeviceNet(interfaceName string) *deviceNet { - var dialer net.Dialer - bindControl := control.BindToInterface(control.NewDefaultInterfaceFinder(), interfaceName, -1) - dialer.Control = control.Append(dialer.Control, bindControl) - return &deviceNet{dialer: dialer} + dialer := &net.Dialer{} + dialer.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := syscall.BindToDevice(int(fd), interfaceName); err != nil { + errors.LogInfoInner(context.Background(), err, "failed to bind to device") + } + }) + } + lc := &net.ListenConfig{} + lc.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := syscall.BindToDevice(int(fd), interfaceName); err != nil { + errors.LogInfoInner(context.Background(), err, "failed to bind to device") + } + }) + } + return &deviceNet{dialer: dialer, lc: lc} } func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( @@ -60,9 +75,23 @@ func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrP } func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { - dialer := d.dialer - dialer.LocalAddr = &net.UDPAddr{IP: laddr.Addr().AsSlice(), Port: int(laddr.Port())} - return dialer.DialContext(context.Background(), "udp", raddr.String()) + var conn net.PacketConn + var err error + if raddr.Addr().Is4() { + conn, err = d.lc.ListenPacket(context.Background(), "udp4", ":0") + } else { + conn, err = d.lc.ListenPacket(context.Background(), "udp6", ":0") + } + if err != nil { + return nil, err + } + return &internet.PacketConnWrapper{ + PacketConn: conn, + Dest: &net.UDPAddr{ + IP: raddr.Addr().AsSlice(), + Port: int(raddr.Port()), + }, + }, nil } func (d *deviceNet) Close() (err error) { @@ -134,7 +163,7 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo } n := CalculateInterfaceName("wg") - wgt, err := wgtun.CreateTUN(n, mtu) + wgt, err := tun.CreateTUN(n, mtu) if err != nil { return nil, err } diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 16b3e9b0..2d604481 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -5,7 +5,6 @@ import ( "syscall" "time" - "github.com/sagernet/sing/common/control" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" @@ -20,7 +19,7 @@ type SystemDialer interface { } type DefaultSystemDialer struct { - controllers []control.Func + controllers []func(network, address string, c syscall.RawConn) error dns dns.Client obm outbound.Manager } @@ -204,7 +203,7 @@ func UseAlternativeSystemDialer(dialer SystemDialer) { // It only works when effective dialer is the default dialer. // // xray:api:beta -func RegisterDialerController(ctl control.Func) error { +func RegisterDialerController(ctl func(network, address string, c syscall.RawConn) error) error { if ctl == nil { return errors.New("nil listener controller") } diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index 2ac28eda..1999953e 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -10,7 +10,6 @@ import ( "time" "github.com/pires/go-proxyproto" - "github.com/sagernet/sing/common/control" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" ) @@ -18,10 +17,10 @@ import ( var effectiveListener = DefaultListener{} type DefaultListener struct { - controllers []control.Func + controllers []func(network, address string, c syscall.RawConn) error } -func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []control.Func) func(network, address string, c syscall.RawConn) error { +func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []func(network, address string, c syscall.RawConn) error) func(network, address string, c syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { for _, controller := range controllers { @@ -186,7 +185,7 @@ func (dl *DefaultListener) ListenPacket(ctx context.Context, addr net.Addr, sock // The controller can be used to operate on file descriptors before they are put into use. // // xray:api:beta -func RegisterListenerController(controller control.Func) error { +func RegisterListenerController(controller func(network, address string, c syscall.RawConn) error) error { if controller == nil { return errors.New("nil listener controller") } diff --git a/transport/internet/system_listener_test.go b/transport/internet/system_listener_test.go index 390888e7..b80fdfa8 100644 --- a/transport/internet/system_listener_test.go +++ b/transport/internet/system_listener_test.go @@ -6,7 +6,6 @@ import ( "syscall" "testing" - "github.com/sagernet/sing/common/control" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/transport/internet" ) @@ -14,10 +13,9 @@ import ( func TestRegisterListenerController(t *testing.T) { var gotFd uintptr - common.Must(internet.RegisterListenerController(func(network, address string, conn syscall.RawConn) error { - return control.Raw(conn, func(fd uintptr) error { + common.Must(internet.RegisterListenerController(func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { gotFd = fd - return nil }) }))