WireGuard inbound: Fix multi-peer; Fix potential routing issue (#5843)

Fixes https://github.com/XTLS/Xray-core/pull/5554

Fixes https://github.com/XTLS/Xray-core/issues/4760
This commit is contained in:
LjhAUMEM
2026-03-28 01:30:21 +08:00
committed by GitHub
parent 14524cc3b7
commit 8aacdbd71b
5 changed files with 68 additions and 81 deletions

View File

@@ -36,7 +36,7 @@ type serverityLogger struct {
func NewLogger(logWriterCreator WriterCreator) Handler {
return &generalLogger{
creator: logWriterCreator,
buffer: make(chan Message, 16),
buffer: make(chan Message, 128),
access: semaphore.New(1),
done: done.New(),
}
@@ -46,7 +46,7 @@ func ReplaceWithSeverityLogger(serverity Severity) {
w := CreateStdoutLogWriter()
g := &generalLogger{
creator: w,
buffer: make(chan Message, 16),
buffer: make(chan Message, 128),
access: semaphore.New(1),
done: done.New(),
}

View File

@@ -2,27 +2,23 @@ package wireguard
import (
"context"
"errors"
gonet "net"
"net/netip"
"runtime"
"strconv"
"sync"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/transport/internet"
)
type netReadInfo struct {
// status
waiter sync.WaitGroup
// param
buff []byte
// result
bytes int
buff []byte
endpoint conn.Endpoint
err error
}
// reduce duplicated code
@@ -32,6 +28,7 @@ type netBind struct {
workers int
readQueue chan *netReadInfo
closedCh chan struct{}
}
// SetMark implements conn.Bind
@@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int {
// Open implements conn.Bind
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
bind.readQueue = make(chan *netReadInfo)
bind.closedCh = make(chan struct{})
errors.LogDebug(context.Background(), "bind opened")
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
defer func() {
if r := recover(); r != nil {
n = 0
err = errors.New("channel closed")
}
}()
r, ok := <-bind.readQueue
if !ok {
return 0, errors.New("channel closed")
select {
case r := <-bind.readQueue:
sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint
return 1, nil
case <-bind.closedCh:
errors.LogDebug(context.Background(), "recv func closed")
return 0, gonet.ErrClosed
}
copy(bufs[0], r.buff[:r.bytes])
sizes[0], eps[0] = r.bytes, r.endpoint
r.waiter.Done()
return 1, r.err
}
workers := bind.workers
if workers <= 0 {
workers = runtime.NumCPU()
}
if workers <= 0 {
workers = 1
}
@@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
// Close implements conn.Bind
func (bind *netBind) Close() error {
if bind.readQueue != nil {
close(bind.readQueue)
errors.LogDebug(context.Background(), "bind closed")
if bind.closedCh != nil {
close(bind.closedCh)
}
return nil
}
@@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
}
endpoint.conn = c
go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
defer func() {
_ = recover() // handle send on closed channel
}()
go func() {
for {
buff := make([]byte, 1700)
i, err := c.Read(buff)
buff := make([]byte, device.MaxMessageSize)
n, err := c.Read(buff)
if i > 3 {
if err != nil {
endpoint.conn = nil
c.Close()
return
}
if n > 3 {
buff[1] = 0
buff[2] = 0
buff[3] = 0
}
r := &netReadInfo{
buff: buff,
bytes: i,
select {
case bind.readQueue <- &netReadInfo{
buff: buff[:n],
endpoint: endpoint,
err: err,
}
r.waiter.Add(1)
readQueue <- r
r.waiter.Wait()
if err != nil {
}:
case <-bind.closedCh:
endpoint.conn = nil
c.Close()
return
}
}
}(bind.readQueue, endpoint)
}()
return nil
}
@@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
}
if nend.conn == nil {
return errors.New("connection not open yet")
errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer")
return errors.New("peer closed")
}
for _, buff := range buff {

View File

@@ -121,7 +121,8 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
IPv4Enable: h.hasIPv4,
IPv6Enable: h.hasIPv6,
},
workers: int(h.conf.NumWorkers),
workers: int(h.conf.NumWorkers),
readQueue: make(chan *netReadInfo),
},
ctx: ctx,
dialer: dialer,

View File

@@ -2,8 +2,6 @@ package wireguard
import (
"context"
goerrors "errors"
"io"
"github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
@@ -51,6 +49,8 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
IPv4Enable: hasIPv4,
IPv6Enable: hasIPv6,
},
workers: int(conf.NumWorkers),
readQueue: make(chan *netReadInfo),
},
},
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
@@ -93,25 +93,31 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
reader := buf.NewPacketReader(conn)
for {
mpayload, err := reader.ReadMultiBuffer()
mb, err := reader.ReadMultiBuffer()
if err != nil {
nep.conn = nil
buf.ReleaseMulti(mb)
return err
}
for _, payload := range mpayload {
v, ok := <-s.bindServer.readQueue
if !ok {
return nil
}
i, err := payload.Read(v.buff)
for i, b := range mb {
buff := b.Bytes()
v.bytes = i
v.endpoint = nep
v.err = err
v.waiter.Done()
if err != nil && goerrors.Is(err, io.EOF) {
if b.Len() > 3 {
buff[1] = 0
buff[2] = 0
buff[3] = 0
}
select {
case s.bindServer.readQueue <- &netReadInfo{
buff: buff,
endpoint: nep,
}:
case <-s.bindServer.closedCh:
nep.conn = nil
return nil
buf.ReleaseMulti(mb[i:])
return errors.New("bind closed")
}
}
}
@@ -138,9 +144,11 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
// Currently we have no way to link to the original source address
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
ctx = session.ContextWithInbound(ctx, &inbound)
content := new(session.Content)
if s.info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag)
content.SniffingRequest = s.info.contentTag.SniffingRequest
}
ctx = session.ContextWithContent(ctx, content)
ctx = session.SubContextFromMuxInbound(ctx)
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{

View File

@@ -8,25 +8,8 @@ import (
"strings"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/log"
"golang.zx2c4.com/wireguard/device"
)
var wgLogger = &device.Logger{
Verbosef: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Debug,
Content: fmt.Sprintf(format, args...),
})
},
Errorf: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Error,
Content: fmt.Sprintf(format, args...),
})
},
}
func init() {
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
deviceConfig := config.(*DeviceConfig)