WireGuard: Implement UDP FullCone NAT (#5833)

Fixes https://github.com/XTLS/Xray-core/issues/5601

---------

Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
This commit is contained in:
LjhAUMEM
2026-03-23 01:42:40 +08:00
committed by GitHub
parent ce66db7032
commit 67a71adad1
9 changed files with 317 additions and 119 deletions

View File

@@ -53,7 +53,7 @@ func GetGlobalID(ctx context.Context) (globalID [8]byte) {
return
}
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun") {
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun" || inbound.Name == "wireguard") {
h := blake3.New(8, BaseKey)
h.Write([]byte(inbound.Source.String()))
copy(globalID[:], h.Sum(nil))

2
go.mod
View File

@@ -26,7 +26,7 @@ require (
golang.org/x/sync v0.20.0
golang.org/x/sys v0.42.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
google.golang.org/grpc v1.79.3
google.golang.org/protobuf v1.36.11
gvisor.dev/gvisor v0.0.0-20260122175437-89a5d21be8f0

2
go.sum
View File

@@ -131,6 +131,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=

View File

@@ -130,7 +130,7 @@ func ParseWireGuardKey(str string) (string, error) {
return "", errors.New("key must not be empty")
}
if len(str)%2 == 0 {
if len(str) == 64 {
_, err = hex.DecodeString(str)
if err == nil {
return str, nil

View File

@@ -227,6 +227,11 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
}
defer conn.Close()
conn = &udpConnClient{
Conn: conn,
dest: destination,
}
requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
@@ -336,3 +341,34 @@ func (h *Handler) createIPCRequest() string {
return request.String()[:request.Len()]
}
type udpConnClient struct {
net.Conn
dest net.Destination
}
func (c *udpConnClient) ReadMultiBuffer() (buf.MultiBuffer, error) {
b := buf.New()
b.Resize(0, buf.Size)
n, addr, err := c.Conn.(net.PacketConn).ReadFrom(b.Bytes())
if err != nil {
b.Release()
return nil, err
}
if addr == nil { // should never hit
addr = c.dest.RawNetAddr()
}
b.Resize(0, int32(n))
b.UDP = &net.Destination{
Address: net.IPAddress(addr.(*net.UDPAddr).IP),
Port: net.Port(addr.(*net.UDPAddr).Port),
Network: net.Network_UDP,
}
return buf.MultiBuffer{b}, nil
}
func (c *udpConnClient) Write(p []byte) (int, error) {
return c.Conn.(net.PacketConn).WriteTo(p, c.dest.RawNetAddr())
}

View File

@@ -31,6 +31,7 @@ type netTun struct {
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
notifyHandle *channel.NotificationHandle
incomingPacket chan *buffer.View
mtu int
hasV4, hasV6 bool
@@ -48,12 +49,17 @@ func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (t
dev := &netTun{
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 1),
events: make(chan tun.Event, 10),
incomingPacket: make(chan *buffer.View),
mtu: mtu,
}
dev.ep.AddNotify(dev)
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
if tcpipErr != nil {
return nil, nil, dev.stack, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
}
dev.notifyHandle = dev.ep.AddNotify(dev)
tcpipErr = dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
@@ -90,20 +96,10 @@ func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (t
dev.stack.SetSpoofing(1, true)
}
opt := tcpip.CongestionControlOption("cubic")
if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
}
dev.events <- tun.EventUp
return dev, (*Net)(dev), dev.stack, nil
}
// BatchSize implements tun.Device
func (tun *netTun) BatchSize() int {
return 1
}
// Name implements tun.Device
func (tun *netTun) Name() (string, error) {
return "go", nil
@@ -120,7 +116,6 @@ func (tun *netTun) Events() <-chan tun.Event {
}
// Read implements tun.Device
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
@@ -169,20 +164,16 @@ func (tun *netTun) WriteNotify() {
tun.incomingPacket <- view
}
// Flush implements tun.Device
func (tun *netTun) Flush() error {
return nil
}
// Close implements tun.Device
func (tun *netTun) Close() error {
tun.closeOnce.Do(func() {
tun.stack.RemoveNIC(1)
tun.stack.Close()
tun.ep.RemoveNotify(tun.notifyHandle)
tun.ep.Close()
close(tun.events)
tun.ep.Close()
close(tun.incomingPacket)
})
return nil
@@ -193,6 +184,11 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
// BatchSize implements tun.Device
func (tun *netTun) BatchSize() int {
return 1
}
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
@@ -224,6 +220,7 @@ func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, er
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr)
rfa = &addr
rfa = nil // do not ep connect
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}

View File

@@ -5,19 +5,17 @@ import (
goerrors "errors"
"io"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet/stat"
)
@@ -31,10 +29,10 @@ type Server struct {
}
type routingInfo struct {
ctx context.Context
dispatcher routing.Dispatcher
inboundTag *session.Inbound
contentTag *session.Content
ctx context.Context
dispatcher routing.Dispatcher
inboundTag *session.Inbound
contentTag *session.Content
}
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
@@ -124,7 +122,6 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
return
}
defer conn.Close()
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
sid := session.NewID()
@@ -146,9 +143,6 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
}
ctx = session.SubContextFromMuxInbound(ctx)
plcy := s.policyManager.ForLevel(0)
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: nullDestination,
To: dest,
@@ -156,35 +150,15 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
Reason: "",
})
link, err := s.info.dispatcher.Dispatch(ctx, dest)
err := s.info.dispatcher.DispatchLink(ctx, dest, &transport.Link{
Reader: buf.NewReader(conn),
Writer: buf.NewWriter(conn),
})
if err != nil {
errors.LogErrorInner(ctx, err, "dispatch connection")
}
defer cancel()
requestDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport all TCP request").Base(err)
}
return nil
errors.LogInfoInner(ctx, err, "connection ends")
}
responseDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport all TCP response").Base(err)
}
return nil
}
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
errors.LogDebugInner(ctx, err, "connection ends")
return
}
cancel()
conn.Close()
}

