WireGuard inbound: Fix multi-peer; Fix potential routing issue (#5843)

Fixes https://github.com/XTLS/Xray-core/pull/5554

Fixes https://github.com/XTLS/Xray-core/issues/4760
This commit is contained in:
LjhAUMEM
2026-03-28 01:30:21 +08:00
committed by GitHub
parent 14524cc3b7
commit 8aacdbd71b
5 changed files with 68 additions and 81 deletions

View File

@@ -36,7 +36,7 @@ type serverityLogger struct {
func NewLogger(logWriterCreator WriterCreator) Handler { func NewLogger(logWriterCreator WriterCreator) Handler {
return &generalLogger{ return &generalLogger{
creator: logWriterCreator, creator: logWriterCreator,
buffer: make(chan Message, 16), buffer: make(chan Message, 128),
access: semaphore.New(1), access: semaphore.New(1),
done: done.New(), done: done.New(),
} }
@@ -46,7 +46,7 @@ func ReplaceWithSeverityLogger(serverity Severity) {
w := CreateStdoutLogWriter() w := CreateStdoutLogWriter()
g := &generalLogger{ g := &generalLogger{
creator: w, creator: w,
buffer: make(chan Message, 16), buffer: make(chan Message, 128),
access: semaphore.New(1), access: semaphore.New(1),
done: done.New(), done: done.New(),
} }

View File

@@ -2,27 +2,23 @@ package wireguard
import ( import (
"context" "context"
"errors" gonet "net"
"net/netip" "net/netip"
"runtime"
"strconv" "strconv"
"sync"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
) )
type netReadInfo struct { type netReadInfo struct {
// status
waiter sync.WaitGroup
// param
buff []byte buff []byte
// result
bytes int
endpoint conn.Endpoint endpoint conn.Endpoint
err error
} }
// reduce duplicated code // reduce duplicated code
@@ -32,6 +28,7 @@ type netBind struct {
workers int workers int
readQueue chan *netReadInfo readQueue chan *netReadInfo
closedCh chan struct{}
} }
// SetMark implements conn.Bind // SetMark implements conn.Bind
@@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int {
// Open implements conn.Bind // Open implements conn.Bind
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
bind.readQueue = make(chan *netReadInfo) bind.closedCh = make(chan struct{})
errors.LogDebug(context.Background(), "bind opened")
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
defer func() { select {
if r := recover(); r != nil { case r := <-bind.readQueue:
n = 0 sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint
err = errors.New("channel closed") return 1, nil
case <-bind.closedCh:
errors.LogDebug(context.Background(), "recv func closed")
return 0, gonet.ErrClosed
} }
}()
r, ok := <-bind.readQueue
if !ok {
return 0, errors.New("channel closed")
}
copy(bufs[0], r.buff[:r.bytes])
sizes[0], eps[0] = r.bytes, r.endpoint
r.waiter.Done()
return 1, r.err
} }
workers := bind.workers workers := bind.workers
if workers <= 0 {
workers = runtime.NumCPU()
}
if workers <= 0 { if workers <= 0 {
workers = 1 workers = 1
} }
@@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
// Close implements conn.Bind // Close implements conn.Bind
func (bind *netBind) Close() error { func (bind *netBind) Close() error {
if bind.readQueue != nil { errors.LogDebug(context.Background(), "bind closed")
close(bind.readQueue) if bind.closedCh != nil {
close(bind.closedCh)
} }
return nil return nil
} }
@@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
} }
endpoint.conn = c endpoint.conn = c
go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) { go func() {
defer func() {
_ = recover() // handle send on closed channel
}()
for { for {
buff := make([]byte, 1700) buff := make([]byte, device.MaxMessageSize)
i, err := c.Read(buff) n, err := c.Read(buff)
if i > 3 { if err != nil {
endpoint.conn = nil
c.Close()
return
}
if n > 3 {
buff[1] = 0 buff[1] = 0
buff[2] = 0 buff[2] = 0
buff[3] = 0 buff[3] = 0
} }
r := &netReadInfo{ select {
buff: buff, case bind.readQueue <- &netReadInfo{
bytes: i, buff: buff[:n],
endpoint: endpoint, endpoint: endpoint,
err: err, }:
} case <-bind.closedCh:
r.waiter.Add(1)
readQueue <- r
r.waiter.Wait()
if err != nil {
endpoint.conn = nil endpoint.conn = nil
c.Close()
return return
} }
} }
}(bind.readQueue, endpoint) }()
return nil return nil
} }
@@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
} }
if nend.conn == nil { if nend.conn == nil {
return errors.New("connection not open yet") errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer")
return errors.New("peer closed")
} }
for _, buff := range buff { for _, buff := range buff {

View File

@@ -122,6 +122,7 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
IPv6Enable: h.hasIPv6, IPv6Enable: h.hasIPv6,
}, },
workers: int(h.conf.NumWorkers), workers: int(h.conf.NumWorkers),
readQueue: make(chan *netReadInfo),
}, },
ctx: ctx, ctx: ctx,
dialer: dialer, dialer: dialer,

