From 8aacdbd71b92c4a85779ae7818bd0d8e8f5793d5 Mon Sep 17 00:00:00 2001 From: LjhAUMEM Date: Sat, 28 Mar 2026 01:30:21 +0800 Subject: [PATCH] WireGuard inbound: Fix multi-peer; Fix potential routing issue (#5843) Fixes https://github.com/XTLS/Xray-core/pull/5554 Fixes https://github.com/XTLS/Xray-core/issues/4760 --- common/log/logger.go | 4 +- proxy/wireguard/bind.go | 85 +++++++++++++++++------------------- proxy/wireguard/client.go | 3 +- proxy/wireguard/server.go | 40 ++++++++++------- proxy/wireguard/wireguard.go | 17 -------- 5 files changed, 68 insertions(+), 81 deletions(-) diff --git a/common/log/logger.go b/common/log/logger.go index 7c100aa1..538eda2d 100644 --- a/common/log/logger.go +++ b/common/log/logger.go @@ -36,7 +36,7 @@ type serverityLogger struct { func NewLogger(logWriterCreator WriterCreator) Handler { return &generalLogger{ creator: logWriterCreator, - buffer: make(chan Message, 16), + buffer: make(chan Message, 128), access: semaphore.New(1), done: done.New(), } @@ -46,7 +46,7 @@ func ReplaceWithSeverityLogger(serverity Severity) { w := CreateStdoutLogWriter() g := &generalLogger{ creator: w, - buffer: make(chan Message, 16), + buffer: make(chan Message, 128), access: semaphore.New(1), done: done.New(), } diff --git a/proxy/wireguard/bind.go b/proxy/wireguard/bind.go index 8e7ecb04..f4ec6f25 100644 --- a/proxy/wireguard/bind.go +++ b/proxy/wireguard/bind.go @@ -2,27 +2,23 @@ package wireguard import ( "context" - "errors" + gonet "net" "net/netip" + "runtime" "strconv" - "sync" "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/transport/internet" ) type netReadInfo struct { - // status - waiter sync.WaitGroup - // param - buff []byte - // result - bytes int + buff []byte endpoint conn.Endpoint - err error } // reduce duplicated code @@ -32,6 +28,7 @@ type netBind struct { workers int readQueue chan *netReadInfo + closedCh chan struct{} } // SetMark implements conn.Bind @@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int { // Open implements conn.Bind func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { - bind.readQueue = make(chan *netReadInfo) + bind.closedCh = make(chan struct{}) + errors.LogDebug(context.Background(), "bind opened") fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { - defer func() { - if r := recover(); r != nil { - n = 0 - err = errors.New("channel closed") - } - }() - - r, ok := <-bind.readQueue - if !ok { - return 0, errors.New("channel closed") + select { + case r := <-bind.readQueue: + sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint + return 1, nil + case <-bind.closedCh: + errors.LogDebug(context.Background(), "recv func closed") + return 0, gonet.ErrClosed } - - copy(bufs[0], r.buff[:r.bytes]) - sizes[0], eps[0] = r.bytes, r.endpoint - r.waiter.Done() - return 1, r.err } workers := bind.workers + if workers <= 0 { + workers = runtime.NumCPU() + } if workers <= 0 { workers = 1 } @@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { // Close implements conn.Bind func (bind *netBind) Close() error { - if bind.readQueue != nil { - close(bind.readQueue) + errors.LogDebug(context.Background(), "bind closed") + if bind.closedCh != nil { + close(bind.closedCh) } return nil } @@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { } endpoint.conn = c - go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) { - defer func() { - _ = recover() // handle send on closed channel - }() + go func() { for { - buff := make([]byte, 1700) - i, err := c.Read(buff) + buff := make([]byte, device.MaxMessageSize) + n, err := c.Read(buff) - if i > 3 { + if err != nil { + endpoint.conn = nil + c.Close() + return + } + + if n > 3 { buff[1] = 0 buff[2] = 0 buff[3] = 0 } - r := &netReadInfo{ - buff: buff, - bytes: i, + select { + case bind.readQueue <- &netReadInfo{ + buff: buff[:n], endpoint: endpoint, - err: err, - } - r.waiter.Add(1) - readQueue <- r - r.waiter.Wait() - if err != nil { + }: + case <-bind.closedCh: endpoint.conn = nil + c.Close() return } } - }(bind.readQueue, endpoint) + }() return nil } @@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error { } if nend.conn == nil { - return errors.New("connection not open yet") + errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer") + return errors.New("peer closed") } for _, buff := range buff { diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go index efe833a1..3030ac78 100644 --- a/proxy/wireguard/client.go +++ b/proxy/wireguard/client.go @@ -121,7 +121,8 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer) IPv4Enable: h.hasIPv4, IPv6Enable: h.hasIPv6, }, - workers: int(h.conf.NumWorkers), + workers: int(h.conf.NumWorkers), + readQueue: make(chan *netReadInfo), }, ctx: ctx, dialer: dialer, diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index ce18be22..9e82ed39 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -2,8 +2,6 @@ package wireguard import ( "context" - goerrors "errors" - "io" "github.com/xtls/xray-core/common/buf" c "github.com/xtls/xray-core/common/ctx" @@ -51,6 +49,8 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { IPv4Enable: hasIPv4, IPv6Enable: hasIPv6, }, + workers: int(conf.NumWorkers), + readQueue: make(chan *netReadInfo), }, }, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), @@ -93,25 +93,31 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con reader := buf.NewPacketReader(conn) for { - mpayload, err := reader.ReadMultiBuffer() + mb, err := reader.ReadMultiBuffer() if err != nil { + nep.conn = nil + buf.ReleaseMulti(mb) return err } - for _, payload := range mpayload { - v, ok := <-s.bindServer.readQueue - if !ok { - return nil - } - i, err := payload.Read(v.buff) + for i, b := range mb { + buff := b.Bytes() - v.bytes = i - v.endpoint = nep - v.err = err - v.waiter.Done() - if err != nil && goerrors.Is(err, io.EOF) { + if b.Len() > 3 { + buff[1] = 0 + buff[2] = 0 + buff[3] = 0 + } + + select { + case s.bindServer.readQueue <- &netReadInfo{ + buff: buff, + endpoint: nep, + }: + case <-s.bindServer.closedCh: nep.conn = nil - return nil + buf.ReleaseMulti(mb[i:]) + return errors.New("bind closed") } } } @@ -138,9 +144,11 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { // Currently we have no way to link to the original source address inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) ctx = session.ContextWithInbound(ctx, &inbound) + content := new(session.Content) if s.info.contentTag != nil { - ctx = session.ContextWithContent(ctx, s.info.contentTag) + content.SniffingRequest = s.info.contentTag.SniffingRequest } + ctx = session.ContextWithContent(ctx, content) ctx = session.SubContextFromMuxInbound(ctx) ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ diff --git a/proxy/wireguard/wireguard.go b/proxy/wireguard/wireguard.go index 0d75ee00..4f489114 100644 --- a/proxy/wireguard/wireguard.go +++ b/proxy/wireguard/wireguard.go @@ -8,25 +8,8 @@ import ( "strings" "github.com/xtls/xray-core/common" - "github.com/xtls/xray-core/common/log" - "golang.zx2c4.com/wireguard/device" ) -var wgLogger = &device.Logger{ - Verbosef: func(format string, args ...any) { - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Debug, - Content: fmt.Sprintf(format, args...), - }) - }, - Errorf: func(format string, args ...any) { - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Error, - Content: fmt.Sprintf(format, args...), - }) - }, -} - func init() { common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { deviceConfig := config.(*DeviceConfig)