View File

@@ -3,6 +3,7 @@ package wireguard
import (
"context"
"fmt"
"io"
"net/netip"
"runtime"
"strconv"
@@ -10,12 +11,17 @@ import (
"sync"
"time"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -138,7 +144,7 @@ func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, erro
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
out := &gvisorNet{}
tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
tun, n, gstack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
if err != nil {
return nil, err
}
@@ -147,60 +153,236 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
// handler is only used for promiscuous mode
// capture all packets and send to handler
tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
tcpForwarder := tcp.NewForwarder(gstack, 0, 65535, func(r *tcp.ForwarderRequest) {
go func(r *tcp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
var wq waiter.Queue
var id = r.ID()
// Perform a TCP three-way handshake.
ep, err := r.CreateEndpoint(&wq)
if err != nil {
errors.LogError(context.Background(), err.String())
r.Complete(true)
return
}
r.Complete(false)
defer ep.Close()
// enable tcp keep-alive to prevent hanging connections
ep.SocketOptions().SetKeepAlive(true)
options := ep.SocketOptions()
options.SetKeepAlive(false)
options.SetReuseAddress(true)
options.SetReusePort(true)
// local address is actually destination
handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
ep.Close()
r.Complete(false)
}(r)
})
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
gstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) bool {
go func(r *udp.ForwarderRequest) {
var (
wq waiter.Queue
id = r.ID()
)
ep, err := r.CreateEndpoint(&wq)
if err != nil {
errors.LogError(context.Background(), err.String())
return
}
defer ep.Close()
// prevents hanging connections and ensure timely release
ep.SocketOptions().SetLinger(tcpip.LingerOption{
Enabled: true,
Timeout: 15 * time.Second,
})
handler(net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewUDPConn(&wq, ep))
}(r)
manager := &udpManager{
stack: gstack,
handler: handler,
m: make(map[string]*udpConn),
}
gstack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
data := pkt.Clone().Data().AsRange().ToSlice()
// if len(data) == 0 {
// return false
// }
src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
manager.feed(src, dst, data)
return true
})
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
}
out.tun, out.net = tun, n
return out, nil
}
type udpManager struct {
stack *stack.Stack
handler func(dest net.Destination, conn net.Conn)
m map[string]*udpConn
mutex sync.RWMutex
}
func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) {
m.mutex.RLock()
uc, ok := m.m[src.NetAddr()]
if ok {
select {
case uc.ch <- data:
default:
}
m.mutex.RUnlock()
return
}
m.mutex.RUnlock()
m.mutex.Lock()
defer m.mutex.Unlock()
uc, ok = m.m[src.NetAddr()]
if !ok {
uc = &udpConn{
ch: make(chan []byte, 1024),
src: src,
dst: dst,
}
uc.writeFunc = m.writeRawUDPPacket
uc.closeFunc = func() {
m.mutex.Lock()
m.close(uc)
m.mutex.Unlock()
}
m.m[src.NetAddr()] = uc
go m.handler(dst, uc)
}
select {
case uc.ch <- data:
default:
}
}
func (m *udpManager) close(uc *udpConn) {
if !uc.closed {
uc.closed = true
close(uc.ch)
delete(m.m, uc.src.NetAddr())
}
}
func (m *udpManager) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
udpLen := header.UDPMinimumSize + len(payload)
srcIP := tcpip.AddrFromSlice(src.Address.IP())
dstIP := tcpip.AddrFromSlice(dst.Address.IP())
// build packet with appropriate IP header size
isIPv4 := dst.Address.Family().IsIPv4()
ipHdrSize := header.IPv6MinimumSize
ipProtocol := header.IPv6ProtocolNumber
if isIPv4 {
ipHdrSize = header.IPv4MinimumSize
ipProtocol = header.IPv4ProtocolNumber
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
// Build UDP header
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
udpHdr.Encode(&header.UDPFields{
SrcPort: uint16(src.Port),
DstPort: uint16(dst.Port),
Length: uint16(udpLen),
})
// Calculate and set UDP checksum
xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
// Build IP header
if isIPv4 {
ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
ipHdr.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + udpLen),
TTL: 64,
Protocol: uint8(header.UDPProtocolNumber),
SrcAddr: srcIP,
DstAddr: dstIP,
})
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
} else {
ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ipHdr.Encode(&header.IPv6Fields{
PayloadLength: uint16(udpLen),
TransportProtocol: header.UDPProtocolNumber,
HopLimit: 64,
SrcAddr: srcIP,
DstAddr: dstIP,
})
}
// dispatch the packet
err := m.stack.WriteRawPacket(1, ipProtocol, buffer.MakeWithView(pkt.ToView()))
if err != nil {
return errors.New("failed to write raw udp packet back to stack err ", err)
}
return nil
}
type udpConn struct {
ch chan []byte
src net.Destination
dst net.Destination
writeFunc func(payload []byte, src net.Destination, dst net.Destination) error
closeFunc func()
closed bool
}
func (c *udpConn) Read(p []byte) (int, error) {
b, ok := <-c.ch
if !ok {
return 0, io.EOF
}
n := copy(p, b)
if n != len(b) {
return 0, io.ErrShortBuffer
}
return n, nil
}
func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
for i, b := range mb {
dst := c.dst
if b.UDP != nil {
dst = *b.UDP
}
err := c.writeFunc(b.Bytes(), dst, c.src)
if err != nil {
buf.ReleaseMulti(mb[i:])
return err
}
b.Release()
}
return nil
}
func (c *udpConn) Write(p []byte) (int, error) {
err := c.writeFunc(p, c.dst, c.src)
if err != nil {
return 0, err
}
return len(p), nil
}
func (c *udpConn) Close() error {
c.closeFunc()
return nil
}
func (c *udpConn) LocalAddr() net.Addr {
return c.src.RawNetAddr() // fake
}
func (c *udpConn) RemoteAddr() net.Addr {
return c.src.RawNetAddr() // src
}
func (c *udpConn) SetDeadline(t time.Time) error {
return nil
}
func (c *udpConn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *udpConn) SetWriteDeadline(t time.Time) error {
return nil
}