View File

@@ -2,8 +2,6 @@ package wireguard
import ( import (
"context" "context"
goerrors "errors"
"io"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx" c "github.com/xtls/xray-core/common/ctx"
@@ -51,6 +49,8 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
IPv4Enable: hasIPv4, IPv4Enable: hasIPv4,
IPv6Enable: hasIPv6, IPv6Enable: hasIPv6,
}, },
workers: int(conf.NumWorkers),
readQueue: make(chan *netReadInfo),
}, },
}, },
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
@@ -93,25 +93,31 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
reader := buf.NewPacketReader(conn) reader := buf.NewPacketReader(conn)
for { for {
mpayload, err := reader.ReadMultiBuffer() mb, err := reader.ReadMultiBuffer()
if err != nil { if err != nil {
nep.conn = nil
buf.ReleaseMulti(mb)
return err return err
} }
for _, payload := range mpayload { for i, b := range mb {
v, ok := <-s.bindServer.readQueue buff := b.Bytes()
if !ok {
return nil
}
i, err := payload.Read(v.buff)
v.bytes = i if b.Len() > 3 {
v.endpoint = nep buff[1] = 0
v.err = err buff[2] = 0
v.waiter.Done() buff[3] = 0
if err != nil && goerrors.Is(err, io.EOF) { }
select {
case s.bindServer.readQueue <- &netReadInfo{
buff: buff,
endpoint: nep,
}:
case <-s.bindServer.closedCh:
nep.conn = nil nep.conn = nil
return nil buf.ReleaseMulti(mb[i:])
return errors.New("bind closed")
} }
} }
} }
@@ -138,9 +144,11 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
// Currently we have no way to link to the original source address // Currently we have no way to link to the original source address
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
ctx = session.ContextWithInbound(ctx, &inbound) ctx = session.ContextWithInbound(ctx, &inbound)
content := new(session.Content)
if s.info.contentTag != nil { if s.info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag) content.SniffingRequest = s.info.contentTag.SniffingRequest
} }
ctx = session.ContextWithContent(ctx, content)
ctx = session.SubContextFromMuxInbound(ctx) ctx = session.SubContextFromMuxInbound(ctx)
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{

View File

@@ -8,25 +8,8 @@ import (
"strings" "strings"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/log"
"golang.zx2c4.com/wireguard/device"
) )
var wgLogger = &device.Logger{
Verbosef: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Debug,
Content: fmt.Sprintf(format, args...),
})
},
Errorf: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Error,
Content: fmt.Sprintf(format, args...),
})
},
}
func init() { func init() {
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
deviceConfig := config.(*DeviceConfig) deviceConfig := config.(*DeviceConfig)