View File

@@ -100,32 +100,39 @@ func (m *udpSessionManagerServer) run() {
func (m *udpSessionManagerServer) feed(id uint32, d []byte) {
m.mutex.RLock()
udpConn, ok := m.m[id]
if ok {
select {
case udpConn.ch <- d:
default:
}
m.mutex.RUnlock()
return
}
m.mutex.RUnlock()
m.mutex.Lock()
defer m.mutex.Unlock()
udpConn, ok = m.m[id]
if !ok {
m.mutex.Lock()
udpConn, ok = m.m[id]
if !ok {
udpConn = &InterUdpConn{
conn: m.conn,
local: m.conn.LocalAddr(),
remote: m.conn.RemoteAddr(),
udpConn = &InterUdpConn{
conn: m.conn,
local: m.conn.LocalAddr(),
remote: m.conn.RemoteAddr(),
id: id,
ch: make(chan []byte, udpMessageChanSize),
last: time.Now(),
id: id,
ch: make(chan []byte, udpMessageChanSize),
last: time.Now(),
user: m.user,
}
udpConn.closeFunc = func() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.close(udpConn)
}
m.m[id] = udpConn
m.addConn(udpConn)
user: m.user,
}
m.mutex.Unlock()
udpConn.closeFunc = func() {
m.mutex.Lock()
m.close(udpConn)
m.mutex.Unlock()
}
m.m[id] = udpConn
m.addConn(udpConn)
}
select {